Subhadip007 commited on
Commit
f780124
·
1 Parent(s): daafb32

feat: retrieval optimization pipeline complete

Browse files

- BM25 sparse index: 15,664 documents, 39.3MB
- Hybrid retrieval: RRF fusion (dense 0.7 + sparse 0.3)
- Cross-encoder re-ranking: ms-marco-MiniLM-L-6-v2
- Diversity filter: max 2 chunks per paper
- Fixed Qdrant Range filter: publication_year as integer field
- CE score range: 4.3-8.3 (strong relevance signal)
- Query latency: 3-17s (first query loads models, subsequent ~4s)

.vscode/settings.json CHANGED
@@ -1,2 +1,5 @@
1
  {
 
 
 
2
  }
 
1
  {
2
+ "cSpell.words": [
3
+ "reranked"
4
+ ]
5
  }
src/retrieval/__init__.py ADDED
File without changes
src/retrieval/hybrid_retriever.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid retriever combining dense (Qdrant) and sparse (BM25) search.
3
+
4
+ RECIPROCAL RANK FUSION (RRF) EXPLAINED:
5
+
6
+ Instead of trying to normalize scores across two completely different
7
+ scoring systems (cosine similarity vs BM25 score), RRF uses RANKS.
8
+
9
+ For each result, we compute:
10
+ RRF_score = 1 / (k + rank_in_dense_results)
11
+ + 1 / (k + rank_in_bm25_results)
12
+
13
+ Where k=60 is a constant that dampens the impact of very high ranks.
14
+
15
+ Example:
16
+ Chunk A: rank 1 in dense, rank 3 in BM25
17
+ RRF = 1/(60+1) + 1/(60+3) = 0.0164 + 0.0159 = 0.0323
18
+
19
+ Chunk B: rank 2 in dense, not in BM25
20
+ RRF = 1/(60+2) + 0 = 0.0161
21
+
22
+ Chunk C: rank 5 in dense, rank 1 in BM25
23
+ RRF = 1/(60+5) + 1/(60+1) = 0.0154 + 0.0164 = 0.0317
24
+
25
+ Chunk A wins - it ranked highly in BOTH systems.
26
+ Chunk C is second - it was top in BM25 and decent in dense.
27
+
28
+ WHY RRF OVER SCORE NORMALIZATION:
29
+ BM25 scores range 0-15 typically.
30
+ Cosine similarity scores range 0-1.
31
+ Normalizing these to the same scale requires knowing
32
+ the distribution of each, which changes per query.
33
+ RRF sidesteps this entirely by using ranks.
34
+
35
+ This is why RRF is the industry standard for hybrid search.
36
+ """
37
+
38
+ from typing import Optional
39
+ import numpy as np
40
+
41
+
42
+ from src.vectorstore.qdrant_store import QdrantStore
43
+ from src.vectorstore.bm25_store import BM25Store
44
+ from src.embeddings.embedding_model import EmbeddingModel
45
+ from src.utils.logger import get_logger
46
+ from config.settings import TOP_K_RETRIEVAL
47
+
48
+
49
+ logger = get_logger(__name__)
50
+
51
+ # RRF constant - 60 is the standard value from the original paper
52
+ RRF_K = 60
53
+
54
+
55
+
56
+ class HybridRetriever:
57
+ """
58
+ Combines dense vector search and BM25 keyword search
59
+ using Reciprocal Rank Fusion for score merging.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ qdrant_store: QdrantStore,
65
+ bm25_store: BM25Store,
66
+ embedding_model: EmbeddingModel,
67
+ ):
68
+ self.qdrant = qdrant_store
69
+ self.bm25 = bm25_store
70
+ self.embedder = embedding_model
71
+
72
+
73
+ def retrieve(
74
+ self,
75
+ query: str,
76
+ top_k: int = TOP_K_RETRIEVAL,
77
+ filter_category: Optional[str] = None,
78
+ filter_year_gte: Optional[int] = None,
79
+ dense_weight: float = 0.7,
80
+ sparse_weight: float = 0.3,
81
+ ) -> list[dict]:
82
+ """
83
+ Hybrid retrieval with RRF fusion.
84
+
85
+ Args:
86
+ query: User's raw query string
87
+ top_k: Final number of results to return
88
+ filter_category: ArXiv category filter (e.g. "cs.LG")
89
+ filter_year_gte: Only papers from this year onwards
90
+ dense_weight: Weight for dense retrieval in fusion (0-1)
91
+ sparse_weight: Weight for BM25 retrieval in fusion (0-1)
92
+
93
+ Returns:
94
+ List of result dicts sorted by RRF score (best first)
95
+
96
+ WHY dense_weight = 0.7, sparse_weight = 0.3:
97
+ Research papers use technical language where semantic
98
+ understanding (dense) matters more than exact keyword
99
+ matching (sparse). For a code search system, you'd
100
+ flip these weights.
101
+ """
102
+
103
+ # -------------- Step 1: Dense retrieval --------------
104
+ query_vector = self.embedder.embed_query(query)
105
+ dense_results = self.qdrant.search(
106
+ query_vector = query_vector,
107
+ top_k = top_k * 2, # Retrieve more for fusion
108
+ filter_category = filter_category,
109
+ filter_year_gte = filter_year_gte,
110
+ )
111
+
112
+ # -------------- Step 2: Sparse (BM25) retrieval --------------
113
+ sparse_results = self.bm25.search(query, top_k = top_k * 2)
114
+
115
+
116
+ # -------------- Step 3: Build chunk_id -> full data lookup --------------
117
+ # Dense results have full payload (text, metadata)
118
+ # Sparse results only have chunk_id and text
119
+ chunk_data = {}
120
+ for r in dense_results:
121
+ if r["chunk_id"] not in chunk_data:
122
+ chunk_data[r["chunk_id"]] = {
123
+ "chunk_id": r["chunk_id"],
124
+ "text": r["text"],
125
+ "score": 0.0,
126
+ }
127
+
128
+ # -------------- Step 4: Compute RRF score --------------
129
+ RRF_scores = {}
130
+
131
+ # Add dense ranks
132
+ for rank, result in enumerate(dense_results):
133
+ cid = result["chunk_id"]
134
+ RRF_scores[cid] = RRF_scores.get(cid, 0.0)
135
+ RRF_scores[cid] += dense_weight * (1.0 / (RRF_K + rank + 1))
136
+
137
+ # Add sparse ranks
138
+ for rank, result in enumerate(sparse_results):
139
+ cid = result["chunk_id"]
140
+ RRF_scores[cid] = RRF_scores.get(cid, 0.0)
141
+ RRF_scores[cid] += sparse_weight * (1.0 / (RRF_K + rank + 1))
142
+
143
+ # -------------- Step 5: Sort by RRF score --------------
144
+ sorted_ids = sorted(RRF_scores, key = RRF_scores.get, reverse = True)
145
+
146
+
147
+ # -------------- Step 6: Build final results --------------
148
+ final_results = []
149
+ for cid in sorted_ids[:top_k]:
150
+ data = chunk_data.get(cid, {})
151
+ final_results.append(
152
+ {
153
+ **data,
154
+ "rrf_score": round(RRF_scores[cid], 6),
155
+ "retrieval": "hybrid",
156
+ }
157
+ )
158
+
159
+ logger.debug(
160
+ f"Hybrid retrieval: {len(dense_results)} dense + "
161
+ f"{len(sparse_results)} sparse -> {len(final_results)} merged"
162
+ )
163
+
164
+ return final_results
src/retrieval/reranker.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-encoder re-ranking for improved retrieval precision.
3
+
4
+ THE DIFFERENCE BETWEEN BI-ENCODER AND CROSS-ENCODER:
5
+
6
+ Bi-encoder (what BGE does):
7
+ embed(query) → vector_q
8
+ embed(chunk) → vector_c
9
+ score = cosine(vector_q, vector_c)
10
+
11
+ Query and chunk are embedded INDEPENDENTLY.
12
+ Fast (vectors pre-computed), but loses interaction signal.
13
+
14
+ Cross-encoder (what we use for re-ranking):
15
+ score = model(query + [SEP] + chunk)
16
+
17
+ Query and chunk are processed TOGETHER by the model.
18
+ The model can see how query tokens relate to chunk tokens.
19
+ Slower (cannot pre-compute), but much more accurate.
20
+
21
+ THE TWO-STAGE PATTERN:
22
+ Stage 1 (Retrieval): Bi-encoder -> top-20 candidates (fast, approximate)
23
+ Stage 2 (Re-ranking): Cross-encoder -> re-score top-20 (slow, accurate)
24
+
25
+ We only run the expensive cross-encoder on 20 candidates,
26
+ not all 15,664 chunks. This gives us accuracy without
27
+ paying the full cost for every chunk.
28
+
29
+ MODEL: cross-encoder/ms-marco-MiniLM-L-6-v2
30
+ - Trained on MS MARCO passage retrieval dataset (500K+ queries)
31
+ - MiniLM architecture: fast on CPU
32
+ - Output: relevance score (-inf to +inf, higher = more relevant)
33
+ - Size: ~80MB
34
+ """
35
+
36
+ import logging
37
+ logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
38
+
39
+ from sentence_transformers import CrossEncoder
40
+ from src.utils.logger import get_logger
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
45
+
46
+
47
+ class CrossEncoderReranker:
48
+ """
49
+ Re-ranks retrieved chunks using a cross-encoder model.
50
+ """
51
+
52
+
53
+ def __init__(self, model_name: str = RERANKER_MODEL):
54
+ self._model = None
55
+ self._model_name = model_name
56
+ logger.info(f"CrossEncoderReranker initialized: {model_name}")
57
+
58
+ @property
59
+ def model(self) -> CrossEncoder:
60
+ """Lazy-load cross-encoder model."""
61
+ if self._model is None:
62
+ logger.info(f"Loading cross-encoder: {self._model_name}")
63
+ self._model = CrossEncoder(
64
+ self._model_name,
65
+ max_length = 512 # Max tokens for query+chunk combined
66
+ )
67
+ logger.info("Cross-encoder loaded")
68
+
69
+ return self._model
70
+
71
+
72
+ def rerank(
73
+ self,
74
+ query: str,
75
+ results: list[dict],
76
+ top_k: int = 5
77
+ ) -> list[dict]:
78
+ """
79
+ Re-rank a list of retrieved chunks using cross-encoder scoring.
80
+
81
+ Args:
82
+ query: Original user query
83
+ results: List of retrieved chunk dicts (from hybrid retriever)
84
+ top_k: How many top results to return after re-ranking
85
+
86
+ Returns:
87
+ Top-k results sorted by cross-encoder relevance score
88
+
89
+ WHAT THE CROSS-ENCODER SEES:
90
+ Input: "[CLS] how does attention work? [SEP] The transformer
91
+ architecture uses scaled dot-product attention where
92
+ queries, keys and values are computed... [SEP]"
93
+ Output: 8.3 (high relevance)
94
+
95
+ vs.
96
+
97
+ Input: "[CLS] how does attention work? [SEP] UAV delivery
98
+ systems require multi-agent coordination... [SEP]"
99
+ Output: -2.1 (low relevance)
100
+
101
+ The model learned these relevance patterns from 500K+
102
+ human-labeled query-passage pairs in MS MARCO.
103
+ """
104
+
105
+ if not results:
106
+ return []
107
+
108
+ # Build (query, chunk_text) pairs for batch scoring
109
+ pairs = [
110
+ (query, r.get("text", ""))
111
+ for r in results
112
+ ]
113
+
114
+ # Score all pairs in one batch
115
+ # predict() returns numpy array of relevance scores
116
+ scores = self.model.predict(
117
+ pairs,
118
+ show_progress_bar = False,
119
+ batch_size = 32,
120
+ )
121
+
122
+ # Attach cross_encoder score to each result
123
+ for result, score in zip(results, scores):
124
+ result["ce_score"] = round(float(score), 4)
125
+
126
+ # Sort by cross-encoder score (descending)
127
+ reranked = sorted(results, key = lambda x: x["ce_score"], reverse = True)
128
+
129
+ logger.debug(
130
+ f"Re-ranked {len(results)} -> top-{top_k}. "
131
+ f"Score range: [{reranked[-1]["ce_score"]:.2f}, "
132
+ f"{reranked[0]["ce_score"]:.2f}]"
133
+ )
134
+
135
+
136
+ return reranked[:top_k]
137
+
138
+
139
+
140
+ def diversity_filter(results: list[dict], max_per_paper: int = 2) -> list[dict]:
141
+ """
142
+ Ensure no single paper dominates the results.
143
+
144
+ As you saw in test_search.py - the same paper appeared twice
145
+ in top-3. This function limits results to max_per_paper
146
+ chunks from any single paper.
147
+
148
+ Args:
149
+ results: List of result dicts (sorted by relevance)
150
+ max_per_paper: Maximum chunks allowed from the same paper
151
+
152
+ Returns:
153
+ Filtered list maintaining original relevance order
154
+
155
+ WHY THIS MATTERS FOR USER EXPERIENCE:
156
+ User asks: "how does attention work?"
157
+ Without diversity filter: 3 chunks from same attention paper
158
+ With diversity filter: 1-2 chunks each from 3 different papers
159
+
160
+ The second response is richer - multiple perspectives,
161
+ multiple research groups, more comprehensive coverage.
162
+ """
163
+
164
+ seen_papers: dict[str, int] = {}
165
+ filtered = []
166
+
167
+ for result in results:
168
+ paper_id = result.get("paper_id", "unknown")
169
+ count = seen_papers.get(paper_id, 0)
170
+
171
+ if count < max_per_paper:
172
+ filtered.append(result)
173
+ seen_papers[paper_id] = count + 1
174
+
175
+
176
+ return filtered
src/retrieval/retrieval_pipeline.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Orchestrates the full retrieval pipeline:
3
+ 1. Hybrid retrieval (dense + BM25)
4
+ 2. Cross-encoder re-ranking
5
+ 3. Diversity filtering
6
+
7
+ This is the component that the RAG pipeline (Phase 9) will call.
8
+ It takes a query string and returns the best chunks.
9
+ """
10
+
11
+ from typing import Optional
12
+
13
+
14
+ from src.retrieval.hybrid_retriever import HybridRetriever
15
+ from src.retrieval.reranker import CrossEncoderReranker, diversity_filter
16
+ from src.vectorstore.qdrant_store import QdrantStore
17
+ from src.vectorstore.bm25_store import BM25Store
18
+ from src.embeddings.embedding_model import EmbeddingModel
19
+ from src.utils.logger import get_logger
20
+ from config.settings import TOP_K_RETRIEVAL, TOP_K_RERANK
21
+
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+
27
+
28
+ class RetrievalPipeline:
29
+ """
30
+ Full retrieval pipeline with hybrid search + re-ranking.
31
+
32
+ Usage:
33
+ pipeline = RetrievalPipeline()
34
+ results = pipeline.retrieve("how does LoRA fine-tuning work?")
35
+ for r in results:
36
+ print(r["title"], r["ce_score"], r["text"][:100])
37
+ """
38
+
39
+ def __init__(self):
40
+ # Initialize all components
41
+ logger.info("Initializing RetrievalPipeline...")
42
+
43
+
44
+ qdrant = QdrantStore()
45
+ embedder = EmbeddingModel()
46
+
47
+
48
+ # Load or build BM25 index
49
+ bm25 = BM25Store()
50
+ if not bm25.load():
51
+ logger.info("BM25 index not found - building now...")
52
+ bm25.build_index()
53
+
54
+ self.hybrid_retriever = HybridRetriever(
55
+ qdrant_store = qdrant,
56
+ bm25_store = bm25,
57
+ embedding_model = embedder,
58
+ )
59
+
60
+ self.reranker = CrossEncoderReranker()
61
+
62
+ logger.info("RetrievalPipeline ready")
63
+
64
+
65
+ def retrieve(
66
+ self,
67
+ query: str,
68
+ top_k_final: int = TOP_K_RERANK,
69
+ filter_category: Optional[str] = None,
70
+ filter_year_gte: Optional[int] = None,
71
+ ) -> list[dict]:
72
+ """
73
+ Full retrieval: hybrid search → re-rank → diversity filter.
74
+
75
+ Args:
76
+ query: User's natural language question
77
+ top_k_final: Number of chunks to return (default 5)
78
+ filter_category: ArXiv category filter
79
+ filter_year_gte: Year filter
80
+
81
+ Returns:
82
+ List of top chunks with all metadata and scores
83
+ """
84
+ logger.debug(f"Retrieving for query: '{query[:60]}'")
85
+
86
+ # Stage 1: Hybrid retrieval → top-20 candidates
87
+ candidates = self.hybrid_retriever.retrieve(
88
+ query = query,
89
+ top_k = TOP_K_RETRIEVAL * 2, # 40 candidates
90
+ filter_category = filter_category,
91
+ filter_year_gte = filter_year_gte,
92
+ )
93
+
94
+ if not candidates:
95
+ logger.warning(f"No candidates found for query: {query}")
96
+ return []
97
+
98
+ # Stage 2: Cross-encoder re-ranking -> top-5
99
+ reranked = self.reranker.rerank(
100
+ query = query,
101
+ results = candidates,
102
+ top_k = top_k_final * 2, # Keep extra before diversity filter
103
+ )
104
+
105
+ # Stage 3: Diversity filter -> max 2 chunks per paper
106
+ diverse = diversity_filter(reranked, max_per_paper=2)
107
+
108
+ # Return top_k_final after diversity filtering
109
+ final = diverse[:top_k_final]
110
+
111
+ logger.debug(
112
+ f"Pipeline: {len(candidates)} candidates -> "
113
+ f"{len(reranked)} reranked -> "
114
+ f"{len(final)} final"
115
+ )
116
+
117
+ return final
src/vectorstore/bm25_store.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BM25 sparse retrieval index for keyword-based search.
3
+
4
+ BM25 (Best Match 25) is the gold standard keyword search algorithm.
5
+ It powers Elasticsearch, Solr, and was the backbone of Google Search
6
+ before neural methods. It rewards:
7
+ - Term frequency: how often the query word appears in the chunk
8
+ - Inverse document frequency: rare words are more discriminative
9
+ - Document length normalization: prevents long chunks from dominating
10
+
11
+ WHY WE NEED THIS ALONGSIDE VECTOR SEARCH:
12
+ Query: "what is LoRA fine-tuning?"
13
+
14
+ Vector search: finds chunks about "parameter-efficient training"
15
+ (semantically related but may miss the exact acronym)
16
+
17
+ BM25: finds chunks containing the EXACT token "LoRA"
18
+ (exact match, regardless of semantic similarity)
19
+
20
+ Hybrid: finds chunks that are BOTH semantically relevant
21
+ AND contain the keyword - best of both worlds.
22
+ """
23
+
24
+ from copyreg import pickle
25
+ import json
26
+ import pickle
27
+ import re
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ from rank_bm25 import BM25Okapi
32
+
33
+ from src.utils.logger import get_logger
34
+ from config.settings import CHUNKS_DIR, EMBEDDINGS_DIR
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ # Where we persist the BM25 index
40
+ BM25_INDEX_PATH = EMBEDDINGS_DIR / "bm25_index.pkl"
41
+
42
+
43
+
44
+ def tokenize(text: str) -> list[str]:
45
+ """
46
+ Simple tokenizer for BM25.
47
+
48
+ BM25 works on token lists, not raw strings.
49
+ We lowercase and split on non-alphanumeric characters.
50
+
51
+ WHY NOT USE NLTK/SPACY:
52
+ For BM25 in a RAG pipeline, simple whitespace+punctuation
53
+ tokenization is sufficient and avoids heavy dependencies.
54
+ The quality difference is minimal for retrieval tasks.
55
+ """
56
+ text = text.lower()
57
+
58
+ # Split on anything that't not a letter, number, or hyphen
59
+ tokens = re.findall(r'[a-z0-9]+(?:-[a-z0-9]+)*', text)
60
+ return tokens
61
+
62
+
63
+
64
+ class BM25Store:
65
+ """
66
+ Manages a BM25 index over all chunk texts.
67
+
68
+ The index is built once and persisted to disk as a pickle file.
69
+ Loading from pickle is near-instant vs rebuilding from scratch.
70
+ """
71
+
72
+ def __init__(self):
73
+ self.bm25: BM25Okapi = None
74
+ self.chunk_ids: list[str] = []
75
+ self.texts: list[str] = []
76
+
77
+
78
+ def build_index(self) -> None:
79
+ """
80
+ Build BM25 index from all chunk files.
81
+
82
+ Loads all chunk texts, tokenizes them, and creates the BM25 index.
83
+ This takes ~30 seconds for 15,664 chunks.
84
+ """
85
+ logger.info("Building BM25 index from chunk files...")
86
+
87
+ chunk_ids = []
88
+ texts = []
89
+
90
+ for cf in CHUNKS_DIR.glob("*_semantic.json"):
91
+ with open(cf, "r", encoding = 'utf-8') as f:
92
+ chunks = json.load(f)
93
+
94
+ for chunk in chunks:
95
+ chunk_ids.append(chunk["chunk_id"])
96
+ texts.append(chunk["text"])
97
+
98
+ logger.info(f"Tokenizing {len(texts):,} chunks...")
99
+
100
+ # Tokenize all texts
101
+ # bm250kapi expects a list of token lists
102
+ tokenized_corpus = [tokenize(text) for text in texts]
103
+
104
+ # Build the BM25 index
105
+ # BM250kapi is the standard 0kapi BM25 variant
106
+ self.bm25 = BM25Okapi(tokenized_corpus)
107
+ self.chunk_ids = chunk_ids
108
+ self.texts = texts
109
+
110
+ logger.info(f"BM25 index built: {len(chunk_ids):,} documents")
111
+
112
+ # Persist to disk
113
+ self._save()
114
+
115
+
116
+
117
+ def _save(self) -> None:
118
+ """Save index to disk using pickle."""
119
+ data = {
120
+ "bm25": self.bm25,
121
+ "chunk_ids": self.chunk_ids,
122
+ "texts": self.texts,
123
+ }
124
+
125
+ with open(BM25_INDEX_PATH, "wb") as f:
126
+ pickle.dump(data, f)
127
+ size_mb = BM25_INDEX_PATH.stat().st_size / 1024 / 1024
128
+ logger.info(f"BM25 index saved: {BM25_INDEX_PATH} ({size_mb:.1f} MB)")
129
+
130
+
131
+
132
+ def load(self) -> bool:
133
+ """
134
+ Look index from disk
135
+ Return True if loaded, False if index doesn't exists
136
+ """
137
+ if not BM25_INDEX_PATH.exists():
138
+ logger.info("No BM25 index found on disk")
139
+ return False
140
+
141
+ logger.info("Loading BM25 index from disk...")
142
+ with open(BM25_INDEX_PATH, "rb") as f:
143
+ data = pickle.load(f)
144
+
145
+ self.bm25 = data["bm25"]
146
+ self.chunk_ids = data["chunk_ids"]
147
+ self.texts = data["texts"]
148
+
149
+ logger.info(f"BM25 index loaded: {len(self.chunk_ids):,} documents")
150
+ return True
151
+
152
+
153
+ def search(self, query: str, top_k: int = 20) -> list[dict]:
154
+ """
155
+ Search BM25 index with a text query.
156
+
157
+ Args:
158
+ query: Raw query string (NOT embedded - BM25 uses tokens)
159
+ top_k: Number of top results to return
160
+
161
+ Returns:
162
+ List of dicts with chunk_id, bm25_score, text
163
+
164
+ HOW BM25 SCORING WORKS:
165
+ Given query tokens ["lora", "fine-tuning"],
166
+ BM25 scores each document based on how frequently
167
+ these tokens appear, weighted by their rarity across
168
+ all documents (IDF) and normalized by document length.
169
+ Higher score = better keyword match.
170
+ """
171
+ if self.bm25 is None:
172
+ raise RuntimeError("BM25 index not loaded. Call build_index() or load() first.")
173
+
174
+ query_tokens = tokenize(query)
175
+
176
+ if not query_tokens:
177
+ return []
178
+
179
+
180
+ # get_scores returns array of shape (n_documents,)
181
+ # with BM25 score for each document
182
+ scores = self.bm25.get_scores(query_tokens)
183
+
184
+
185
+ # Get indices of top-k scores (argsort ascending, take last k, reverse)
186
+ top_indices = np.argsort(scores)[-top_k:][::-1]
187
+
188
+
189
+ results = []
190
+ for idx in top_indices:
191
+ score = float(scores[idx])
192
+ if score <= 0:
193
+ # Skip zero-score results - no keywords overlap at all
194
+ continue
195
+ results.append(
196
+ {
197
+ "chunk_id": self.chunk_ids[idx],
198
+ "bm25_score": round(score, 4),
199
+ "text": self.texts[idx],
200
+ }
201
+ )
202
+
203
+ return results
src/vectorstore/qdrant_store.py CHANGED
@@ -175,7 +175,8 @@ class QdrantStore:
175
  payload = {
176
  # Store ALL metadata in payload for retrieval
177
  **metadata[i],
178
- "text": texts[i], # Inlcude chunk text
 
179
  }
180
  )
181
  points.append(point)
@@ -286,13 +287,12 @@ class QdrantStore:
286
  )
287
 
288
  if year_gte:
289
- # published_date is stored as "YYYY-MM-DD" string
290
- # We filter by string comparison: "2024-01-01" <= date
291
- # This works because ISO date strings sort lexicographically
292
  conditions.append(
293
  FieldCondition(
294
- key = "published_date",
295
- range = Range(gte = f"{year_gte}-01-01")
296
  )
297
  )
298
 
 
175
  payload = {
176
  # Store ALL metadata in payload for retrieval
177
  **metadata[i],
178
+ "text": texts[i], # Include chunk text
179
+ "publication_year": int(metadata[i].get("published_date", "0000")[:4]),
180
  }
181
  )
182
  points.append(point)
 
287
  )
288
 
289
  if year_gte:
290
+ # publication_year is stored as an integer (e.g. 2026)
291
+ # Range(gte=year_gte) filters to papers from that year onwards
 
292
  conditions.append(
293
  FieldCondition(
294
+ key = "publication_year",
295
+ range = Range(gte = year_gte)
296
  )
297
  )
298
 
test_retrieval.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test the full retrieval pipeline: hybrid search + re-ranking + diversity.
3
+ Compare it against pure dense search to show the improvement.
4
+ """
5
+
6
+ import time
7
+ from src.utils.logger import setup_logger, get_logger
8
+ from src.retrieval.retrieval_pipeline import RetrievalPipeline
9
+ from src.vectorstore.qdrant_store import QdrantStore
10
+ from src.embeddings.embedding_model import EmbeddingModel
11
+
12
+ setup_logger()
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def test_pipeline(pipeline: RetrievalPipeline, query: str):
17
+ print(f"\n{'='*60}")
18
+ print(f"QUERY: {query}")
19
+ print(f"{'='*60}")
20
+
21
+ start = time.time()
22
+ results = pipeline.retrieve(query, top_k_final=5)
23
+ elapsed = time.time() - start
24
+
25
+ print(f"Retrieved {len(results)} results in {elapsed:.2f}s\n")
26
+
27
+ for i, r in enumerate(results):
28
+ print(f"[{i+1}] CE Score: {r.get('ce_score', 'N/A'):>7} | "
29
+ f"RRF: {r.get('rrf_score', 'N/A'):.4f}")
30
+ print(f" {r.get('title','')[:65]}...")
31
+ print(f" {r.get('text','')[:120].replace(chr(10),' ')}...")
32
+ print()
33
+
34
+
35
+ def main():
36
+ logger.info("Initializing full retrieval pipeline...")
37
+ pipeline = RetrievalPipeline()
38
+
39
+ # Test 1: Conceptual query
40
+ test_pipeline(
41
+ pipeline,
42
+ "how does self-attention mechanism work in transformers"
43
+ )
44
+
45
+ # Test 2: Specific method query - tests BM25 keyword advantage
46
+ test_pipeline(
47
+ pipeline,
48
+ "LoRA low-rank adaptation fine-tuning"
49
+ )
50
+
51
+ # Test 3: Comparison query
52
+ test_pipeline(
53
+ pipeline,
54
+ "reinforcement learning reward shaping techniques"
55
+ )
56
+
57
+ # Test 4: With year filter
58
+ print(f"\n{'='*60}")
59
+ print("FILTERED: 'graph neural networks' (2026 only)")
60
+ print(f"{'='*60}")
61
+
62
+ results = pipeline.retrieve(
63
+ "graph neural networks",
64
+ filter_year_gte = 2026,
65
+ top_k_final = 3
66
+ )
67
+
68
+ for i, r in enumerate(results):
69
+ print(
70
+ f"[{i+1}] {r.get('published_date', 'N/A')} | "
71
+ f"CE: {r.get('ce_score','N/A'):>6} | "
72
+ f"{r.get('title','')[:55]}..."
73
+ )
74
+
75
+ logger.info("\n✅ Retrieval pipeline test complete")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()