dtufail commited on
Commit
8168ff0
·
verified ·
1 Parent(s): 9f50407

Upload retriever.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. retriever.py +647 -0
retriever.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ retriever.py — Nuremberg Scholar Hybrid Retriever (HuggingFace Spaces / ZeroGPU)
3
+ ====================================================================================
4
+ Changes from local/SageMaker version:
5
+ - index_dir parameter : Retriever accepts an explicit path instead of
6
+ hardcoded Path("output/index"). On Spaces this
7
+ comes from snapshot_download(); locally it falls
8
+ back to the default ./output/index/.
9
+ - CPU-first model loading : QueryEncoder and Reranker load to CPU at init.
10
+ rag.py moves them to CUDA inside the @spaces.GPU
11
+ window and back to CPU after. The .device attribute
12
+ on QueryEncoder and Reranker is updated by rag.py
13
+ before each call so encode()/rerank() run on the
14
+ correct device.
15
+ - dtype= replaces torch_dtype : fixes the transformers deprecation warning.
16
+ - CLI smoke test preserved : `python retriever.py --query "..." ` still works
17
+ for local testing; it auto-detects CUDA availability.
18
+
19
+ Pipeline (paper-backed):
20
+ 1. Query encoding : BGE-M3 dense (1024d) + sparse (lexical weights)
21
+ 2. Dense retrieval : FAISS FlatIP top-N (cosine via L2-norm + inner product)
22
+ 3. Sparse retrieval: dot-product over CSR sparse matrix, top-N
23
+ 4. RRF fusion : k=60, merge dense+sparse ranked lists -> top-K candidates
24
+ 5. Reranking : bge-reranker-v2-m3 cross-encoder -> sigmoid scores -> top-K_final
25
+ 6. Return : list of ranked Result objects with metadata + scores
26
+
27
+ Design decisions from literature:
28
+ - RRF k=60: industry standard, robust across domains (Cormack et al. 2009)
29
+ - Dense N=100, Sparse N=100 -> RRF top-25 -> rerank to top-5
30
+ (two-stage funnel: high recall first, high precision second)
31
+ - BGE-M3 paper recommends dense+sparse hybrid for long-document corpus;
32
+ sparse alone outperforms dense by ~10 NDCG points on long docs (MLDR)
33
+ - bge-reranker-v2-m3 is the official reranker pairing for bge-m3 embeddings
34
+ - Scores sigmoid-mapped to [0,1] for interpretability at generation time
35
+ - No query instruction prefix needed for BGE-M3 (unlike BGE v1.5)
36
+
37
+ Usage:
38
+ from retriever import Retriever
39
+ r = Retriever()
40
+ results = r.retrieve("What did Goring say about the Luftwaffe?", top_k=5)
41
+ for res in results:
42
+ print(res)
43
+
44
+ # CLI smoke test
45
+ python retriever.py --query "crimes against humanity Article 6c"
46
+ python retriever.py --query "Ohlendorf Einsatzgruppen" --top-k 3 --no-rerank
47
+ python retriever.py --query "London Agreement 1945" --dense-only
48
+ """
49
+
50
+ import json
51
+ import time
52
+ import argparse
53
+ from pathlib import Path
54
+ from dataclasses import dataclass, field
55
+ from typing import Optional
56
+
57
+ # ── Defaults ──────────────────────────────────────────────────────────────────
58
+
59
+ DEFAULT_INDEX_DIR = Path("output/index")
60
+
61
+ EMBED_MODEL = "BAAI/bge-m3"
62
+ RERANK_MODEL = "BAAI/bge-reranker-v2-m3"
63
+
64
+ EMBED_DIM = 1024
65
+ RRF_K = 60 # Cormack et al. 2009 — robust standard
66
+ DENSE_N = 100 # candidates from dense retrieval
67
+ SPARSE_N = 100 # candidates from sparse retrieval
68
+ RERANK_INPUT = 25 # max chunks sent to reranker (post-RRF)
69
+ DEFAULT_TOP_K = 5 # final chunks returned to generator
70
+ MAX_Q_TOKENS = 512 # query max tokens (queries are short)
71
+
72
+ # ── Result dataclass ──────────────────────────────────────────────────────────
73
+
74
+ @dataclass
75
+ class Result:
76
+ chunk_id: str
77
+ body: str
78
+ collection: str
79
+ date_iso: Optional[str]
80
+ speaker: Optional[str]
81
+ source_url: Optional[str]
82
+ page_number: Optional[int]
83
+ slug: Optional[str]
84
+ # Scores
85
+ dense_rank: Optional[int] = None
86
+ sparse_rank: Optional[int] = None
87
+ rrf_score: float = 0.0
88
+ rerank_score: Optional[float] = None # sigmoid [0,1], None if bypassed
89
+
90
+ def __str__(self):
91
+ rerank = f" rerank={self.rerank_score:.4f}" if self.rerank_score is not None else ""
92
+ return (
93
+ f"[{self.collection}] {self.date_iso or '?'} {self.slug or ''}\n"
94
+ f" speaker={self.speaker or '-'} page={self.page_number or '?'}\n"
95
+ f" rrf={self.rrf_score:.5f}{rerank}\n"
96
+ f" {self.body[:200]}..."
97
+ )
98
+
99
+
100
+ # ── BGE-M3 query encoder ──────────────────────────────────────────────────────
101
+
102
+ UNUSED_TOKENS = [0, 1, 2] # <s>, <pad>, </s>
103
+
104
+
105
+ class QueryEncoder:
106
+ """
107
+ Encodes a query into:
108
+ dense_vec : np.ndarray (1024,) L2-normalised
109
+ sparse_weights : dict {token_str: score}
110
+
111
+ sparse_linear = Linear(1024, 1) — scalar weight per token position.
112
+ Scatter onto input_ids vocab positions via scatter_reduce("amax").
113
+
114
+ ZeroGPU note:
115
+ Loads to CPU at init. rag.py moves self.model to CUDA inside the
116
+ @spaces.GPU window by calling self.model.to("cuda") and updating
117
+ self.device. encode() uses self.device for all tensor ops, so it
118
+ runs on whichever device the model currently sits on.
119
+ """
120
+
121
+ def __init__(self, model_name: str, device: str = "cpu"):
122
+ import torch
123
+ import torch.nn as nn
124
+ from transformers import AutoTokenizer, AutoModel
125
+ from huggingface_hub import hf_hub_download
126
+
127
+ self.device = torch.device(device)
128
+ self.torch = torch
129
+ self.fp16 = device != "cpu"
130
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
131
+ self.vocab_size = self.tokenizer.vocab_size # 250002
132
+
133
+ # CPU-first: always load to CPU, let caller move to GPU when needed.
134
+ # dtype= replaces deprecated torch_dtype=
135
+ self.model = AutoModel.from_pretrained(
136
+ model_name,
137
+ dtype=torch.float16 if self.fp16 else torch.float32,
138
+ )
139
+ self.model.to(self.device)
140
+ self.model.eval()
141
+
142
+ sparse_path = hf_hub_download(repo_id=model_name, filename="sparse_linear.pt")
143
+ raw = torch.load(sparse_path, map_location="cpu", weights_only=True)
144
+ in_f, out_f = raw["weight"].shape[1], raw["weight"].shape[0]
145
+ self.sparse_linear = nn.Linear(in_f, out_f, bias=True)
146
+ self.sparse_linear.load_state_dict(raw, strict=True)
147
+ if self.fp16:
148
+ self.sparse_linear = self.sparse_linear.half()
149
+ self.sparse_linear.to(self.device)
150
+ self.sparse_linear.eval()
151
+
152
+ def encode(self, query: str) -> dict:
153
+ """
154
+ Encode a query string. Uses self.device for all tensor placement,
155
+ so this works on both CPU and CUDA depending on where the model
156
+ has been moved by the caller.
157
+ """
158
+ import torch
159
+ import numpy as np
160
+ import torch.nn.functional as F
161
+
162
+ # Resolve current device from the model parameters — this handles
163
+ # the case where rag.py has moved self.model to CUDA but self.device
164
+ # hasn't been explicitly updated yet.
165
+ device = next(self.model.parameters()).device
166
+
167
+ enc = self.tokenizer(
168
+ [query],
169
+ padding=True,
170
+ truncation=True,
171
+ max_length=MAX_Q_TOKENS,
172
+ return_tensors="pt",
173
+ )
174
+ enc = {k: v.to(device) for k, v in enc.items()}
175
+
176
+ with torch.no_grad():
177
+ out = self.model(**enc, return_dict=True)
178
+ last_hidden = out.last_hidden_state
179
+
180
+ dense = F.normalize(last_hidden[:, 0, :].float(), p=2, dim=-1)
181
+ dense_np = dense.cpu().numpy().astype("float32")[0] # (1024,)
182
+
183
+ # sparse_linear may be on a different device if only self.model
184
+ # was moved — move it to match
185
+ if next(self.sparse_linear.parameters()).device != device:
186
+ self.sparse_linear.to(device)
187
+
188
+ token_weights = torch.relu(
189
+ self.sparse_linear(last_hidden)
190
+ ).squeeze(-1).float()
191
+
192
+ sparse_emb = torch.zeros(
193
+ 1, self.vocab_size, dtype=torch.float32, device=device
194
+ )
195
+ sparse_emb = sparse_emb.scatter_reduce(
196
+ dim=1,
197
+ index=enc["input_ids"],
198
+ src=token_weights,
199
+ reduce="amax",
200
+ include_self=False,
201
+ )
202
+ for uid in UNUSED_TOKENS:
203
+ if uid < self.vocab_size:
204
+ sparse_emb[0, uid] = 0.0
205
+
206
+ nonzero = sparse_emb[0].nonzero(as_tuple=True)[0].tolist()
207
+ scores = sparse_emb[0][nonzero].cpu().tolist()
208
+ sparse = {}
209
+ for tid, score in zip(nonzero, scores):
210
+ if score <= 0:
211
+ continue
212
+ tok = self.tokenizer.decode([tid]).strip()
213
+ if tok:
214
+ sparse[tok] = round(float(score), 4)
215
+
216
+ return {"dense_vec": dense_np, "sparse_weights": sparse}
217
+
218
+
219
+ # ── Reranker ──────────────────────────────────────────────────────────────────
220
+
221
+ class Reranker:
222
+ """
223
+ bge-reranker-v2-m3 cross-encoder.
224
+ Scores sigmoid-mapped to [0,1] per HF model card recommendation.
225
+
226
+ ZeroGPU note:
227
+ Same CPU-first pattern as QueryEncoder. rag.py moves self.model
228
+ to CUDA inside the @spaces.GPU window. rerank() resolves device
229
+ from model parameters.
230
+ """
231
+
232
+ def __init__(self, model_name: str, device: str = "cpu"):
233
+ import torch
234
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
235
+
236
+ self.device = torch.device(device)
237
+ self.torch = torch
238
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
239
+
240
+ # dtype= replaces deprecated torch_dtype=
241
+ self.model = AutoModelForSequenceClassification.from_pretrained(
242
+ model_name,
243
+ dtype=torch.float16 if device != "cpu" else torch.float32,
244
+ )
245
+ self.model.to(self.device)
246
+ self.model.eval()
247
+
248
+ def rerank(self, query: str, candidates: list[Result],
249
+ batch_size: int = 32) -> list[Result]:
250
+ import torch
251
+
252
+ # Resolve current device from model parameters
253
+ device = next(self.model.parameters()).device
254
+
255
+ pairs = [[query, c.body] for c in candidates]
256
+ all_scores = []
257
+
258
+ for i in range(0, len(pairs), batch_size):
259
+ batch = pairs[i:i + batch_size]
260
+ enc = self.tokenizer(
261
+ batch,
262
+ padding=True,
263
+ truncation=True,
264
+ max_length=512,
265
+ return_tensors="pt",
266
+ )
267
+ enc = {k: v.to(device) for k, v in enc.items()}
268
+ with torch.no_grad():
269
+ logits = self.model(**enc, return_dict=True).logits.view(-1).float()
270
+ scores = torch.sigmoid(logits).cpu().tolist()
271
+ all_scores.extend(scores)
272
+
273
+ for candidate, score in zip(candidates, all_scores):
274
+ candidate.rerank_score = round(score, 6)
275
+
276
+ return sorted(candidates, key=lambda x: x.rerank_score, reverse=True)
277
+
278
+
279
+ # ── Sparse index ──────────────────────────────────────────────────────────────
280
+
281
+ class SparseIndex:
282
+ """
283
+ CSR sparse matrix index over sparse.jsonl.
284
+
285
+ Layout: matrix shape (num_tokens, num_chunks), float32.
286
+ rows = tokens (indexed via token_to_row dict)
287
+ columns = chunks (same order as metadata.jsonl / FAISS rows)
288
+ values = BGE-M3 sparse weights
289
+
290
+ Query:
291
+ 1. Build a 1-row CSR query vector from query token weights.
292
+ 2. query_vec @ matrix -> dense (num_chunks,) score array. One BLAS call.
293
+ 3. np.argpartition for top-n, argsort only the top slice.
294
+
295
+ Why CSR vs dict-of-lists:
296
+ - RAM : 54 MB vs 608 MB (-554 MB measured on this corpus)
297
+ - Query: single scipy sparse matmul vs Python loop over posting lists
298
+ - Load : ~4s vs ~24s
299
+
300
+ query() signature is identical to the old implementation.
301
+ """
302
+
303
+ def __init__(self, sparse_path: Path):
304
+ import numpy as np
305
+ from scipy.sparse import csr_matrix
306
+
307
+ print(f" Loading sparse index from {sparse_path}...")
308
+ t0 = time.time()
309
+
310
+ token_to_row: dict[str, int] = {}
311
+ chunk_ids: list[str] = []
312
+ rows: list[int] = []
313
+ cols: list[int] = []
314
+ data: list[float] = []
315
+
316
+ with sparse_path.open(encoding="utf-8") as f:
317
+ for chunk_idx, line in enumerate(f):
318
+ line = line.strip()
319
+ if not line:
320
+ continue
321
+ obj = json.loads(line)
322
+ chunk_ids.append(obj["chunk_id"])
323
+ for token, weight in obj.get("weights", {}).items():
324
+ if token not in token_to_row:
325
+ token_to_row[token] = len(token_to_row)
326
+ rows.append(token_to_row[token])
327
+ cols.append(chunk_idx)
328
+ data.append(weight)
329
+
330
+ num_tokens = len(token_to_row)
331
+ num_chunks = len(chunk_ids)
332
+
333
+ self.matrix = csr_matrix(
334
+ (
335
+ np.array(data, dtype=np.float32),
336
+ (np.array(rows, dtype=np.int32),
337
+ np.array(cols, dtype=np.int32)),
338
+ ),
339
+ shape=(num_tokens, num_chunks),
340
+ )
341
+ self.token_to_row = token_to_row
342
+ self.chunk_ids = chunk_ids
343
+
344
+ elapsed = time.time() - t0
345
+ ram_mb = (self.matrix.data.nbytes
346
+ + self.matrix.indices.nbytes
347
+ + self.matrix.indptr.nbytes) / 1024**2
348
+
349
+ print(f" Sparse index: {num_chunks:,} chunks, "
350
+ f"{num_tokens:,} unique tokens, "
351
+ f"{self.matrix.nnz:,} nnz "
352
+ f"({elapsed:.1f}s, {ram_mb:.1f} MB CSR)")
353
+
354
+ def query(self, sparse_weights: dict[str, float],
355
+ top_n: int) -> list[tuple[int, float]]:
356
+ """
357
+ Returns list of (chunk_idx, score) sorted descending, length <= top_n.
358
+ Identical signature to the old dict-of-lists implementation.
359
+ """
360
+ import numpy as np
361
+ from scipy.sparse import csr_matrix
362
+
363
+ if not sparse_weights:
364
+ return []
365
+
366
+ q_rows, q_cols, q_data = [], [], []
367
+ for token, weight in sparse_weights.items():
368
+ row = self.token_to_row.get(token)
369
+ if row is not None:
370
+ q_rows.append(0)
371
+ q_cols.append(row)
372
+ q_data.append(weight)
373
+
374
+ if not q_data:
375
+ return []
376
+
377
+ num_tokens = self.matrix.shape[0]
378
+ q_vec = csr_matrix(
379
+ (np.array(q_data, dtype=np.float32),
380
+ (np.array(q_rows, dtype=np.int32),
381
+ np.array(q_cols, dtype=np.int32))),
382
+ shape=(1, num_tokens),
383
+ )
384
+
385
+ # (1, num_tokens) @ (num_tokens, num_chunks) -> (1, num_chunks)
386
+ # todense() ensures we always get a plain numpy matrix, not sparse
387
+ scores = np.asarray((q_vec @ self.matrix).todense()).ravel() # (num_chunks,)
388
+
389
+ if top_n >= len(scores):
390
+ top_indices = np.argsort(scores)[::-1]
391
+ else:
392
+ top_indices = np.argpartition(scores, -top_n)[-top_n:]
393
+ top_indices = top_indices[np.argsort(scores[top_indices])[::-1]]
394
+
395
+ return [
396
+ (int(idx), float(scores[idx]))
397
+ for idx in top_indices
398
+ if float(scores[idx]) > 0
399
+ ]
400
+
401
+
402
+ # ── RRF fusion ────────────────────────────────────────────────────────────────
403
+
404
+ def reciprocal_rank_fusion(
405
+ dense_ranked: list[tuple[int, float]],
406
+ sparse_ranked: list[tuple[int, float]],
407
+ k: int = RRF_K,
408
+ ) -> list[tuple[int, float]]:
409
+ """
410
+ RRF(d) = sum( 1 / (k + rank_r(d)) )
411
+ Returns list of (chunk_idx, rrf_score) sorted descending.
412
+ """
413
+ rrf: dict[int, float] = {}
414
+ for rank, (chunk_idx, _) in enumerate(dense_ranked, start=1):
415
+ rrf[chunk_idx] = rrf.get(chunk_idx, 0.0) + 1.0 / (k + rank)
416
+ for rank, (chunk_idx, _) in enumerate(sparse_ranked, start=1):
417
+ rrf[chunk_idx] = rrf.get(chunk_idx, 0.0) + 1.0 / (k + rank)
418
+ return sorted(rrf.items(), key=lambda x: x[1], reverse=True)
419
+
420
+
421
+ # ── Main Retriever ────────────────────────────────────────────────────────────
422
+
423
+ class Retriever:
424
+ """
425
+ Full hybrid retrieval pipeline.
426
+
427
+ Parameters
428
+ ----------
429
+ index_dir : Path to directory containing dense.faiss, metadata.jsonl,
430
+ sparse.jsonl. Defaults to ./output/index/ for local dev.
431
+ On Spaces, rag.py passes the snapshot_download() cache path.
432
+ device : "cuda" / "cpu". On Spaces this is "cpu" at init time;
433
+ rag.py moves models to CUDA inside the @spaces.GPU window.
434
+ dense_n : candidates from FAISS (default 100)
435
+ sparse_n : candidates from sparse index (default 100)
436
+ rerank_input : max chunks sent to reranker (default 25)
437
+ top_k : final results returned (default 5)
438
+ use_reranker : bool (default True)
439
+ dense_only : skip sparse + RRF, just return FAISS top-k (baseline mode)
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ index_dir: Optional[str] = None,
445
+ device: str = "cpu",
446
+ dense_n: int = DENSE_N,
447
+ sparse_n: int = SPARSE_N,
448
+ rerank_input: int = RERANK_INPUT,
449
+ top_k: int = DEFAULT_TOP_K,
450
+ use_reranker: bool = True,
451
+ dense_only: bool = False,
452
+ ):
453
+ import faiss
454
+
455
+ # Resolve index directory
456
+ idx_dir = Path(index_dir) if index_dir else DEFAULT_INDEX_DIR
457
+ dense_file = idx_dir / "dense.faiss"
458
+ sparse_file = idx_dir / "sparse.jsonl"
459
+ meta_file = idx_dir / "metadata.jsonl"
460
+
461
+ self.device = device
462
+ self.dense_n = dense_n
463
+ self.sparse_n = sparse_n
464
+ self.rerank_input = rerank_input
465
+ self.top_k = top_k
466
+ self.use_reranker = use_reranker
467
+ self.dense_only = dense_only
468
+
469
+ if not dense_file.exists():
470
+ raise FileNotFoundError(f"Dense index not found: {dense_file}")
471
+ print(f" Loading FAISS index from {idx_dir}...")
472
+ self.faiss_index = faiss.read_index(str(dense_file))
473
+ print(f" FAISS: {self.faiss_index.ntotal:,} vectors")
474
+
475
+ print(f" Loading metadata...")
476
+ self.metadata: list[dict] = []
477
+ with meta_file.open(encoding="utf-8") as f:
478
+ for line in f:
479
+ line = line.strip()
480
+ if line:
481
+ self.metadata.append(json.loads(line))
482
+ print(f" Metadata: {len(self.metadata):,} records")
483
+
484
+ self.chunk_id_to_idx = {m["chunk_id"]: i for i, m in enumerate(self.metadata)}
485
+
486
+ if not dense_only:
487
+ self.sparse_index = SparseIndex(sparse_file)
488
+ else:
489
+ self.sparse_index = None
490
+
491
+ print(f" Loading query encoder ({EMBED_MODEL})...")
492
+ self.encoder = QueryEncoder(EMBED_MODEL, device)
493
+
494
+ self.reranker = None
495
+ if use_reranker:
496
+ print(f" Loading reranker ({RERANK_MODEL})...")
497
+ self.reranker = Reranker(RERANK_MODEL, device)
498
+
499
+ print(f"\n Retriever ready "
500
+ f"device={device} index={idx_dir} "
501
+ f"dense_n={dense_n} sparse_n={sparse_n} "
502
+ f"rerank={use_reranker} top_k={top_k}\n")
503
+
504
+ def retrieve(self, query: str, top_k: Optional[int] = None) -> list[Result]:
505
+ import numpy as np
506
+
507
+ top_k = top_k or self.top_k
508
+ t0 = time.time()
509
+
510
+ # ── 1. Encode query ───────────────────────────────────────────────────
511
+ encoded = self.encoder.encode(query)
512
+ dense_vec = encoded["dense_vec"]
513
+ sparse_w = encoded["sparse_weights"]
514
+
515
+ # ── 2. Dense retrieval (FAISS) ────────────────────────────────────────
516
+ q_vec = dense_vec.reshape(1, -1).astype("float32")
517
+ scores, indices = self.faiss_index.search(q_vec, self.dense_n)
518
+ dense_ranked = [
519
+ (int(idx), float(score))
520
+ for idx, score in zip(indices[0], scores[0])
521
+ if idx >= 0
522
+ ]
523
+
524
+ if self.dense_only:
525
+ results = self._build_results(
526
+ dense_ranked[:top_k],
527
+ dense_ranked=dense_ranked,
528
+ sparse_ranked=[],
529
+ )
530
+ if self.use_reranker and self.reranker:
531
+ results = self.reranker.rerank(query, results)
532
+ return results[:top_k]
533
+
534
+ # ── 3. Sparse retrieval ───────────────────────────────────────────────
535
+ sparse_ranked = self.sparse_index.query(sparse_w, self.sparse_n)
536
+
537
+ # ── 4. RRF fusion ─────────────────────────────────────────────────────
538
+ fused = reciprocal_rank_fusion(dense_ranked, sparse_ranked, k=RRF_K)
539
+ fused = fused[:self.rerank_input]
540
+
541
+ # ── 5. Build Result objects ───────────────────────────────────────────
542
+ dense_rank_map = {idx: r+1 for r, (idx, _) in enumerate(dense_ranked)}
543
+ sparse_rank_map = {idx: r+1 for r, (idx, _) in enumerate(sparse_ranked)}
544
+
545
+ candidates = []
546
+ for chunk_idx, rrf_score in fused:
547
+ if chunk_idx >= len(self.metadata):
548
+ continue
549
+ m = self.metadata[chunk_idx]
550
+ candidates.append(Result(
551
+ chunk_id = m.get("chunk_id", ""),
552
+ body = m.get("body", ""),
553
+ collection = m.get("collection", ""),
554
+ date_iso = m.get("date_iso"),
555
+ speaker = m.get("speaker"),
556
+ source_url = m.get("source_url"),
557
+ page_number = m.get("page_number"),
558
+ slug = m.get("slug"),
559
+ dense_rank = dense_rank_map.get(chunk_idx),
560
+ sparse_rank = sparse_rank_map.get(chunk_idx),
561
+ rrf_score = rrf_score,
562
+ ))
563
+
564
+ # ── 6. Rerank ─────────────────────────────────────────────────────────
565
+ if self.use_reranker and self.reranker and candidates:
566
+ candidates = self.reranker.rerank(query, candidates)
567
+
568
+ elapsed = time.time() - t0
569
+ print(f" Retrieved {len(candidates[:top_k])} results in {elapsed:.2f}s "
570
+ f"(dense={len(dense_ranked)} sparse={len(sparse_ranked)} "
571
+ f"fused={len(fused)} reranked={self.use_reranker})")
572
+
573
+ return candidates[:top_k]
574
+
575
+ def _build_results(self, ranked, dense_ranked, sparse_ranked) -> list[Result]:
576
+ dense_rank_map = {idx: r+1 for r, (idx, _) in enumerate(dense_ranked)}
577
+ sparse_rank_map = {idx: r+1 for r, (idx, _) in enumerate(sparse_ranked)}
578
+ results = []
579
+ for chunk_idx, rrf_score in ranked:
580
+ if chunk_idx >= len(self.metadata):
581
+ continue
582
+ m = self.metadata[chunk_idx]
583
+ results.append(Result(
584
+ chunk_id = m.get("chunk_id", ""),
585
+ body = m.get("body", ""),
586
+ collection = m.get("collection", ""),
587
+ date_iso = m.get("date_iso"),
588
+ speaker = m.get("speaker"),
589
+ source_url = m.get("source_url"),
590
+ page_number = m.get("page_number"),
591
+ slug = m.get("slug"),
592
+ dense_rank = dense_rank_map.get(chunk_idx),
593
+ sparse_rank = sparse_rank_map.get(chunk_idx),
594
+ rrf_score = rrf_score,
595
+ ))
596
+ return results
597
+
598
+
599
+ # ── CLI smoke test ──────────────────────────────────────────���─────────────────
600
+
601
+ def main():
602
+ ap = argparse.ArgumentParser(description="Nuremberg Scholar -- Retriever smoke test")
603
+ ap.add_argument("--query", required=True)
604
+ ap.add_argument("--top-k", type=int, default=DEFAULT_TOP_K)
605
+ ap.add_argument("--device", default="cuda")
606
+ ap.add_argument("--no-rerank", action="store_true")
607
+ ap.add_argument("--dense-only", action="store_true")
608
+ ap.add_argument("--dense-n", type=int, default=DENSE_N)
609
+ ap.add_argument("--sparse-n", type=int, default=SPARSE_N)
610
+ ap.add_argument("--index-dir", default=None,
611
+ help="Path to index directory (default: ./output/index/)")
612
+ args = ap.parse_args()
613
+
614
+ if args.device == "cuda":
615
+ try:
616
+ import torch
617
+ if not torch.cuda.is_available():
618
+ args.device = "cpu"
619
+ except ImportError:
620
+ args.device = "cpu"
621
+
622
+ print(f"\nNuremberg Scholar -- Retriever")
623
+ print("=" * 60)
624
+
625
+ retriever = Retriever(
626
+ index_dir = args.index_dir,
627
+ device = args.device,
628
+ dense_n = args.dense_n,
629
+ sparse_n = args.sparse_n,
630
+ top_k = args.top_k,
631
+ use_reranker = not args.no_rerank,
632
+ dense_only = args.dense_only,
633
+ )
634
+
635
+ print(f"\nQuery: {args.query}\n")
636
+ results = retriever.retrieve(args.query, top_k=args.top_k)
637
+
638
+ print(f"\n{'='*60}")
639
+ print(f"Top {len(results)} results:")
640
+ print(f"{'='*60}\n")
641
+ for i, r in enumerate(results, 1):
642
+ print(f" -- Result {i} --")
643
+ print(f" {r}\n")
644
+
645
+
646
+ if __name__ == "__main__":
647
+ main()