mx_tag: concurrent DNS resolution (40 workers, 3s timeout) for bulk speed

The serial path (verifier's 8s+6s lifetime per domain) was far too slow for
bulk tagging -- 0 tagged in 3 min on dead domains. Self-contained fast resolver
+ ThreadPoolExecutor(40) resolves thousands of domains in minutes.
This commit is contained in:
justin 2026-06-14 21:16:17 -05:00
parent 9e40965092
commit 60e6dc5d19

View file

@ -25,11 +25,49 @@ if ROOT not in sys.path:
sys.path.insert(0, ROOT) sys.path.insert(0, ROOT)
import psycopg2 # noqa: E402 import psycopg2 # noqa: E402
# Reuse the verifier's MX classifier + cache (one source of truth, no extra DNS). import dns.resolver # noqa: E402
from scripts.verify_csv_emails import mx_provider, get_mx_hosts # noqa: E402 from concurrent.futures import ThreadPoolExecutor # noqa: E402
DB_URL = os.getenv("DATABASE_URL", "") DB_URL = os.getenv("DATABASE_URL", "")
# Fast, bounded MX resolution (the shared verifier uses an 8s+6s lifetime which
# is far too slow for bulk tagging across many dead domains). 3s, no A-fallback.
_resolver = dns.resolver.Resolver()
_resolver.lifetime = 3.0
_resolver.timeout = 3.0
def classify(domain: str) -> str:
try:
ans = _resolver.resolve(domain, "MX")
hosts = " ".join(str(r.exchange).rstrip(".") for r in ans).lower()
except Exception:
return "no_mx"
if "protection.outlook" in hosts or "outlook.com" in hosts or "office365" in hosts:
return "microsoft"
if "aspmx.l.google" in hosts or "googlemail" in hosts or "google.com" in hosts:
return "google"
if "pphosted" in hosts or "ppe-hosted" in hosts or "proofpoint" in hosts:
return "proofpoint"
if "mimecast" in hosts:
return "mimecast"
if "iphmx.com" in hosts or "cisco" in hosts:
return "cisco"
if "barracuda" in hosts:
return "barracuda"
if "messagelabs" in hosts or "symantec" in hosts or "broadcom" in hosts:
return "broadcom"
if "secureserver.net" in hosts:
return "godaddy"
if "zoho" in hosts:
return "zoho"
if "emailsrvr" in hosts or "rackspace" in hosts:
return "rackspace"
if not hosts.strip():
return "no_mx"
root = hosts.split()[0].rstrip(".").split(".")
return "mx:" + (".".join(root[-2:]) if len(root) >= 2 else hosts.split()[0])
def main() -> int: def main() -> int:
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
@ -58,21 +96,30 @@ def main() -> int:
LIMIT %s LIMIT %s
""", (args.limit_domains,)) """, (args.limit_domains,))
domains = [r[0] for r in cur.fetchall() if r[0]] domains = [r[0] for r in cur.fetchall() if r[0]]
print(f"resolving MX for {len(domains)} distinct domains...", file=sys.stderr) print(f"resolving MX for {len(domains)} distinct domains (concurrent)...", file=sys.stderr)
tagged_domains = 0 # Resolve concurrently -- DNS is I/O-bound, so a thread pool gives a huge
for i, dom in enumerate(domains, 1): # speedup over the serial 3s-timeout-per-domain path.
get_mx_hosts(dom) # populates the cache (DNS) results: dict[str, str] = {}
prov = mx_provider(dom) # classify from cache done = 0
with ThreadPoolExecutor(max_workers=40) as ex:
for dom, prov in zip(domains, ex.map(classify, domains)):
results[dom] = prov
done += 1
if done % 500 == 0:
print(f" resolved {done}/{len(domains)}", file=sys.stderr)
# Batch-write the tags.
tagged = 0
for dom, prov in results.items():
cur.execute(""" cur.execute("""
UPDATE fmcsa_carriers SET mx_provider = %s UPDATE fmcsa_carriers SET mx_provider = %s
WHERE lower(split_part(email_address, '@', 2)) = %s WHERE lower(split_part(email_address, '@', 2)) = %s
AND mx_provider IS NULL AND mx_provider IS NULL
""", (prov, dom)) """, (prov, dom))
tagged_domains += 1 tagged += 1
if i % 200 == 0: if tagged % 500 == 0:
conn.commit() conn.commit()
print(f" {i}/{len(domains)} domains", file=sys.stderr)
conn.commit() conn.commit()
# Report the operator distribution of what we just tagged. # Report the operator distribution of what we just tagged.
@ -84,7 +131,7 @@ def main() -> int:
for prov, n in cur.fetchall(): for prov, n in cur.fetchall():
print(f" {prov}: {n:,}", file=sys.stderr) print(f" {prov}: {n:,}", file=sys.stderr)
conn.close() conn.close()
print(f"done: tagged {tagged_domains} domains", file=sys.stderr) print(f"done: tagged {tagged} domains", file=sys.stderr)
return 0 return 0