nkshirsa commited on
Commit
ecdb8ec
Β·
verified Β·
1 Parent(s): 056b450

Add SPECTER2 embedding-based deduplication (replaces Jaccard word overlap)

Browse files
phd_research_os_v2/layer3/embedding_dedup.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Layer 3: Embedding-Based Claim Deduplication (SPECTER2)
3
+ =========================================================
4
+ Replaces Jaccard word-overlap deduplication with SPECTER2 scientific
5
+ embeddings for semantic matching.
6
+
7
+ Addresses blindspots: M-1, M-2, PA-1
8
+ Source: SYSTEM_INSPIRATIONS.md DA-1
9
+
10
+ Dependencies:
11
+ pip install adapters torch
12
+
13
+ Falls back to Jaccard if adapters/torch not available.
14
+ """
15
+
16
+ import json
17
+ import re
18
+ import logging
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # ── Try to load SPECTER2 ──────────────────────────────────────────────
26
+ _SPECTER2_AVAILABLE = False
27
+ _specter2_model = None
28
+ _specter2_tokenizer = None
29
+
30
+ def _load_specter2():
31
+ """Lazy-load SPECTER2 model and adapter. Called once on first use."""
32
+ global _SPECTER2_AVAILABLE, _specter2_model, _specter2_tokenizer
33
+
34
+ if _specter2_model is not None:
35
+ return True
36
+
37
+ try:
38
+ from adapters import AutoAdapterModel
39
+ from transformers import AutoTokenizer
40
+
41
+ logger.info("Loading SPECTER2 base model...")
42
+ _specter2_tokenizer = AutoTokenizer.from_pretrained("allenai/specter2_base")
43
+ _specter2_model = AutoAdapterModel.from_pretrained("allenai/specter2_base")
44
+
45
+ logger.info("Loading SPECTER2 proximity adapter...")
46
+ _specter2_model.load_adapter("allenai/specter2", source="hf", set_active=True)
47
+ _specter2_model.eval()
48
+
49
+ _SPECTER2_AVAILABLE = True
50
+ logger.info("SPECTER2 loaded successfully (768-dim embeddings)")
51
+ return True
52
+
53
+ except ImportError:
54
+ logger.warning(
55
+ "SPECTER2 not available (install: pip install adapters torch). "
56
+ "Falling back to Jaccard word overlap for deduplication."
57
+ )
58
+ _SPECTER2_AVAILABLE = False
59
+ return False
60
+ except Exception as e:
61
+ logger.warning(f"SPECTER2 failed to load: {e}. Using Jaccard fallback.")
62
+ _SPECTER2_AVAILABLE = False
63
+ return False
64
+
65
+
66
+ def embed_claims(texts: list[str]) -> np.ndarray:
67
+ """
68
+ Embed a list of claim texts using SPECTER2.
69
+ Returns shape (N, 768) numpy array of L2-normalized embeddings.
70
+
71
+ For SPECTER2, the expected input format is:
72
+ title + [SEP] + abstract
73
+ For claims (no title), we just pass the claim text directly.
74
+ """
75
+ import torch
76
+
77
+ if not _load_specter2():
78
+ raise RuntimeError("SPECTER2 not available")
79
+
80
+ inputs = _specter2_tokenizer(
81
+ texts,
82
+ padding=True,
83
+ truncation=True,
84
+ max_length=512,
85
+ return_tensors="pt"
86
+ )
87
+
88
+ with torch.no_grad():
89
+ outputs = _specter2_model(**inputs)
90
+
91
+ # CLS token embedding
92
+ embeddings = outputs.last_hidden_state[:, 0, :].numpy()
93
+
94
+ # L2 normalize for cosine similarity via dot product
95
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
96
+ norms = np.where(norms == 0, 1, norms)
97
+ embeddings = embeddings / norms
98
+
99
+ return embeddings
100
+
101
+
102
+ def cosine_similarity(emb_a: np.ndarray, emb_b: np.ndarray) -> float:
103
+ """Cosine similarity between two L2-normalized embeddings."""
104
+ return float(np.dot(emb_a, emb_b))
105
+
106
+
107
+ def cosine_similarity_matrix(embeddings: np.ndarray) -> np.ndarray:
108
+ """Full pairwise cosine similarity matrix (for batch operations)."""
109
+ return embeddings @ embeddings.T
110
+
111
+
112
+ # ── Jaccard fallback (identical to existing canonicalizer.py) ─────────
113
+
114
+ _STOPWORDS = {
115
+ 'the', 'a', 'an', 'is', 'was', 'were', 'are', 'been', 'be',
116
+ 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
117
+ 'could', 'should', 'may', 'might', 'in', 'on', 'at', 'to',
118
+ 'for', 'of', 'with', 'by', 'from', 'and', 'or', 'but', 'not',
119
+ 'this', 'that', 'it', 'its', 'we', 'our', 'they'
120
+ }
121
+
122
+ def _normalize(text: str) -> str:
123
+ t = text.lower().strip()
124
+ t = re.sub(r'\s+', ' ', t)
125
+ t = re.sub(r'[^\w\s\.\,\-\+\=\<\>\(\)]', '', t)
126
+ return t
127
+
128
+ def jaccard_similarity(text_a: str, text_b: str) -> float:
129
+ words_a = set(_normalize(text_a).split()) - _STOPWORDS
130
+ words_b = set(_normalize(text_b).split()) - _STOPWORDS
131
+ if not words_a or not words_b:
132
+ return 0.0
133
+ intersection = words_a & words_b
134
+ union = words_a | words_b
135
+ return len(intersection) / len(union) if union else 0.0
136
+
137
+
138
+ # ── Unified similarity function ───────────────────────────────────────
139
+
140
+ def claim_similarity(text_a: str, text_b: str, method: str = "auto") -> float:
141
+ """
142
+ Compute similarity between two claim texts.
143
+
144
+ method:
145
+ "auto" - SPECTER2 if available, else Jaccard
146
+ "specter2" - Force SPECTER2 (raises if not available)
147
+ "jaccard" - Force Jaccard word overlap
148
+ """
149
+ if method == "jaccard":
150
+ return jaccard_similarity(text_a, text_b)
151
+
152
+ if method == "auto":
153
+ if _load_specter2():
154
+ method = "specter2"
155
+ else:
156
+ return jaccard_similarity(text_a, text_b)
157
+
158
+ # SPECTER2
159
+ embeddings = embed_claims([text_a, text_b])
160
+ return cosine_similarity(embeddings[0], embeddings[1])
161
+
162
+
163
+ def batch_deduplicate(texts: list[str], threshold: float = 0.85,
164
+ method: str = "auto") -> dict:
165
+ """
166
+ Batch deduplication. Returns mapping of duplicate indices to their canonical index.
167
+
168
+ Returns:
169
+ {
170
+ "canonical_indices": [0, 2, 5, ...], # indices of unique claims
171
+ "duplicates": {1: 0, 3: 0, 4: 2}, # duplicate_idx -> canonical_idx
172
+ "similarity_method": "specter2" | "jaccard"
173
+ }
174
+ """
175
+ n = len(texts)
176
+ if n == 0:
177
+ return {"canonical_indices": [], "duplicates": {}, "similarity_method": "none"}
178
+ if n == 1:
179
+ return {"canonical_indices": [0], "duplicates": {}, "similarity_method": "none"}
180
+
181
+ use_specter = (method == "specter2") or (method == "auto" and _load_specter2())
182
+
183
+ if use_specter:
184
+ embeddings = embed_claims(texts)
185
+ sim_matrix = cosine_similarity_matrix(embeddings)
186
+ actual_method = "specter2"
187
+ else:
188
+ # Build Jaccard matrix
189
+ sim_matrix = np.zeros((n, n))
190
+ for i in range(n):
191
+ for j in range(i, n):
192
+ sim = jaccard_similarity(texts[i], texts[j])
193
+ sim_matrix[i][j] = sim
194
+ sim_matrix[j][i] = sim
195
+ actual_method = "jaccard"
196
+
197
+ # Greedy deduplication
198
+ canonical_indices = []
199
+ duplicates = {}
200
+ removed = set()
201
+
202
+ for i in range(n):
203
+ if i in removed:
204
+ continue
205
+ canonical_indices.append(i)
206
+ for j in range(i + 1, n):
207
+ if j in removed:
208
+ continue
209
+ if sim_matrix[i][j] >= threshold:
210
+ duplicates[j] = i
211
+ removed.add(j)
212
+
213
+ return {
214
+ "canonical_indices": canonical_indices,
215
+ "duplicates": duplicates,
216
+ "similarity_method": actual_method,
217
+ }
218
+
219
+
220
+ def is_available() -> bool:
221
+ """Check if SPECTER2 is available for embedding-based dedup."""
222
+ return _load_specter2()