surrogate-1 / bin /lib /dedup.py
Ashira Pitchayapakayakul
fix: 16-shard write contention crashed Space β€” flock + backoff retry
9bafe64
"""
Central dedup hash store β€” single source of truth for "have we seen this pair?"
Every writer (dataset-enrich, GitHub crawler, agentic crawler, orchestrate,
threat-intel, SRE postmortem, synthetic-data) MUST call DedupStore.is_new()
before appending to training-pairs.jsonl.
Hash: md5(prompt[:500])[:16] β€” 64-bit collision space, ~1 in 10^19 false-dup
rate at our scale. Stored in ~/.surrogate/state/dedup.db (SQLite, thread-safe).
Usage:
from lib.dedup import DedupStore
if DedupStore.is_new(prompt, source="github-crawl-pr"):
write_pair(...)
# else skip
"""
from __future__ import annotations
import hashlib
import sqlite3
import threading
import time
from pathlib import Path
from typing import Iterable
DB_PATH = Path.home() / ".surrogate/state/dedup.db"
class DedupStore:
_lock = threading.Lock()
_conn: sqlite3.Connection | None = None
@classmethod
def _connection(cls) -> sqlite3.Connection:
if cls._conn is None:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
# Auto-recover from corruption (16 parallel shards can corrupt SQLite)
for attempt in range(3):
try:
c = sqlite3.connect(str(DB_PATH), check_same_thread=False,
timeout=30, isolation_level=None)
c.execute("PRAGMA journal_mode=WAL")
c.execute("PRAGMA synchronous=NORMAL")
c.execute("PRAGMA busy_timeout=30000") # 30s wait on lock
c.execute("PRAGMA wal_autocheckpoint=1000")
c.executescript("""
CREATE TABLE IF NOT EXISTS seen_hashes (
hash TEXT PRIMARY KEY,
source TEXT NOT NULL,
ts INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_seen_source ON seen_hashes(source);
CREATE INDEX IF NOT EXISTS idx_seen_ts ON seen_hashes(ts);
""")
# Smoke-test the table
c.execute("SELECT 1 FROM seen_hashes LIMIT 1").fetchall()
cls._conn = c
break
except sqlite3.DatabaseError as e:
if "malformed" in str(e).lower() or "corrupt" in str(e).lower():
# Backup + reset corrupted DB
import time as _t
backup = DB_PATH.with_suffix(f".corrupt-{int(_t.time())}.bak")
try:
DB_PATH.rename(backup)
for ext in ("-wal", "-shm"):
p = DB_PATH.with_suffix(DB_PATH.suffix + ext)
if p.exists():
p.unlink()
except Exception:
pass
if attempt < 2:
continue
raise
return cls._conn
@classmethod
def hash_key(cls, prompt: str) -> str:
return hashlib.md5(prompt[:500].encode("utf-8", errors="ignore")).hexdigest()[:16]
@classmethod
def _force_reset(cls) -> None:
"""Backup the corrupt DB and clear the cached connection so the next
_connection() call rebuilds from scratch. Caller must hold cls._lock."""
if cls._conn is not None:
try:
cls._conn.close()
except Exception:
pass
cls._conn = None
try:
if DB_PATH.exists():
backup = DB_PATH.with_suffix(f".corrupt-{int(time.time())}.bak")
DB_PATH.rename(backup)
for ext in ("-wal", "-shm"):
p = Path(str(DB_PATH) + ext)
if p.exists():
p.unlink()
except Exception:
pass
@staticmethod
def _is_corruption(e: Exception) -> bool:
msg = str(e).lower()
return "malformed" in msg or "corrupt" in msg or "not a database" in msg
@staticmethod
def _is_transient(e: Exception) -> bool:
"""Disk I/O contention or lock pile-up under 16 parallel writers.
Caller should backoff + retry, NOT wipe the DB."""
msg = str(e).lower()
return ("disk i/o error" in msg or "database is locked" in msg
or "cannot start a transaction" in msg)
@classmethod
def is_new(cls, prompt: str, source: str = "unknown") -> bool:
"""Atomic check-and-insert. Returns True if hash newly added (writer should
emit the pair); False if already seen (writer should skip).
Resilient against:
- hard corruption -> reset DB once, retry
- transient I/O / lock contention -> backoff + retry up to 3x"""
if not prompt:
return False
h = cls.hash_key(prompt)
for attempt in range(4):
try:
with cls._lock:
con = cls._connection()
cur = con.execute(
"INSERT OR IGNORE INTO seen_hashes (hash, source, ts) VALUES (?, ?, ?)",
(h, source, int(time.time())),
)
con.commit()
return cur.rowcount > 0
except sqlite3.DatabaseError as e:
if cls._is_corruption(e) and attempt == 0:
with cls._lock:
cls._force_reset()
continue
if cls._is_transient(e) and attempt < 3:
time.sleep(0.4 * (2 ** attempt)) # 0.4s, 0.8s, 1.6s backoff
continue
# Last resort: don't crash the caller β€” best to skip than lose
# the whole batch over a single retry-exhaustion.
return True # treat as new; worst case is one duplicate
@classmethod
def bulk_seen(cls, prompts: Iterable[str], source: str = "bootstrap") -> int:
"""Mark a batch of prompts as seen. Returns count newly added.
Same resilience model as is_new()."""
rows = [(cls.hash_key(p), source, int(time.time())) for p in prompts if p]
if not rows:
return 0
for attempt in range(4):
try:
with cls._lock:
con = cls._connection()
before = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0]
con.executemany(
"INSERT OR IGNORE INTO seen_hashes (hash, source, ts) VALUES (?, ?, ?)",
rows,
)
con.commit()
after = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0]
return after - before
except sqlite3.DatabaseError as e:
if cls._is_corruption(e) and attempt == 0:
with cls._lock:
cls._force_reset()
continue
if cls._is_transient(e) and attempt < 3:
time.sleep(0.4 * (2 ** attempt))
continue
return 0
@classmethod
def stats(cls) -> dict:
"""Return DB stats. Safe against corruption β€” resets and returns empty
stats rather than crashing the caller (callers like dataset-enrich
treat stats as diagnostic, not load-bearing)."""
for attempt in range(2):
try:
with cls._lock:
con = cls._connection()
total = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0]
by_source = dict(con.execute(
"SELECT source, COUNT(*) FROM seen_hashes GROUP BY source ORDER BY 2 DESC LIMIT 20"
).fetchall())
mn, mx = con.execute("SELECT MIN(ts), MAX(ts) FROM seen_hashes").fetchone()
return {"total": total, "by_source": by_source, "first_ts": mn, "latest_ts": mx}
except sqlite3.DatabaseError as e:
if cls._is_corruption(e) and attempt == 0:
with cls._lock:
cls._force_reset()
continue
if cls._is_transient(e) and attempt < 1:
time.sleep(0.5)
continue
# Last-resort: never let stats() crash a caller
return {"total": 0, "by_source": {}, "first_ts": None, "latest_ts": None, "error": str(e)}
return {"total": 0, "by_source": {}, "first_ts": None, "latest_ts": None}
def write_pair_dedup(record: dict, output_path: Path | str, prompt: str | None = None) -> bool:
"""Convenience helper: only append record if its prompt is new.
Returns True if written, False if skipped as duplicate.
"""
p = prompt or record.get("prompt") or record.get("instruction") or ""
if not p:
return False
src = record.get("source", "unknown")
if not DedupStore.is_new(p, src):
return False
import json as _json
out = Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "a") as f:
f.write(_json.dumps(record, ensure_ascii=False) + "\n")
return True
if __name__ == "__main__":
import sys, json
if len(sys.argv) > 1 and sys.argv[1] == "stats":
print(json.dumps(DedupStore.stats(), indent=2))
elif len(sys.argv) > 1 and sys.argv[1] == "bootstrap":
# Read jsonl from stdin, mark all prompts as seen
added = 0
prompts = []
src = sys.argv[2] if len(sys.argv) > 2 else "bootstrap"
for line in sys.stdin:
try:
d = json.loads(line)
p = d.get("prompt") or d.get("instruction")
if p: prompts.append(p)
if len(prompts) >= 5000:
added += DedupStore.bulk_seen(prompts, src)
prompts = []
except: pass
if prompts:
added += DedupStore.bulk_seen(prompts, src)
print(f"bootstrapped {added} new hashes (source={src})")
else:
print("usage: dedup.py stats | bootstrap [source] < input.jsonl")