Spaces:
Runtime error
Runtime error
Ashira Pitchayapakayakul
fix: 16-shard write contention crashed Space β flock + backoff retry
9bafe64 | """ | |
| Central dedup hash store β single source of truth for "have we seen this pair?" | |
| Every writer (dataset-enrich, GitHub crawler, agentic crawler, orchestrate, | |
| threat-intel, SRE postmortem, synthetic-data) MUST call DedupStore.is_new() | |
| before appending to training-pairs.jsonl. | |
| Hash: md5(prompt[:500])[:16] β 64-bit collision space, ~1 in 10^19 false-dup | |
| rate at our scale. Stored in ~/.surrogate/state/dedup.db (SQLite, thread-safe). | |
| Usage: | |
| from lib.dedup import DedupStore | |
| if DedupStore.is_new(prompt, source="github-crawl-pr"): | |
| write_pair(...) | |
| # else skip | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import sqlite3 | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Iterable | |
| DB_PATH = Path.home() / ".surrogate/state/dedup.db" | |
| class DedupStore: | |
| _lock = threading.Lock() | |
| _conn: sqlite3.Connection | None = None | |
| def _connection(cls) -> sqlite3.Connection: | |
| if cls._conn is None: | |
| DB_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| # Auto-recover from corruption (16 parallel shards can corrupt SQLite) | |
| for attempt in range(3): | |
| try: | |
| c = sqlite3.connect(str(DB_PATH), check_same_thread=False, | |
| timeout=30, isolation_level=None) | |
| c.execute("PRAGMA journal_mode=WAL") | |
| c.execute("PRAGMA synchronous=NORMAL") | |
| c.execute("PRAGMA busy_timeout=30000") # 30s wait on lock | |
| c.execute("PRAGMA wal_autocheckpoint=1000") | |
| c.executescript(""" | |
| CREATE TABLE IF NOT EXISTS seen_hashes ( | |
| hash TEXT PRIMARY KEY, | |
| source TEXT NOT NULL, | |
| ts INTEGER NOT NULL | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_seen_source ON seen_hashes(source); | |
| CREATE INDEX IF NOT EXISTS idx_seen_ts ON seen_hashes(ts); | |
| """) | |
| # Smoke-test the table | |
| c.execute("SELECT 1 FROM seen_hashes LIMIT 1").fetchall() | |
| cls._conn = c | |
| break | |
| except sqlite3.DatabaseError as e: | |
| if "malformed" in str(e).lower() or "corrupt" in str(e).lower(): | |
| # Backup + reset corrupted DB | |
| import time as _t | |
| backup = DB_PATH.with_suffix(f".corrupt-{int(_t.time())}.bak") | |
| try: | |
| DB_PATH.rename(backup) | |
| for ext in ("-wal", "-shm"): | |
| p = DB_PATH.with_suffix(DB_PATH.suffix + ext) | |
| if p.exists(): | |
| p.unlink() | |
| except Exception: | |
| pass | |
| if attempt < 2: | |
| continue | |
| raise | |
| return cls._conn | |
| def hash_key(cls, prompt: str) -> str: | |
| return hashlib.md5(prompt[:500].encode("utf-8", errors="ignore")).hexdigest()[:16] | |
| def _force_reset(cls) -> None: | |
| """Backup the corrupt DB and clear the cached connection so the next | |
| _connection() call rebuilds from scratch. Caller must hold cls._lock.""" | |
| if cls._conn is not None: | |
| try: | |
| cls._conn.close() | |
| except Exception: | |
| pass | |
| cls._conn = None | |
| try: | |
| if DB_PATH.exists(): | |
| backup = DB_PATH.with_suffix(f".corrupt-{int(time.time())}.bak") | |
| DB_PATH.rename(backup) | |
| for ext in ("-wal", "-shm"): | |
| p = Path(str(DB_PATH) + ext) | |
| if p.exists(): | |
| p.unlink() | |
| except Exception: | |
| pass | |
| def _is_corruption(e: Exception) -> bool: | |
| msg = str(e).lower() | |
| return "malformed" in msg or "corrupt" in msg or "not a database" in msg | |
| def _is_transient(e: Exception) -> bool: | |
| """Disk I/O contention or lock pile-up under 16 parallel writers. | |
| Caller should backoff + retry, NOT wipe the DB.""" | |
| msg = str(e).lower() | |
| return ("disk i/o error" in msg or "database is locked" in msg | |
| or "cannot start a transaction" in msg) | |
| def is_new(cls, prompt: str, source: str = "unknown") -> bool: | |
| """Atomic check-and-insert. Returns True if hash newly added (writer should | |
| emit the pair); False if already seen (writer should skip). | |
| Resilient against: | |
| - hard corruption -> reset DB once, retry | |
| - transient I/O / lock contention -> backoff + retry up to 3x""" | |
| if not prompt: | |
| return False | |
| h = cls.hash_key(prompt) | |
| for attempt in range(4): | |
| try: | |
| with cls._lock: | |
| con = cls._connection() | |
| cur = con.execute( | |
| "INSERT OR IGNORE INTO seen_hashes (hash, source, ts) VALUES (?, ?, ?)", | |
| (h, source, int(time.time())), | |
| ) | |
| con.commit() | |
| return cur.rowcount > 0 | |
| except sqlite3.DatabaseError as e: | |
| if cls._is_corruption(e) and attempt == 0: | |
| with cls._lock: | |
| cls._force_reset() | |
| continue | |
| if cls._is_transient(e) and attempt < 3: | |
| time.sleep(0.4 * (2 ** attempt)) # 0.4s, 0.8s, 1.6s backoff | |
| continue | |
| # Last resort: don't crash the caller β best to skip than lose | |
| # the whole batch over a single retry-exhaustion. | |
| return True # treat as new; worst case is one duplicate | |
| def bulk_seen(cls, prompts: Iterable[str], source: str = "bootstrap") -> int: | |
| """Mark a batch of prompts as seen. Returns count newly added. | |
| Same resilience model as is_new().""" | |
| rows = [(cls.hash_key(p), source, int(time.time())) for p in prompts if p] | |
| if not rows: | |
| return 0 | |
| for attempt in range(4): | |
| try: | |
| with cls._lock: | |
| con = cls._connection() | |
| before = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0] | |
| con.executemany( | |
| "INSERT OR IGNORE INTO seen_hashes (hash, source, ts) VALUES (?, ?, ?)", | |
| rows, | |
| ) | |
| con.commit() | |
| after = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0] | |
| return after - before | |
| except sqlite3.DatabaseError as e: | |
| if cls._is_corruption(e) and attempt == 0: | |
| with cls._lock: | |
| cls._force_reset() | |
| continue | |
| if cls._is_transient(e) and attempt < 3: | |
| time.sleep(0.4 * (2 ** attempt)) | |
| continue | |
| return 0 | |
| def stats(cls) -> dict: | |
| """Return DB stats. Safe against corruption β resets and returns empty | |
| stats rather than crashing the caller (callers like dataset-enrich | |
| treat stats as diagnostic, not load-bearing).""" | |
| for attempt in range(2): | |
| try: | |
| with cls._lock: | |
| con = cls._connection() | |
| total = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0] | |
| by_source = dict(con.execute( | |
| "SELECT source, COUNT(*) FROM seen_hashes GROUP BY source ORDER BY 2 DESC LIMIT 20" | |
| ).fetchall()) | |
| mn, mx = con.execute("SELECT MIN(ts), MAX(ts) FROM seen_hashes").fetchone() | |
| return {"total": total, "by_source": by_source, "first_ts": mn, "latest_ts": mx} | |
| except sqlite3.DatabaseError as e: | |
| if cls._is_corruption(e) and attempt == 0: | |
| with cls._lock: | |
| cls._force_reset() | |
| continue | |
| if cls._is_transient(e) and attempt < 1: | |
| time.sleep(0.5) | |
| continue | |
| # Last-resort: never let stats() crash a caller | |
| return {"total": 0, "by_source": {}, "first_ts": None, "latest_ts": None, "error": str(e)} | |
| return {"total": 0, "by_source": {}, "first_ts": None, "latest_ts": None} | |
| def write_pair_dedup(record: dict, output_path: Path | str, prompt: str | None = None) -> bool: | |
| """Convenience helper: only append record if its prompt is new. | |
| Returns True if written, False if skipped as duplicate. | |
| """ | |
| p = prompt or record.get("prompt") or record.get("instruction") or "" | |
| if not p: | |
| return False | |
| src = record.get("source", "unknown") | |
| if not DedupStore.is_new(p, src): | |
| return False | |
| import json as _json | |
| out = Path(output_path) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out, "a") as f: | |
| f.write(_json.dumps(record, ensure_ascii=False) + "\n") | |
| return True | |
| if __name__ == "__main__": | |
| import sys, json | |
| if len(sys.argv) > 1 and sys.argv[1] == "stats": | |
| print(json.dumps(DedupStore.stats(), indent=2)) | |
| elif len(sys.argv) > 1 and sys.argv[1] == "bootstrap": | |
| # Read jsonl from stdin, mark all prompts as seen | |
| added = 0 | |
| prompts = [] | |
| src = sys.argv[2] if len(sys.argv) > 2 else "bootstrap" | |
| for line in sys.stdin: | |
| try: | |
| d = json.loads(line) | |
| p = d.get("prompt") or d.get("instruction") | |
| if p: prompts.append(p) | |
| if len(prompts) >= 5000: | |
| added += DedupStore.bulk_seen(prompts, src) | |
| prompts = [] | |
| except: pass | |
| if prompts: | |
| added += DedupStore.bulk_seen(prompts, src) | |
| print(f"bootstrapped {added} new hashes (source={src})") | |
| else: | |
| print("usage: dedup.py stats | bootstrap [source] < input.jsonl") | |