""" deduplication.py — Fast near-duplicate removal for large datasets. Uses a two-phase strategy: 1. **Exact dedup** — hash-based O(n) removal of identical texts. 2. **Near-dedup via Sentence-BERT** — encode texts, build a cosine similarity index, and remove near-duplicate pairs above a configurable threshold. Uses chunked approach with early termination to keep runtime feasible on 100K+ rows. The ``all-MiniLM-L6-v2`` model is used for embedding. """ import hashlib import logging import time from typing import Dict, List, Optional, Set, Tuple import numpy as np import pandas as pd logger = logging.getLogger(__name__) _DEFAULT_MODEL = "all-MiniLM-L6-v2" # ═══════════════════════════════════════════════════════════ # Phase 1: Exact dedup (hash-based, O(n)) # ═══════════════════════════════════════════════════════════ def _exact_dedup(df: pd.DataFrame, text_column: str) -> Tuple[pd.DataFrame, int]: """Remove rows with identical text via SHA-256 hashing. Args: df: Input DataFrame. text_column: Column to hash for exact comparison. Returns: (deduplicated DataFrame, number of rows removed). """ before = len(df) hashes: Dict[str, int] = {} keep: List[bool] = [] for idx, txt in enumerate(df[text_column].fillna("").astype(str)): h = hashlib.sha256(txt.encode("utf-8", errors="replace")).hexdigest() if h in hashes: keep.append(False) else: hashes[h] = idx keep.append(True) df_out = df.loc[keep].reset_index(drop=True) removed = before - len(df_out) logger.info("Exact dedup: removed %d / %d identical rows", removed, before) return df_out, removed # ═══════════════════════════════════════════════════════════ # Phase 2: Semantic near-dedup (Sentence-BERT + chunked cosine) # ═══════════════════════════════════════════════════════════ def _semantic_dedup( df: pd.DataFrame, text_column: str, threshold: float, batch_size: int, model_name: str, max_rows_for_pairwise: int = 30_000, ) -> Tuple[pd.DataFrame, int]: """Remove near-duplicate rows using Sentence-BERT cosine similarity. For datasets larger than *max_rows_for_pairwise*, the comparison is done in a block-diagonal fashion (each chunk vs. itself) to keep computation tractable. Cross-chunk duplicates are rare across dataset origins, and exact dedup already handles identical pairs. Args: df: Input DataFrame (already exact-deduped). text_column: Column to encode. threshold: Cosine similarity cutoff. batch_size: Encoding batch size. model_name: Sentence-BERT model name. max_rows_for_pairwise: Max rows for full pairwise comparison. Returns: (deduplicated DataFrame, number of rows removed). """ from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity n = len(df) if n < 2: return df.copy(), 0 texts = df[text_column].fillna("").astype(str).tolist() # Truncate long texts to first 256 chars for fast encoding texts_trunc = [t[:256] for t in texts] logger.info( "Encoding %d texts with %s (batch_size=%d) …", n, model_name, batch_size, ) model = SentenceTransformer(model_name) embeddings = model.encode( texts_trunc, batch_size=batch_size, show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True, ) duplicate_indices: Set[int] = set() if n <= max_rows_for_pairwise: # Full pairwise — feasible for ≤ 30K rows logger.info("Running full pairwise cosine similarity (%d × %d) …", n, n) chunk_size = 2000 for start in range(0, n, chunk_size): end = min(start + chunk_size, n) sim = cosine_similarity(embeddings[start:end], embeddings) for li in range(sim.shape[0]): gi = start + li if gi in duplicate_indices: continue # Only compare with later-indexed rows for j in range(gi + 1, n): if j in duplicate_indices: continue if sim[li, j] >= threshold: duplicate_indices.add(j) else: # For very large datasets: compare within blocks of 10K rows logger.info( "Dataset too large (%d) for full pairwise — using block dedup", n, ) block_size = 10_000 for block_start in range(0, n, block_size): block_end = min(block_start + block_size, n) block_emb = embeddings[block_start:block_end] block_n = block_end - block_start logger.info( " Block [%d:%d] (%d rows) …", block_start, block_end, block_n, ) sim = cosine_similarity(block_emb, block_emb) for li in range(block_n): gi = block_start + li if gi in duplicate_indices: continue for lj in range(li + 1, block_n): gj = block_start + lj if gj in duplicate_indices: continue if sim[li, lj] >= threshold: duplicate_indices.add(gj) removed = len(duplicate_indices) if removed > 0: keep_mask = np.ones(n, dtype=bool) for idx in duplicate_indices: keep_mask[idx] = False df_out = df.loc[keep_mask].reset_index(drop=True) else: df_out = df.copy() logger.info("Semantic dedup: removed %d / %d near-duplicate rows", removed, n) return df_out, removed # ═══════════════════════════════════════════════════════════ # Public API # ═══════════════════════════════════════════════════════════ def deduplicate_dataframe( df: pd.DataFrame, text_column: str = "text", threshold: float = 0.92, batch_size: int = 64, model_name: str = _DEFAULT_MODEL, origin_column: Optional[str] = "dataset_origin", ) -> Tuple[pd.DataFrame, Dict[str, int]]: """Remove duplicate rows from *df* (exact + semantic). Args: df: Input DataFrame (must contain *text_column*). text_column: Column to use for duplicate detection. threshold: Cosine similarity cutoff for near-dedup. batch_size: Encoding batch size. model_name: Sentence-BERT model identifier. origin_column: Optional column for per-origin stats. Returns: (cleaned DataFrame, stats dict with per-origin removal counts). """ t0 = time.perf_counter() logger.info("=" * 60) logger.info("Starting deduplication pipeline (threshold=%.2f) …", threshold) n_before = len(df) # Phase 1: exact df_exact, exact_removed = _exact_dedup(df, text_column) # Phase 2: semantic df_final, semantic_removed = _semantic_dedup( df_exact, text_column=text_column, threshold=threshold, batch_size=batch_size, model_name=model_name, ) total_removed = n_before - len(df_final) # Build per-origin stats stats: Dict[str, int] = {} if origin_column and origin_column in df.columns: before_counts = df[origin_column].value_counts().to_dict() after_counts = df_final[origin_column].value_counts().to_dict() for origin in before_counts: stats[origin] = before_counts[origin] - after_counts.get(origin, 0) else: stats["total"] = total_removed elapsed = time.perf_counter() - t0 logger.info( "Dedup complete: %d → %d rows (removed %d, %.1f%%) in %.1fs", n_before, len(df_final), total_removed, 100 * total_removed / max(n_before, 1), elapsed, ) for origin, cnt in stats.items(): if cnt > 0: logger.info(" %-30s %6d removed", origin, cnt) logger.info("=" * 60) return df_final, stats # ─── standalone test ──────────────────────────────────────── if __name__ == "__main__": logging.basicConfig(level=logging.INFO) sample = pd.DataFrame({ "text": [ "The president signed the bill into law today.", "The president signed the bill into law today.", # exact dup "Scientists discover a new species of frog in the Amazon.", "A new frog species has been found in the Amazon rainforest.", # near dup "Stock markets rallied after a strong jobs report.", ], "dataset_origin": ["a", "a", "b", "b", "c"], }) clean, info = deduplicate_dataframe(sample, threshold=0.92) print(f"\nKept {len(clean)} / {len(sample)} rows") print("Stats:", info) print(clean[["text", "dataset_origin"]])