mekosotto commited on
Commit
978f645
·
1 Parent(s): cf5c011

feat(rag): RAGRetriever (load + search → chunks with scores)

Browse files
Files changed (2) hide show
  1. src/rag/retrieve.py +40 -0
  2. tests/rag/test_retrieve.py +45 -0
src/rag/retrieve.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query → top-k chunks. Encapsulates the embedder + store pair so callers
2
+ don't have to assemble both. Loads from disk lazily.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from pathlib import Path
7
+
8
+ from src.core.logger import get_logger
9
+ from src.rag.embed import EMBEDDING_DIM, Embedder
10
+ from src.rag.store import FAISSStore
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class RAGRetriever:
16
+ """Bundle (embedder, store). Use `RAGRetriever.load(dir)` to construct."""
17
+
18
+ def __init__(self, store: FAISSStore, embedder: Embedder) -> None:
19
+ self._store = store
20
+ self._embedder = embedder
21
+
22
+ @classmethod
23
+ def load(cls, index_dir: Path) -> "RAGRetriever":
24
+ store = FAISSStore.load(Path(index_dir), dim=EMBEDDING_DIM)
25
+ return cls(store=store, embedder=Embedder())
26
+
27
+ def __len__(self) -> int:
28
+ return len(self._store)
29
+
30
+ def search(self, query: str, k: int = 5) -> list[dict]:
31
+ """Return up to `k` chunks most relevant to `query`, sorted by score desc.
32
+
33
+ Each chunk dict carries `text`, `source`, `chunk_index`, `score`.
34
+ Returns [] for empty query or empty store.
35
+ """
36
+ if not query.strip() or len(self._store) == 0:
37
+ return []
38
+ vec = self._embedder.encode([query])
39
+ hits = self._store.search(vec[0], k=k)
40
+ return [{**chunk, "score": score} for chunk, score in hits]
tests/rag/test_retrieve.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.retrieve — query → top-k chunks."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ from src.rag.ingest import ingest_directory
9
+ from src.rag.retrieve import RAGRetriever
10
+
11
+
12
+ _FIXTURE_KB = Path(__file__).parent.parent / "fixtures" / "kb_sample"
13
+
14
+
15
+ class TestRAGRetriever:
16
+ @pytest.fixture(scope="class")
17
+ def retriever(self, tmp_path_factory: pytest.TempPathFactory) -> RAGRetriever:
18
+ idx_dir = tmp_path_factory.mktemp("rag_idx")
19
+ ingest_directory(_FIXTURE_KB, idx_dir)
20
+ return RAGRetriever.load(idx_dir)
21
+
22
+ def test_bbb_query_returns_lipinski_chunk(self, retriever: RAGRetriever) -> None:
23
+ hits = retriever.search("Why does ethanol cross the blood-brain barrier?", k=3)
24
+ assert len(hits) == 3
25
+ sources = [h["source"] for h in hits]
26
+ assert "lipinski_rule_of_five.md" in sources
27
+ # top hit should be from lipinski
28
+ assert hits[0]["source"] == "lipinski_rule_of_five.md"
29
+
30
+ def test_combat_query_returns_combat_chunk(self, retriever: RAGRetriever) -> None:
31
+ hits = retriever.search("How does ComBat remove scanner bias from MRI data?", k=2)
32
+ assert hits[0]["source"] == "combat_harmonization_primer.md"
33
+
34
+ def test_eeg_query_returns_ica_chunk(self, retriever: RAGRetriever) -> None:
35
+ hits = retriever.search("How do you remove eye blink artifacts from EEG?", k=2)
36
+ assert hits[0]["source"] == "mne_ica_basics.md"
37
+
38
+ def test_search_includes_score_and_text(self, retriever: RAGRetriever) -> None:
39
+ hits = retriever.search("BBB permeability", k=1)
40
+ h = hits[0]
41
+ assert "text" in h
42
+ assert "source" in h
43
+ assert "score" in h
44
+ assert isinstance(h["score"], float)
45
+ assert 0.0 <= h["score"] <= 1.0