hackathon / src /rag /clinical /loader.py
mekosotto's picture
feat(rag): clinical TF-IDF index loader with __main__.Chunk routing
6b2c154
"""Load (or rebuild) the TF-IDF clinical RAG index.
The user's rag.py builds the index from `__main__.Chunk` (a frozen
dataclass). Picking up that pickle from a different module path requires
a custom Unpickler that re-routes the class — see _ChunkRoutingUnpickler.
"""
from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any
from src.core.logger import get_logger
from src.rag.clinical.types import ClinicalChunk
logger = get_logger(__name__)
class _ChunkRoutingUnpickler(pickle.Unpickler):
"""Pickle's `find_class` hook lets us swap `__main__.Chunk` (and
`rag.Chunk` if the user later runs the builder as a module) for our
`ClinicalChunk` — both are frozen dataclasses with the same fields,
so the swap is structurally safe.
"""
def find_class(self, module: str, name: str):
if name == "Chunk" and module in {"__main__", "rag", "rag.rag"}:
return ClinicalChunk
return super().find_class(module, name)
def load_index(path: Path) -> dict[str, Any]:
"""Unpickle a TF-IDF index produced by the user's rag.py."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"clinical RAG index not found: {path}")
with path.open("rb") as f:
payload = _ChunkRoutingUnpickler(f).load()
expected = {"chunks", "vectorizer", "matrix"}
if not expected <= set(payload):
raise ValueError(
f"clinical RAG index missing expected keys: have {sorted(payload)}, need {sorted(expected)}"
)
logger.info("loaded clinical RAG index: %d chunks from %s", len(payload["chunks"]), path)
return payload