Spaces:
Runtime error
Runtime error
File size: 10,344 Bytes
508b0e2 39c61d0 508b0e2 0ad083a 9bafe64 508b0e2 0ad083a 9bafe64 508b0e2 9bafe64 0ad083a 9bafe64 508b0e2 0ad083a 9bafe64 508b0e2 9bafe64 0ad083a 9bafe64 508b0e2 0ad083a 9bafe64 0ad083a 508b0e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | """
Central dedup hash store β single source of truth for "have we seen this pair?"
Every writer (dataset-enrich, GitHub crawler, agentic crawler, orchestrate,
threat-intel, SRE postmortem, synthetic-data) MUST call DedupStore.is_new()
before appending to training-pairs.jsonl.
Hash: md5(prompt[:500])[:16] β 64-bit collision space, ~1 in 10^19 false-dup
rate at our scale. Stored in ~/.surrogate/state/dedup.db (SQLite, thread-safe).
Usage:
from lib.dedup import DedupStore
if DedupStore.is_new(prompt, source="github-crawl-pr"):
write_pair(...)
# else skip
"""
from __future__ import annotations
import hashlib
import sqlite3
import threading
import time
from pathlib import Path
from typing import Iterable
DB_PATH = Path.home() / ".surrogate/state/dedup.db"
class DedupStore:
_lock = threading.Lock()
_conn: sqlite3.Connection | None = None
@classmethod
def _connection(cls) -> sqlite3.Connection:
if cls._conn is None:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
# Auto-recover from corruption (16 parallel shards can corrupt SQLite)
for attempt in range(3):
try:
c = sqlite3.connect(str(DB_PATH), check_same_thread=False,
timeout=30, isolation_level=None)
c.execute("PRAGMA journal_mode=WAL")
c.execute("PRAGMA synchronous=NORMAL")
c.execute("PRAGMA busy_timeout=30000") # 30s wait on lock
c.execute("PRAGMA wal_autocheckpoint=1000")
c.executescript("""
CREATE TABLE IF NOT EXISTS seen_hashes (
hash TEXT PRIMARY KEY,
source TEXT NOT NULL,
ts INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_seen_source ON seen_hashes(source);
CREATE INDEX IF NOT EXISTS idx_seen_ts ON seen_hashes(ts);
""")
# Smoke-test the table
c.execute("SELECT 1 FROM seen_hashes LIMIT 1").fetchall()
cls._conn = c
break
except sqlite3.DatabaseError as e:
if "malformed" in str(e).lower() or "corrupt" in str(e).lower():
# Backup + reset corrupted DB
import time as _t
backup = DB_PATH.with_suffix(f".corrupt-{int(_t.time())}.bak")
try:
DB_PATH.rename(backup)
for ext in ("-wal", "-shm"):
p = DB_PATH.with_suffix(DB_PATH.suffix + ext)
if p.exists():
p.unlink()
except Exception:
pass
if attempt < 2:
continue
raise
return cls._conn
@classmethod
def hash_key(cls, prompt: str) -> str:
return hashlib.md5(prompt[:500].encode("utf-8", errors="ignore")).hexdigest()[:16]
@classmethod
def _force_reset(cls) -> None:
"""Backup the corrupt DB and clear the cached connection so the next
_connection() call rebuilds from scratch. Caller must hold cls._lock."""
if cls._conn is not None:
try:
cls._conn.close()
except Exception:
pass
cls._conn = None
try:
if DB_PATH.exists():
backup = DB_PATH.with_suffix(f".corrupt-{int(time.time())}.bak")
DB_PATH.rename(backup)
for ext in ("-wal", "-shm"):
p = Path(str(DB_PATH) + ext)
if p.exists():
p.unlink()
except Exception:
pass
@staticmethod
def _is_corruption(e: Exception) -> bool:
msg = str(e).lower()
return "malformed" in msg or "corrupt" in msg or "not a database" in msg
@staticmethod
def _is_transient(e: Exception) -> bool:
"""Disk I/O contention or lock pile-up under 16 parallel writers.
Caller should backoff + retry, NOT wipe the DB."""
msg = str(e).lower()
return ("disk i/o error" in msg or "database is locked" in msg
or "cannot start a transaction" in msg)
@classmethod
def is_new(cls, prompt: str, source: str = "unknown") -> bool:
"""Atomic check-and-insert. Returns True if hash newly added (writer should
emit the pair); False if already seen (writer should skip).
Resilient against:
- hard corruption -> reset DB once, retry
- transient I/O / lock contention -> backoff + retry up to 3x"""
if not prompt:
return False
h = cls.hash_key(prompt)
for attempt in range(4):
try:
with cls._lock:
con = cls._connection()
cur = con.execute(
"INSERT OR IGNORE INTO seen_hashes (hash, source, ts) VALUES (?, ?, ?)",
(h, source, int(time.time())),
)
con.commit()
return cur.rowcount > 0
except sqlite3.DatabaseError as e:
if cls._is_corruption(e) and attempt == 0:
with cls._lock:
cls._force_reset()
continue
if cls._is_transient(e) and attempt < 3:
time.sleep(0.4 * (2 ** attempt)) # 0.4s, 0.8s, 1.6s backoff
continue
# Last resort: don't crash the caller β best to skip than lose
# the whole batch over a single retry-exhaustion.
return True # treat as new; worst case is one duplicate
@classmethod
def bulk_seen(cls, prompts: Iterable[str], source: str = "bootstrap") -> int:
"""Mark a batch of prompts as seen. Returns count newly added.
Same resilience model as is_new()."""
rows = [(cls.hash_key(p), source, int(time.time())) for p in prompts if p]
if not rows:
return 0
for attempt in range(4):
try:
with cls._lock:
con = cls._connection()
before = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0]
con.executemany(
"INSERT OR IGNORE INTO seen_hashes (hash, source, ts) VALUES (?, ?, ?)",
rows,
)
con.commit()
after = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0]
return after - before
except sqlite3.DatabaseError as e:
if cls._is_corruption(e) and attempt == 0:
with cls._lock:
cls._force_reset()
continue
if cls._is_transient(e) and attempt < 3:
time.sleep(0.4 * (2 ** attempt))
continue
return 0
@classmethod
def stats(cls) -> dict:
"""Return DB stats. Safe against corruption β resets and returns empty
stats rather than crashing the caller (callers like dataset-enrich
treat stats as diagnostic, not load-bearing)."""
for attempt in range(2):
try:
with cls._lock:
con = cls._connection()
total = con.execute("SELECT COUNT(*) FROM seen_hashes").fetchone()[0]
by_source = dict(con.execute(
"SELECT source, COUNT(*) FROM seen_hashes GROUP BY source ORDER BY 2 DESC LIMIT 20"
).fetchall())
mn, mx = con.execute("SELECT MIN(ts), MAX(ts) FROM seen_hashes").fetchone()
return {"total": total, "by_source": by_source, "first_ts": mn, "latest_ts": mx}
except sqlite3.DatabaseError as e:
if cls._is_corruption(e) and attempt == 0:
with cls._lock:
cls._force_reset()
continue
if cls._is_transient(e) and attempt < 1:
time.sleep(0.5)
continue
# Last-resort: never let stats() crash a caller
return {"total": 0, "by_source": {}, "first_ts": None, "latest_ts": None, "error": str(e)}
return {"total": 0, "by_source": {}, "first_ts": None, "latest_ts": None}
def write_pair_dedup(record: dict, output_path: Path | str, prompt: str | None = None) -> bool:
"""Convenience helper: only append record if its prompt is new.
Returns True if written, False if skipped as duplicate.
"""
p = prompt or record.get("prompt") or record.get("instruction") or ""
if not p:
return False
src = record.get("source", "unknown")
if not DedupStore.is_new(p, src):
return False
import json as _json
out = Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "a") as f:
f.write(_json.dumps(record, ensure_ascii=False) + "\n")
return True
if __name__ == "__main__":
import sys, json
if len(sys.argv) > 1 and sys.argv[1] == "stats":
print(json.dumps(DedupStore.stats(), indent=2))
elif len(sys.argv) > 1 and sys.argv[1] == "bootstrap":
# Read jsonl from stdin, mark all prompts as seen
added = 0
prompts = []
src = sys.argv[2] if len(sys.argv) > 2 else "bootstrap"
for line in sys.stdin:
try:
d = json.loads(line)
p = d.get("prompt") or d.get("instruction")
if p: prompts.append(p)
if len(prompts) >= 5000:
added += DedupStore.bulk_seen(prompts, src)
prompts = []
except: pass
if prompts:
added += DedupStore.bulk_seen(prompts, src)
print(f"bootstrapped {added} new hashes (source={src})")
else:
print("usage: dedup.py stats | bootstrap [source] < input.jsonl")
|