mx_tag: concurrent DNS resolution (40 workers, 3s timeout) for bulk speed
The serial path (verifier's 8s+6s lifetime per domain) was far too slow for bulk tagging -- 0 tagged in 3 min on dead domains. Self-contained fast resolver + ThreadPoolExecutor(40) resolves thousands of domains in minutes.
This commit is contained in:
parent
9e40965092
commit
60e6dc5d19
1 changed files with 58 additions and 11 deletions
|
|
@ -25,11 +25,49 @@ if ROOT not in sys.path:
|
|||
sys.path.insert(0, ROOT)
|
||||
|
||||
import psycopg2 # noqa: E402
|
||||
# Reuse the verifier's MX classifier + cache (one source of truth, no extra DNS).
|
||||
from scripts.verify_csv_emails import mx_provider, get_mx_hosts # noqa: E402
|
||||
import dns.resolver # noqa: E402
|
||||
from concurrent.futures import ThreadPoolExecutor # noqa: E402
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL", "")
|
||||
|
||||
# Fast, bounded MX resolution (the shared verifier uses an 8s+6s lifetime which
|
||||
# is far too slow for bulk tagging across many dead domains). 3s, no A-fallback.
|
||||
_resolver = dns.resolver.Resolver()
|
||||
_resolver.lifetime = 3.0
|
||||
_resolver.timeout = 3.0
|
||||
|
||||
|
||||
def classify(domain: str) -> str:
|
||||
try:
|
||||
ans = _resolver.resolve(domain, "MX")
|
||||
hosts = " ".join(str(r.exchange).rstrip(".") for r in ans).lower()
|
||||
except Exception:
|
||||
return "no_mx"
|
||||
if "protection.outlook" in hosts or "outlook.com" in hosts or "office365" in hosts:
|
||||
return "microsoft"
|
||||
if "aspmx.l.google" in hosts or "googlemail" in hosts or "google.com" in hosts:
|
||||
return "google"
|
||||
if "pphosted" in hosts or "ppe-hosted" in hosts or "proofpoint" in hosts:
|
||||
return "proofpoint"
|
||||
if "mimecast" in hosts:
|
||||
return "mimecast"
|
||||
if "iphmx.com" in hosts or "cisco" in hosts:
|
||||
return "cisco"
|
||||
if "barracuda" in hosts:
|
||||
return "barracuda"
|
||||
if "messagelabs" in hosts or "symantec" in hosts or "broadcom" in hosts:
|
||||
return "broadcom"
|
||||
if "secureserver.net" in hosts:
|
||||
return "godaddy"
|
||||
if "zoho" in hosts:
|
||||
return "zoho"
|
||||
if "emailsrvr" in hosts or "rackspace" in hosts:
|
||||
return "rackspace"
|
||||
if not hosts.strip():
|
||||
return "no_mx"
|
||||
root = hosts.split()[0].rstrip(".").split(".")
|
||||
return "mx:" + (".".join(root[-2:]) if len(root) >= 2 else hosts.split()[0])
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
|
|
@ -58,21 +96,30 @@ def main() -> int:
|
|||
LIMIT %s
|
||||
""", (args.limit_domains,))
|
||||
domains = [r[0] for r in cur.fetchall() if r[0]]
|
||||
print(f"resolving MX for {len(domains)} distinct domains...", file=sys.stderr)
|
||||
print(f"resolving MX for {len(domains)} distinct domains (concurrent)...", file=sys.stderr)
|
||||
|
||||
tagged_domains = 0
|
||||
for i, dom in enumerate(domains, 1):
|
||||
get_mx_hosts(dom) # populates the cache (DNS)
|
||||
prov = mx_provider(dom) # classify from cache
|
||||
# Resolve concurrently -- DNS is I/O-bound, so a thread pool gives a huge
|
||||
# speedup over the serial 3s-timeout-per-domain path.
|
||||
results: dict[str, str] = {}
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=40) as ex:
|
||||
for dom, prov in zip(domains, ex.map(classify, domains)):
|
||||
results[dom] = prov
|
||||
done += 1
|
||||
if done % 500 == 0:
|
||||
print(f" resolved {done}/{len(domains)}", file=sys.stderr)
|
||||
|
||||
# Batch-write the tags.
|
||||
tagged = 0
|
||||
for dom, prov in results.items():
|
||||
cur.execute("""
|
||||
UPDATE fmcsa_carriers SET mx_provider = %s
|
||||
WHERE lower(split_part(email_address, '@', 2)) = %s
|
||||
AND mx_provider IS NULL
|
||||
""", (prov, dom))
|
||||
tagged_domains += 1
|
||||
if i % 200 == 0:
|
||||
tagged += 1
|
||||
if tagged % 500 == 0:
|
||||
conn.commit()
|
||||
print(f" {i}/{len(domains)} domains", file=sys.stderr)
|
||||
conn.commit()
|
||||
|
||||
# Report the operator distribution of what we just tagged.
|
||||
|
|
@ -84,7 +131,7 @@ def main() -> int:
|
|||
for prov, n in cur.fetchall():
|
||||
print(f" {prov}: {n:,}", file=sys.stderr)
|
||||
conn.close()
|
||||
print(f"done: tagged {tagged_domains} domains", file=sys.stderr)
|
||||
print(f"done: tagged {tagged} domains", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue