mx_tag: concurrent DNS resolution (40 workers, 3s timeout) for bulk speed
The serial path (verifier's 8s+6s lifetime per domain) was far too slow for bulk tagging -- 0 tagged in 3 min on dead domains. Self-contained fast resolver + ThreadPoolExecutor(40) resolves thousands of domains in minutes.
This commit is contained in:
parent
9e40965092
commit
60e6dc5d19
1 changed files with 58 additions and 11 deletions
|
|
@ -25,11 +25,49 @@ if ROOT not in sys.path:
|
||||||
sys.path.insert(0, ROOT)
|
sys.path.insert(0, ROOT)
|
||||||
|
|
||||||
import psycopg2 # noqa: E402
|
import psycopg2 # noqa: E402
|
||||||
# Reuse the verifier's MX classifier + cache (one source of truth, no extra DNS).
|
import dns.resolver # noqa: E402
|
||||||
from scripts.verify_csv_emails import mx_provider, get_mx_hosts # noqa: E402
|
from concurrent.futures import ThreadPoolExecutor # noqa: E402
|
||||||
|
|
||||||
DB_URL = os.getenv("DATABASE_URL", "")
|
DB_URL = os.getenv("DATABASE_URL", "")
|
||||||
|
|
||||||
|
# Fast, bounded MX resolution (the shared verifier uses an 8s+6s lifetime which
|
||||||
|
# is far too slow for bulk tagging across many dead domains). 3s, no A-fallback.
|
||||||
|
_resolver = dns.resolver.Resolver()
|
||||||
|
_resolver.lifetime = 3.0
|
||||||
|
_resolver.timeout = 3.0
|
||||||
|
|
||||||
|
|
||||||
|
def classify(domain: str) -> str:
|
||||||
|
try:
|
||||||
|
ans = _resolver.resolve(domain, "MX")
|
||||||
|
hosts = " ".join(str(r.exchange).rstrip(".") for r in ans).lower()
|
||||||
|
except Exception:
|
||||||
|
return "no_mx"
|
||||||
|
if "protection.outlook" in hosts or "outlook.com" in hosts or "office365" in hosts:
|
||||||
|
return "microsoft"
|
||||||
|
if "aspmx.l.google" in hosts or "googlemail" in hosts or "google.com" in hosts:
|
||||||
|
return "google"
|
||||||
|
if "pphosted" in hosts or "ppe-hosted" in hosts or "proofpoint" in hosts:
|
||||||
|
return "proofpoint"
|
||||||
|
if "mimecast" in hosts:
|
||||||
|
return "mimecast"
|
||||||
|
if "iphmx.com" in hosts or "cisco" in hosts:
|
||||||
|
return "cisco"
|
||||||
|
if "barracuda" in hosts:
|
||||||
|
return "barracuda"
|
||||||
|
if "messagelabs" in hosts or "symantec" in hosts or "broadcom" in hosts:
|
||||||
|
return "broadcom"
|
||||||
|
if "secureserver.net" in hosts:
|
||||||
|
return "godaddy"
|
||||||
|
if "zoho" in hosts:
|
||||||
|
return "zoho"
|
||||||
|
if "emailsrvr" in hosts or "rackspace" in hosts:
|
||||||
|
return "rackspace"
|
||||||
|
if not hosts.strip():
|
||||||
|
return "no_mx"
|
||||||
|
root = hosts.split()[0].rstrip(".").split(".")
|
||||||
|
return "mx:" + (".".join(root[-2:]) if len(root) >= 2 else hosts.split()[0])
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
|
|
@ -58,21 +96,30 @@ def main() -> int:
|
||||||
LIMIT %s
|
LIMIT %s
|
||||||
""", (args.limit_domains,))
|
""", (args.limit_domains,))
|
||||||
domains = [r[0] for r in cur.fetchall() if r[0]]
|
domains = [r[0] for r in cur.fetchall() if r[0]]
|
||||||
print(f"resolving MX for {len(domains)} distinct domains...", file=sys.stderr)
|
print(f"resolving MX for {len(domains)} distinct domains (concurrent)...", file=sys.stderr)
|
||||||
|
|
||||||
tagged_domains = 0
|
# Resolve concurrently -- DNS is I/O-bound, so a thread pool gives a huge
|
||||||
for i, dom in enumerate(domains, 1):
|
# speedup over the serial 3s-timeout-per-domain path.
|
||||||
get_mx_hosts(dom) # populates the cache (DNS)
|
results: dict[str, str] = {}
|
||||||
prov = mx_provider(dom) # classify from cache
|
done = 0
|
||||||
|
with ThreadPoolExecutor(max_workers=40) as ex:
|
||||||
|
for dom, prov in zip(domains, ex.map(classify, domains)):
|
||||||
|
results[dom] = prov
|
||||||
|
done += 1
|
||||||
|
if done % 500 == 0:
|
||||||
|
print(f" resolved {done}/{len(domains)}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Batch-write the tags.
|
||||||
|
tagged = 0
|
||||||
|
for dom, prov in results.items():
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
UPDATE fmcsa_carriers SET mx_provider = %s
|
UPDATE fmcsa_carriers SET mx_provider = %s
|
||||||
WHERE lower(split_part(email_address, '@', 2)) = %s
|
WHERE lower(split_part(email_address, '@', 2)) = %s
|
||||||
AND mx_provider IS NULL
|
AND mx_provider IS NULL
|
||||||
""", (prov, dom))
|
""", (prov, dom))
|
||||||
tagged_domains += 1
|
tagged += 1
|
||||||
if i % 200 == 0:
|
if tagged % 500 == 0:
|
||||||
conn.commit()
|
conn.commit()
|
||||||
print(f" {i}/{len(domains)} domains", file=sys.stderr)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# Report the operator distribution of what we just tagged.
|
# Report the operator distribution of what we just tagged.
|
||||||
|
|
@ -84,7 +131,7 @@ def main() -> int:
|
||||||
for prov, n in cur.fetchall():
|
for prov, n in cur.fetchall():
|
||||||
print(f" {prov}: {n:,}", file=sys.stderr)
|
print(f" {prov}: {n:,}", file=sys.stderr)
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"done: tagged {tagged_domains} domains", file=sys.stderr)
|
print(f"done: tagged {tagged} domains", file=sys.stderr)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue