feat(rag): clinical TF-IDF index loader with __main__.Chunk routing
Browse files
src/rag/clinical/__init__.py
ADDED
|
File without changes
|
src/rag/clinical/loader.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load (or rebuild) the TF-IDF clinical RAG index.
|
| 2 |
+
|
| 3 |
+
The user's rag.py builds the index from `__main__.Chunk` (a frozen
|
| 4 |
+
dataclass). Picking up that pickle from a different module path requires
|
| 5 |
+
a custom Unpickler that re-routes the class — see _ChunkRoutingUnpickler.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import pickle
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from src.core.logger import get_logger
|
| 14 |
+
from src.rag.clinical.types import ClinicalChunk
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _ChunkRoutingUnpickler(pickle.Unpickler):
|
| 20 |
+
"""Pickle's `find_class` hook lets us swap `__main__.Chunk` (and
|
| 21 |
+
`rag.Chunk` if the user later runs the builder as a module) for our
|
| 22 |
+
`ClinicalChunk` — both are frozen dataclasses with the same fields,
|
| 23 |
+
so the swap is structurally safe.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def find_class(self, module: str, name: str):
|
| 27 |
+
if name == "Chunk" and module in {"__main__", "rag", "rag.rag"}:
|
| 28 |
+
return ClinicalChunk
|
| 29 |
+
return super().find_class(module, name)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_index(path: Path) -> dict[str, Any]:
|
| 33 |
+
"""Unpickle a TF-IDF index produced by the user's rag.py."""
|
| 34 |
+
path = Path(path)
|
| 35 |
+
if not path.exists():
|
| 36 |
+
raise FileNotFoundError(f"clinical RAG index not found: {path}")
|
| 37 |
+
with path.open("rb") as f:
|
| 38 |
+
payload = _ChunkRoutingUnpickler(f).load()
|
| 39 |
+
expected = {"chunks", "vectorizer", "matrix"}
|
| 40 |
+
if not expected <= set(payload):
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"clinical RAG index missing expected keys: have {sorted(payload)}, need {sorted(expected)}"
|
| 43 |
+
)
|
| 44 |
+
logger.info("loaded clinical RAG index: %d chunks from %s", len(payload["chunks"]), path)
|
| 45 |
+
return payload
|
src/rag/clinical/types.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Types shared across clinical-RAG modules."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class ClinicalChunk:
|
| 11 |
+
"""Mirrors the Chunk dataclass produced by the user's rag.py builder."""
|
| 12 |
+
chunk_id: int
|
| 13 |
+
source: str
|
| 14 |
+
page_start: int
|
| 15 |
+
page_end: int
|
| 16 |
+
text: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ClinicalEvidence(BaseModel):
|
| 20 |
+
sentence: str
|
| 21 |
+
source: str
|
| 22 |
+
page_start: int
|
| 23 |
+
page_end: int
|
| 24 |
+
score: float = Field(..., ge=0.0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ClinicalRetrievalResult(BaseModel):
|
| 28 |
+
query: str
|
| 29 |
+
evidence: list[ClinicalEvidence]
|
| 30 |
+
summary_text: str = Field(..., description="Pre-formatted RAG feedback for the agent")
|
tests/fixtures/build_tiny_clinical_index.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build a synthetic TF-IDF clinical-RAG index for tests.
|
| 2 |
+
|
| 3 |
+
Avoids needing real PDFs. Constructs the same payload schema the user's
|
| 4 |
+
rag.py produces so the loader can be tested independently of pypdf.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import pickle
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 13 |
+
|
| 14 |
+
from src.rag.clinical.types import ClinicalChunk
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def build(path: Path) -> Path:
|
| 18 |
+
"""Save a tiny TF-IDF index at `path`."""
|
| 19 |
+
path = Path(path)
|
| 20 |
+
if path.exists():
|
| 21 |
+
return path
|
| 22 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
chunks = [
|
| 25 |
+
ClinicalChunk(0, "alzheimers_lifestyle.pdf", 1, 1,
|
| 26 |
+
"Aerobic exercise and Mediterranean diet are associated with reduced cognitive decline in older adults at risk for Alzheimer's disease."),
|
| 27 |
+
ClinicalChunk(1, "parkinsons_motor.pdf", 1, 1,
|
| 28 |
+
"Levodopa remains the most effective symptomatic treatment for motor symptoms of Parkinson's disease."),
|
| 29 |
+
ClinicalChunk(2, "alzheimers_mci.pdf", 2, 2,
|
| 30 |
+
"Mild cognitive impairment may progress to dementia; MMSE and MoCA are standard screening tools."),
|
| 31 |
+
ClinicalChunk(3, "parkinsons_nutrition.pdf", 1, 1,
|
| 32 |
+
"Dietary patterns rich in antioxidants and omega-3 fatty acids are linked to lower Parkinson's risk."),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
vectorizer = TfidfVectorizer(lowercase=True, ngram_range=(1, 2), min_df=1, norm="l2")
|
| 36 |
+
matrix = vectorizer.fit_transform([c.text for c in chunks])
|
| 37 |
+
|
| 38 |
+
payload = {
|
| 39 |
+
"created_at": datetime.now().isoformat(timespec="seconds"),
|
| 40 |
+
"source_dir": str(path.parent),
|
| 41 |
+
"chunk_words": 220,
|
| 42 |
+
"overlap_words": 45,
|
| 43 |
+
"chunks": chunks,
|
| 44 |
+
"vectorizer": vectorizer,
|
| 45 |
+
"matrix": matrix,
|
| 46 |
+
}
|
| 47 |
+
with path.open("wb") as f:
|
| 48 |
+
pickle.dump(payload, f)
|
| 49 |
+
return path
|
tests/rag/test_clinical_loader.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.rag.clinical.loader."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.rag.clinical import loader
|
| 9 |
+
from tests.fixtures.build_tiny_clinical_index import build as build_tiny
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestLoadIndex:
|
| 13 |
+
def test_load_returns_payload_with_expected_keys(self, tmp_path: Path) -> None:
|
| 14 |
+
idx_path = build_tiny(tmp_path / "tiny.pkl")
|
| 15 |
+
payload = loader.load_index(idx_path)
|
| 16 |
+
assert {"chunks", "vectorizer", "matrix"} <= set(payload)
|
| 17 |
+
assert len(payload["chunks"]) == 4
|
| 18 |
+
|
| 19 |
+
def test_missing_index_raises(self, tmp_path: Path) -> None:
|
| 20 |
+
with pytest.raises(FileNotFoundError, match="clinical RAG index not found"):
|
| 21 |
+
loader.load_index(tmp_path / "nope.pkl")
|
| 22 |
+
|
| 23 |
+
def test_unique_sources(self, tmp_path: Path) -> None:
|
| 24 |
+
idx_path = build_tiny(tmp_path / "tiny.pkl")
|
| 25 |
+
payload = loader.load_index(idx_path)
|
| 26 |
+
sources = {c.source for c in payload["chunks"]}
|
| 27 |
+
assert sources == {
|
| 28 |
+
"alzheimers_lifestyle.pdf", "parkinsons_motor.pdf",
|
| 29 |
+
"alzheimers_mci.pdf", "parkinsons_nutrition.pdf",
|
| 30 |
+
}
|