diff --git a/scripts/mx_tag_carriers.py b/scripts/mx_tag_carriers.py index d2ebaaa..ac189c6 100644 --- a/scripts/mx_tag_carriers.py +++ b/scripts/mx_tag_carriers.py @@ -97,27 +97,36 @@ def main() -> int: domains = [r[0] for r in cur.fetchall() if r[0]] print(f"resolving MX for {len(domains)} distinct domains (concurrent)...", file=sys.stderr) - # Resolve concurrently and write incrementally as each completes, so a few - # slow/hung domains never hold up the whole batch (as_completed, not map). + # Resolve concurrently (DNS is I/O-bound; per-call Resolver = thread-safe). from concurrent.futures import as_completed - tagged = 0 + resolved: list[tuple[str, str]] = [] with ThreadPoolExecutor(max_workers=60) as ex: futs = {ex.submit(classify, d): d for d in domains} - for fut in as_completed(futs): + for i, fut in enumerate(as_completed(futs), 1): dom = futs[fut] try: - prov = fut.result() + resolved.append((dom, fut.result())) except Exception: - prov = "no_mx" - cur.execute(""" - UPDATE fmcsa_carriers SET mx_provider = %s - WHERE lower(split_part(email_address, '@', 2)) = %s - AND mx_provider IS NULL - """, (prov, dom)) - tagged += 1 - if tagged % 300 == 0: - conn.commit() - print(f" tagged {tagged}/{len(domains)} domains", file=sys.stderr) + resolved.append((dom, "no_mx")) + if i % 1000 == 0: + print(f" resolved {i}/{len(domains)}", file=sys.stderr) + + # ONE bulk UPDATE via a temp table + join. The per-domain UPDATE was doing a + # full 1.49M-row scan EACH time (no functional index on the email-domain + # expression); the join scans the table once. + print(f"bulk-writing {len(resolved)} domain tags...", file=sys.stderr) + cur.execute("CREATE TEMP TABLE _mx_map (domain text PRIMARY KEY, provider text) ON COMMIT DROP") + from psycopg2.extras import execute_values + execute_values(cur, "INSERT INTO _mx_map (domain, provider) VALUES %s ON CONFLICT (domain) DO NOTHING", + resolved, page_size=1000) + cur.execute(""" + UPDATE fmcsa_carriers f + SET mx_provider = m.provider + FROM _mx_map m + WHERE lower(split_part(f.email_address, '@', 2)) = m.domain + AND f.mx_provider IS NULL + """) + tagged = cur.rowcount conn.commit() # Report the operator distribution of what we just tagged.