mekosotto commited on
Commit
6b2c154
·
1 Parent(s): ac78b6f

feat(rag): clinical TF-IDF index loader with __main__.Chunk routing

Browse files
src/rag/clinical/__init__.py ADDED
File without changes
src/rag/clinical/loader.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load (or rebuild) the TF-IDF clinical RAG index.
2
+
3
+ The user's rag.py builds the index from `__main__.Chunk` (a frozen
4
+ dataclass). Picking up that pickle from a different module path requires
5
+ a custom Unpickler that re-routes the class — see _ChunkRoutingUnpickler.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import pickle
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ from src.core.logger import get_logger
14
+ from src.rag.clinical.types import ClinicalChunk
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class _ChunkRoutingUnpickler(pickle.Unpickler):
20
+ """Pickle's `find_class` hook lets us swap `__main__.Chunk` (and
21
+ `rag.Chunk` if the user later runs the builder as a module) for our
22
+ `ClinicalChunk` — both are frozen dataclasses with the same fields,
23
+ so the swap is structurally safe.
24
+ """
25
+
26
+ def find_class(self, module: str, name: str):
27
+ if name == "Chunk" and module in {"__main__", "rag", "rag.rag"}:
28
+ return ClinicalChunk
29
+ return super().find_class(module, name)
30
+
31
+
32
+ def load_index(path: Path) -> dict[str, Any]:
33
+ """Unpickle a TF-IDF index produced by the user's rag.py."""
34
+ path = Path(path)
35
+ if not path.exists():
36
+ raise FileNotFoundError(f"clinical RAG index not found: {path}")
37
+ with path.open("rb") as f:
38
+ payload = _ChunkRoutingUnpickler(f).load()
39
+ expected = {"chunks", "vectorizer", "matrix"}
40
+ if not expected <= set(payload):
41
+ raise ValueError(
42
+ f"clinical RAG index missing expected keys: have {sorted(payload)}, need {sorted(expected)}"
43
+ )
44
+ logger.info("loaded clinical RAG index: %d chunks from %s", len(payload["chunks"]), path)
45
+ return payload
src/rag/clinical/types.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Types shared across clinical-RAG modules."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ClinicalChunk:
11
+ """Mirrors the Chunk dataclass produced by the user's rag.py builder."""
12
+ chunk_id: int
13
+ source: str
14
+ page_start: int
15
+ page_end: int
16
+ text: str
17
+
18
+
19
+ class ClinicalEvidence(BaseModel):
20
+ sentence: str
21
+ source: str
22
+ page_start: int
23
+ page_end: int
24
+ score: float = Field(..., ge=0.0)
25
+
26
+
27
+ class ClinicalRetrievalResult(BaseModel):
28
+ query: str
29
+ evidence: list[ClinicalEvidence]
30
+ summary_text: str = Field(..., description="Pre-formatted RAG feedback for the agent")
tests/fixtures/build_tiny_clinical_index.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a synthetic TF-IDF clinical-RAG index for tests.
2
+
3
+ Avoids needing real PDFs. Constructs the same payload schema the user's
4
+ rag.py produces so the loader can be tested independently of pypdf.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import pickle
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+
14
+ from src.rag.clinical.types import ClinicalChunk
15
+
16
+
17
+ def build(path: Path) -> Path:
18
+ """Save a tiny TF-IDF index at `path`."""
19
+ path = Path(path)
20
+ if path.exists():
21
+ return path
22
+ path.parent.mkdir(parents=True, exist_ok=True)
23
+
24
+ chunks = [
25
+ ClinicalChunk(0, "alzheimers_lifestyle.pdf", 1, 1,
26
+ "Aerobic exercise and Mediterranean diet are associated with reduced cognitive decline in older adults at risk for Alzheimer's disease."),
27
+ ClinicalChunk(1, "parkinsons_motor.pdf", 1, 1,
28
+ "Levodopa remains the most effective symptomatic treatment for motor symptoms of Parkinson's disease."),
29
+ ClinicalChunk(2, "alzheimers_mci.pdf", 2, 2,
30
+ "Mild cognitive impairment may progress to dementia; MMSE and MoCA are standard screening tools."),
31
+ ClinicalChunk(3, "parkinsons_nutrition.pdf", 1, 1,
32
+ "Dietary patterns rich in antioxidants and omega-3 fatty acids are linked to lower Parkinson's risk."),
33
+ ]
34
+
35
+ vectorizer = TfidfVectorizer(lowercase=True, ngram_range=(1, 2), min_df=1, norm="l2")
36
+ matrix = vectorizer.fit_transform([c.text for c in chunks])
37
+
38
+ payload = {
39
+ "created_at": datetime.now().isoformat(timespec="seconds"),
40
+ "source_dir": str(path.parent),
41
+ "chunk_words": 220,
42
+ "overlap_words": 45,
43
+ "chunks": chunks,
44
+ "vectorizer": vectorizer,
45
+ "matrix": matrix,
46
+ }
47
+ with path.open("wb") as f:
48
+ pickle.dump(payload, f)
49
+ return path
tests/rag/test_clinical_loader.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.clinical.loader."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ from src.rag.clinical import loader
9
+ from tests.fixtures.build_tiny_clinical_index import build as build_tiny
10
+
11
+
12
+ class TestLoadIndex:
13
+ def test_load_returns_payload_with_expected_keys(self, tmp_path: Path) -> None:
14
+ idx_path = build_tiny(tmp_path / "tiny.pkl")
15
+ payload = loader.load_index(idx_path)
16
+ assert {"chunks", "vectorizer", "matrix"} <= set(payload)
17
+ assert len(payload["chunks"]) == 4
18
+
19
+ def test_missing_index_raises(self, tmp_path: Path) -> None:
20
+ with pytest.raises(FileNotFoundError, match="clinical RAG index not found"):
21
+ loader.load_index(tmp_path / "nope.pkl")
22
+
23
+ def test_unique_sources(self, tmp_path: Path) -> None:
24
+ idx_path = build_tiny(tmp_path / "tiny.pkl")
25
+ payload = loader.load_index(idx_path)
26
+ sources = {c.source for c in payload["chunks"]}
27
+ assert sources == {
28
+ "alzheimers_lifestyle.pdf", "parkinsons_motor.pdf",
29
+ "alzheimers_mci.pdf", "parkinsons_nutrition.pdf",
30
+ }