| """Load (or rebuild) the TF-IDF clinical RAG index. |
| |
| The user's rag.py builds the index from `__main__.Chunk` (a frozen |
| dataclass). Picking up that pickle from a different module path requires |
| a custom Unpickler that re-routes the class — see _ChunkRoutingUnpickler. |
| """ |
| from __future__ import annotations |
|
|
| import pickle |
| from pathlib import Path |
| from typing import Any |
|
|
| from src.core.logger import get_logger |
| from src.rag.clinical.types import ClinicalChunk |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class _ChunkRoutingUnpickler(pickle.Unpickler): |
| """Pickle's `find_class` hook lets us swap `__main__.Chunk` (and |
| `rag.Chunk` if the user later runs the builder as a module) for our |
| `ClinicalChunk` — both are frozen dataclasses with the same fields, |
| so the swap is structurally safe. |
| """ |
|
|
| def find_class(self, module: str, name: str): |
| if name == "Chunk" and module in {"__main__", "rag", "rag.rag"}: |
| return ClinicalChunk |
| return super().find_class(module, name) |
|
|
|
|
| def load_index(path: Path) -> dict[str, Any]: |
| """Unpickle a TF-IDF index produced by the user's rag.py.""" |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"clinical RAG index not found: {path}") |
| with path.open("rb") as f: |
| payload = _ChunkRoutingUnpickler(f).load() |
| expected = {"chunks", "vectorizer", "matrix"} |
| if not expected <= set(payload): |
| raise ValueError( |
| f"clinical RAG index missing expected keys: have {sorted(payload)}, need {sorted(expected)}" |
| ) |
| logger.info("loaded clinical RAG index: %d chunks from %s", len(payload["chunks"]), path) |
| return payload |
|
|