feat(rag): RAGRetriever (load + search → chunks with scores)
Browse files- src/rag/retrieve.py +40 -0
- tests/rag/test_retrieve.py +45 -0
src/rag/retrieve.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Query → top-k chunks. Encapsulates the embedder + store pair so callers
|
| 2 |
+
don't have to assemble both. Loads from disk lazily.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from src.core.logger import get_logger
|
| 9 |
+
from src.rag.embed import EMBEDDING_DIM, Embedder
|
| 10 |
+
from src.rag.store import FAISSStore
|
| 11 |
+
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RAGRetriever:
|
| 16 |
+
"""Bundle (embedder, store). Use `RAGRetriever.load(dir)` to construct."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, store: FAISSStore, embedder: Embedder) -> None:
|
| 19 |
+
self._store = store
|
| 20 |
+
self._embedder = embedder
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def load(cls, index_dir: Path) -> "RAGRetriever":
|
| 24 |
+
store = FAISSStore.load(Path(index_dir), dim=EMBEDDING_DIM)
|
| 25 |
+
return cls(store=store, embedder=Embedder())
|
| 26 |
+
|
| 27 |
+
def __len__(self) -> int:
|
| 28 |
+
return len(self._store)
|
| 29 |
+
|
| 30 |
+
def search(self, query: str, k: int = 5) -> list[dict]:
|
| 31 |
+
"""Return up to `k` chunks most relevant to `query`, sorted by score desc.
|
| 32 |
+
|
| 33 |
+
Each chunk dict carries `text`, `source`, `chunk_index`, `score`.
|
| 34 |
+
Returns [] for empty query or empty store.
|
| 35 |
+
"""
|
| 36 |
+
if not query.strip() or len(self._store) == 0:
|
| 37 |
+
return []
|
| 38 |
+
vec = self._embedder.encode([query])
|
| 39 |
+
hits = self._store.search(vec[0], k=k)
|
| 40 |
+
return [{**chunk, "score": score} for chunk, score in hits]
|
tests/rag/test_retrieve.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.rag.retrieve — query → top-k chunks."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.rag.ingest import ingest_directory
|
| 9 |
+
from src.rag.retrieve import RAGRetriever
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_FIXTURE_KB = Path(__file__).parent.parent / "fixtures" / "kb_sample"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestRAGRetriever:
|
| 16 |
+
@pytest.fixture(scope="class")
|
| 17 |
+
def retriever(self, tmp_path_factory: pytest.TempPathFactory) -> RAGRetriever:
|
| 18 |
+
idx_dir = tmp_path_factory.mktemp("rag_idx")
|
| 19 |
+
ingest_directory(_FIXTURE_KB, idx_dir)
|
| 20 |
+
return RAGRetriever.load(idx_dir)
|
| 21 |
+
|
| 22 |
+
def test_bbb_query_returns_lipinski_chunk(self, retriever: RAGRetriever) -> None:
|
| 23 |
+
hits = retriever.search("Why does ethanol cross the blood-brain barrier?", k=3)
|
| 24 |
+
assert len(hits) == 3
|
| 25 |
+
sources = [h["source"] for h in hits]
|
| 26 |
+
assert "lipinski_rule_of_five.md" in sources
|
| 27 |
+
# top hit should be from lipinski
|
| 28 |
+
assert hits[0]["source"] == "lipinski_rule_of_five.md"
|
| 29 |
+
|
| 30 |
+
def test_combat_query_returns_combat_chunk(self, retriever: RAGRetriever) -> None:
|
| 31 |
+
hits = retriever.search("How does ComBat remove scanner bias from MRI data?", k=2)
|
| 32 |
+
assert hits[0]["source"] == "combat_harmonization_primer.md"
|
| 33 |
+
|
| 34 |
+
def test_eeg_query_returns_ica_chunk(self, retriever: RAGRetriever) -> None:
|
| 35 |
+
hits = retriever.search("How do you remove eye blink artifacts from EEG?", k=2)
|
| 36 |
+
assert hits[0]["source"] == "mne_ica_basics.md"
|
| 37 |
+
|
| 38 |
+
def test_search_includes_score_and_text(self, retriever: RAGRetriever) -> None:
|
| 39 |
+
hits = retriever.search("BBB permeability", k=1)
|
| 40 |
+
h = hits[0]
|
| 41 |
+
assert "text" in h
|
| 42 |
+
assert "source" in h
|
| 43 |
+
assert "score" in h
|
| 44 |
+
assert isinstance(h["score"], float)
|
| 45 |
+
assert 0.0 <= h["score"] <= 1.0
|