docs(plan): external assets integration (MRI DL 2D + TF-IDF RAG + OASIS tabular)
Browse filesRoadmap + three sub-plans for integrating user-supplied external assets:
1. MRI DL 2D — pretrained resnet18 4-class Alzheimer's classifier from
user's training run (BEST_PARAMS: image_size=160, lr=3.75e-4,
weight_decay=1.96e-4, dropout=0.31). Adds src/models/mri_dl_2d.py
parallel to volumetric ONNX path, dispatched via MRI_MODEL_KIND env.
2. TF-IDF clinical RAG — 14 medical PDFs (Alzheimer/Parkinson/lifestyle)
with Turkish+English query expansion. Wraps user's pre-built sklearn
TF-IDF index as src/rag/clinical/. Existing FAISS RAG kept.
3. OASIS tabular classifier — sklearn RF on OASIS longitudinal biomarkers
(MMSE/eTIV/nWBV/ASF/...). NOTE: user described notebook as 'EEG model'
but it is OASIS tabular. Plan flags this prominently with branch 3a
(default: integrate as fusion modality) vs 3b (await real EEG model).
All three plans flag prerequisite blockers (artifact transfer, dataset
acquisition) and preserve independence guarantees from the clinical
platform roadmap. Each ends with subagent-driven-development handoff.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# External Assets Integration — Roadmap
|
| 2 |
+
|
| 3 |
+
> **For agentic workers:** Index of three sub-plans for integrating the user's external assets. Each sub-plan executes via `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans`.
|
| 4 |
+
|
| 5 |
+
**Vision.** The user has supplied three external assets that should replace or extend our current placeholders:
|
| 6 |
+
|
| 7 |
+
| Asset | What it is | Replaces / extends |
|
| 8 |
+
|---|---|---|
|
| 9 |
+
| Pretrained MRI 2D classifier | PyTorch resnet18 trained on Kaggle's 4-class Alzheimer's MRI dataset (`MildDemented` / `ModerateDemented` / `NonDemented` / `VeryMildDemented`) | The dummy ONNX model in `tests/fixtures/build_dummy_mri_onnx.py`; the placeholder behaviour in `src/models/mri_model.py` |
|
| 10 |
+
| TF-IDF RAG corpus | 14 medical PDFs (Alzheimer + Parkinson + lifestyle/nutrition/exercise) with a pre-built TF-IDF index and Turkish query expansion | The existing FAISS+fastembed RAG in `src/rag/` (or runs alongside it) |
|
| 11 |
+
| OASIS tabular classifier (ipynb) | sklearn ensemble on OASIS longitudinal biomarkers (MMSE, eTIV, nWBV, ASF, …) | **Not an EEG model** — see sub-plan #3 for two routing options |
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Sub-projects
|
| 16 |
+
|
| 17 |
+
| # | Sub-plan file | Owner concern | Depends on | Demo on its own? |
|
| 18 |
+
|---|---|---|---|---|
|
| 19 |
+
| 1 | `2026-05-02-mri-dl-2d-integration.md` | Real MRI deep-learning model in production path | — (parallel to fusion) | yes (Streamlit + curl) |
|
| 20 |
+
| 2 | `2026-05-02-tfidf-rag-integration.md` | Lifestyle / clinical-paper RAG with Turkish support | — | yes (CLI + agent tool) |
|
| 21 |
+
| 3 | `2026-05-02-oasis-tabular-fusion-integration.md` | Tabular OASIS classifier as a fusion-engine feature **OR** wait for a real EEG model | fusion engine (#1 of clinical-platform-roadmap) | yes (POST /fusion/predict) |
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Build order
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
┌────────────────────────────┐ ┌────────────────────────────┐
|
| 29 |
+
│ #1 MRI DL 2D integration │ │ #2 TF-IDF RAG integration │
|
| 30 |
+
│ (independent) │ │ (independent) │
|
| 31 |
+
└─────────────┬──────────────┘ └────────────┬───────────────┘
|
| 32 |
+
│ │
|
| 33 |
+
└──────────────┬───────────────────┘
|
| 34 |
+
│
|
| 35 |
+
┌──────▼─────────┐
|
| 36 |
+
│ #3 OASIS │
|
| 37 |
+
│ classifier as │
|
| 38 |
+
│ fusion feature │
|
| 39 |
+
└────────────────┘
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
#1 and #2 can be built in parallel (different files). #3 should follow once both are stable so the demo flows end-to-end.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Open prerequisites (user must resolve)
|
| 47 |
+
|
| 48 |
+
These are **not** dev gaps — they are inputs we need from outside this codebase. Each sub-plan calls them out explicitly in its preamble, but listing here so they are in one place.
|
| 49 |
+
|
| 50 |
+
### A. MRI checkpoint file is not on this machine
|
| 51 |
+
|
| 52 |
+
The user said the artifact lives at `outputs\checkpoints\best_model.pt` (Windows-style path). `find /Users/mertgungor` returns no `best_model.pt`. Sub-plan #1 cannot start until the file is at `data/processed/mri_dl_2d/best_model.pt` (gitignored — never commit a model binary). Confirm class index order matches the trainer:
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
CLASS_TO_IDX = {
|
| 56 |
+
"MildDemented": 0,
|
| 57 |
+
"ModerateDemented": 1,
|
| 58 |
+
"NonDemented": 2,
|
| 59 |
+
"VeryMildDemented": 3,
|
| 60 |
+
}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
If the trainer used a different ordering (`ImageFolder` alphabetises by default), the labels we surface will be wrong. Sub-plan #1 ships a sanity test that catches this.
|
| 64 |
+
|
| 65 |
+
### B. The "EEG ipynb" is OASIS tabular, not EEG
|
| 66 |
+
|
| 67 |
+
`/Users/mertgungor/Downloads/rag/detecting-early-alzheimer-s (1).ipynb` trains an sklearn ensemble (LogReg / SVM / DT / RF / AdaBoost) on the OASIS longitudinal MRI **tabular** dataset (`oasis_longitudinal.csv` — MMSE, eTIV, nWBV, ASF, EDUC, SES, …). It contains zero EEG signal processing and saves no model artifact.
|
| 68 |
+
|
| 69 |
+
Sub-plan #3 has **two branches**:
|
| 70 |
+
|
| 71 |
+
- **Branch 3a (default).** Treat the OASIS biomarker model as a clinical-tests extension to the fusion engine (already accepts MMSE etc. as features — this just adds eTIV/nWBV/ASF and re-runs the trained sklearn model in-process).
|
| 72 |
+
- **Branch 3b.** If the user has a real EEG model elsewhere (a checkpoint file that consumes raw FIF / EDF data and emits class probabilities), the user must point us to it and we re-scope sub-plan #3 around that artifact.
|
| 73 |
+
|
| 74 |
+
The user must pick the branch before sub-plan #3 starts.
|
| 75 |
+
|
| 76 |
+
### C. RAG corpus location
|
| 77 |
+
|
| 78 |
+
The new RAG lives at `/Users/mertgungor/Downloads/rag/`. It must be copied into the repo at `data/external_rag/` (or a symlink — but symlinks break in Docker). The pre-built `index/rag_index.pkl` is 12.9 MB — gitignore the binary, commit only the source PDFs (or, for hackathon speed, gitignore both and document the manual copy step). Sub-plan #2 commits the wrapper code and a small fixture copy of one PDF for tests; the full corpus stays out of git.
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## Decoupling guarantees (carry forward from clinical-platform-roadmap.md)
|
| 83 |
+
|
| 84 |
+
Independence rules from the existing roadmap apply unchanged. Specifically:
|
| 85 |
+
|
| 86 |
+
- **MRI DL 2D model is a swap-in for the existing 3D ONNX path.** The `src/models/mri_model.py` API surface stays stable; the new module sits alongside as `src/models/mri_dl_2d.py` and is selected via env var (`MRI_MODEL_KIND=resnet18_2d` vs. `volumetric_onnx`). Pipelines that don't load it must continue to work.
|
| 87 |
+
- **TF-IDF RAG and FAISS RAG run side-by-side.** The agent tool `retrieve_context` is widened to accept a `corpus` parameter (`"clinical"` for new TF-IDF, `"reference"` for existing FAISS). Existing tests stay green.
|
| 88 |
+
- **BBB stays decoupled.** Same rule from the fusion plan: no sub-plan here introduces a BBB↔MRI hard dependency.
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
## "When am I done?" gates (apply to every sub-plan)
|
| 93 |
+
|
| 94 |
+
1. All TDD tasks committed.
|
| 95 |
+
2. Full test suite passes (current baseline: 295 passed, 1 skipped).
|
| 96 |
+
3. Feature reachable end-to-end: Streamlit UI **OR** curl `/predict/mri` / `/agent/run` / `/fusion/predict` works.
|
| 97 |
+
4. README has a one-paragraph note describing the new asset and how to swap it out.
|
| 98 |
+
5. Final code-reviewer subagent verdict: "Ready to merge".
|
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MRI DL 2D Integration Plan
|
| 2 |
+
|
| 3 |
+
> **For agentic workers:** REQUIRED SUB-SKILL: `superpowers:subagent-driven-development` (recommended). TDD throughout — failing test → minimal impl → passing test → commit.
|
| 4 |
+
|
| 5 |
+
**Goal.** Wire the user's pretrained PyTorch resnet18 (2D image, 4-class Alzheimer's: `MildDemented` / `ModerateDemented` / `NonDemented` / `VeryMildDemented`) into the production decision layer alongside the existing volumetric ONNX path. The model produces a probability vector that flows naturally into the fusion engine.
|
| 6 |
+
|
| 7 |
+
**Architecture.** Add `src/models/mri_dl_2d.py` parallel to the existing `src/models/mri_model.py`. A small selector picks between paths based on `MRI_MODEL_KIND` env var (`resnet18_2d` or `volumetric_onnx`). The 2D model loads a `state_dict` `.pt` checkpoint, applies the resnet18 preprocessing contract (resize 160, ImageNet normalisation), and emits `MRIPredictResponse` in the same shape the existing surface produces — so the API and frontend need no behavioural change.
|
| 8 |
+
|
| 9 |
+
**Tech stack.** Python 3.11, PyTorch (CPU), torchvision, Pillow. PyTorch is not currently in `requirements.txt` — Task 0 adds it. No new web dependencies.
|
| 10 |
+
|
| 11 |
+
**Trainer's hyper-parameters (for reference, not all relevant at inference):**
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
BEST_PARAMS = {
|
| 15 |
+
"image_size": 160,
|
| 16 |
+
"model_name": "resnet18",
|
| 17 |
+
"optimizer": "adamw", # only relevant for training
|
| 18 |
+
"lr": 0.000375191537539265, # only relevant for training
|
| 19 |
+
"weight_decay": 0.000196410142442417,
|
| 20 |
+
"dropout": 0.31154239434523634, # we apply at inference iff the trainer baked dropout into the head
|
| 21 |
+
"batch_size": 128, # not used at inference (we infer one image at a time)
|
| 22 |
+
"epochs": 10, # training-only
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
CLASS_TO_IDX = {
|
| 26 |
+
"MildDemented": 0,
|
| 27 |
+
"ModerateDemented": 1,
|
| 28 |
+
"NonDemented": 2,
|
| 29 |
+
"VeryMildDemented": 3,
|
| 30 |
+
}
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Prerequisite (controller blocker)
|
| 36 |
+
|
| 37 |
+
The artifact `best_model.pt` is **not** present on this filesystem. Before any task starts:
|
| 38 |
+
|
| 39 |
+
1. Copy the file from the trainer machine to `data/processed/mri_dl_2d/best_model.pt`.
|
| 40 |
+
2. Confirm with `python -c "import torch; sd = torch.load('data/processed/mri_dl_2d/best_model.pt', map_location='cpu'); print(type(sd), list(sd.keys())[:5] if isinstance(sd, dict) else sd)"`. Two possible structures:
|
| 41 |
+
- **`state_dict` only** (most common): `dict[str, Tensor]`. Task 1 builds the resnet18 architecture and `load_state_dict`s.
|
| 42 |
+
- **Full model** (`torch.save(model, ...)`): a pickled `nn.Module`. Task 1 just calls `torch.load(...)`.
|
| 43 |
+
- The plan defaults to **state_dict** (more portable). If the file turns out to be a full model, Task 1 has a fallback branch.
|
| 44 |
+
3. Add the artifact path to `.gitignore` if it isn't already covered (`data/processed/` should already be ignored — verify).
|
| 45 |
+
|
| 46 |
+
If step 2 fails, **stop and surface to the user** — the trainer either produced a different artifact or saved with an unexpected structure.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## File structure
|
| 51 |
+
|
| 52 |
+
| Path | Responsibility |
|
| 53 |
+
|---|---|
|
| 54 |
+
| Modify `requirements.txt` | add `torch`, `torchvision`, `pillow` (CPU wheels are fine) |
|
| 55 |
+
| Create `src/models/mri_dl_2d.py` | resnet18 4-class loader + preprocessing + `predict_image()` |
|
| 56 |
+
| Create `src/models/mri_selector.py` | tiny dispatcher: `load_default()` / `predict_default()` based on `MRI_MODEL_KIND` env |
|
| 57 |
+
| Modify `src/api/routes.py` | `predict_mri` chooses between volumetric and 2D paths via the selector |
|
| 58 |
+
| Modify `src/api/schemas.py` | `MRIPredictRequest.input_path` now accepts `.png/.jpg/.nii*` (already string) — no schema change beyond a docstring tweak |
|
| 59 |
+
| Create `tests/fixtures/build_dummy_resnet18_2d.py` | helper that constructs a randomly-initialised 4-class resnet18 and saves a state_dict to a tmp path so tests don't need the real artifact |
|
| 60 |
+
| Create `tests/models/test_mri_dl_2d.py` | unit tests for the new module |
|
| 61 |
+
| Create `tests/api/test_mri_2d_route.py` | integration test through `POST /predict/mri` |
|
| 62 |
+
| Modify `README.md` | document the env var and the artifact location |
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## Tasks
|
| 67 |
+
|
| 68 |
+
### Task 0: Dependencies
|
| 69 |
+
|
| 70 |
+
**Files:**
|
| 71 |
+
- Modify: `requirements.txt`
|
| 72 |
+
|
| 73 |
+
- [ ] **Step 1:** open `requirements.txt`, append (CPU wheels — torch is large, target ~200 MB):
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
torch>=2.2,<3.0
|
| 77 |
+
torchvision>=0.17,<1.0
|
| 78 |
+
pillow>=10.0,<12.0
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
- [ ] **Step 2:** install: `pip install torch torchvision pillow`. Verify import: `python -c "import torch, torchvision; print(torch.__version__, torchvision.__version__)"`. Expect a version line, no error.
|
| 82 |
+
|
| 83 |
+
- [ ] **Step 3:** run `pytest -q` — expect the existing 295+1 baseline. No regressions before any code changes.
|
| 84 |
+
|
| 85 |
+
- [ ] **Step 4:** commit: `git commit -m "deps: add torch/torchvision/pillow for MRI DL 2D"`.
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
### Task 1: 2D model loader + preprocessing
|
| 90 |
+
|
| 91 |
+
**Files:**
|
| 92 |
+
- Create: `src/models/mri_dl_2d.py`
|
| 93 |
+
- Create: `tests/fixtures/build_dummy_resnet18_2d.py`
|
| 94 |
+
- Create: `tests/models/test_mri_dl_2d.py`
|
| 95 |
+
|
| 96 |
+
- [ ] **Step 1: Write the dummy-checkpoint fixture (so tests don't need the real .pt).**
|
| 97 |
+
|
| 98 |
+
`tests/fixtures/build_dummy_resnet18_2d.py`:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
"""Build a randomly-initialised 4-class resnet18 state_dict for tests."""
|
| 102 |
+
from __future__ import annotations
|
| 103 |
+
|
| 104 |
+
from pathlib import Path
|
| 105 |
+
|
| 106 |
+
import torch
|
| 107 |
+
from torchvision import models
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def build(path: Path) -> Path:
|
| 111 |
+
"""Save a state_dict at `path` and return the path. Idempotent."""
|
| 112 |
+
path = Path(path)
|
| 113 |
+
if path.exists():
|
| 114 |
+
return path
|
| 115 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 116 |
+
model = models.resnet18(weights=None)
|
| 117 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 4)
|
| 118 |
+
torch.save(model.state_dict(), str(path))
|
| 119 |
+
return path
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
- [ ] **Step 2: Write the failing test.**
|
| 123 |
+
|
| 124 |
+
`tests/models/test_mri_dl_2d.py`:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
"""Tests for src.models.mri_dl_2d — pretrained 4-class Alzheimer's resnet18."""
|
| 128 |
+
from __future__ import annotations
|
| 129 |
+
|
| 130 |
+
from pathlib import Path
|
| 131 |
+
|
| 132 |
+
import numpy as np
|
| 133 |
+
import pytest
|
| 134 |
+
import torch
|
| 135 |
+
from PIL import Image
|
| 136 |
+
|
| 137 |
+
from src.models import mri_dl_2d
|
| 138 |
+
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _png(path: Path, size: tuple[int, int] = (200, 200)) -> Path:
|
| 142 |
+
arr = (np.random.RandomState(0).rand(size[1], size[0], 3) * 255).astype(np.uint8)
|
| 143 |
+
Image.fromarray(arr, mode="RGB").save(str(path))
|
| 144 |
+
return path
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TestMRIDL2D:
|
| 148 |
+
def test_class_to_idx_matches_trainer(self) -> None:
|
| 149 |
+
assert mri_dl_2d.CLASS_TO_IDX == {
|
| 150 |
+
"MildDemented": 0,
|
| 151 |
+
"ModerateDemented": 1,
|
| 152 |
+
"NonDemented": 2,
|
| 153 |
+
"VeryMildDemented": 3,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
def test_idx_to_class_is_consistent(self) -> None:
|
| 157 |
+
for name, idx in mri_dl_2d.CLASS_TO_IDX.items():
|
| 158 |
+
assert mri_dl_2d.IDX_TO_CLASS[idx] == name
|
| 159 |
+
|
| 160 |
+
def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
|
| 161 |
+
with pytest.raises(FileNotFoundError, match="MRI 2D checkpoint not found"):
|
| 162 |
+
mri_dl_2d.load(tmp_path / "nope.pt")
|
| 163 |
+
|
| 164 |
+
def test_predict_image_returns_full_probs(self, tmp_path: Path) -> None:
|
| 165 |
+
ckpt = build_dummy_2d(tmp_path / "best.pt")
|
| 166 |
+
model = mri_dl_2d.load(ckpt)
|
| 167 |
+
img = _png(tmp_path / "scan.png")
|
| 168 |
+
|
| 169 |
+
result = mri_dl_2d.predict_image(model, img)
|
| 170 |
+
|
| 171 |
+
assert set(result) == {"label", "label_text", "confidence", "probabilities"}
|
| 172 |
+
assert result["label"] in {0, 1, 2, 3}
|
| 173 |
+
assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX
|
| 174 |
+
assert 0.0 <= result["confidence"] <= 1.0
|
| 175 |
+
probs = result["probabilities"]
|
| 176 |
+
assert len(probs) == 4
|
| 177 |
+
assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
|
| 178 |
+
# Each probability item exposes the trainer's class label, not "class_N".
|
| 179 |
+
assert {p["label_text"] for p in probs} == set(mri_dl_2d.CLASS_TO_IDX)
|
| 180 |
+
|
| 181 |
+
def test_predict_works_for_grayscale_input(self, tmp_path: Path) -> None:
|
| 182 |
+
ckpt = build_dummy_2d(tmp_path / "best.pt")
|
| 183 |
+
model = mri_dl_2d.load(ckpt)
|
| 184 |
+
# Single-channel grayscale, common for MRI slice exports.
|
| 185 |
+
gray = (np.random.RandomState(1).rand(180, 180) * 255).astype(np.uint8)
|
| 186 |
+
path = tmp_path / "gray.png"
|
| 187 |
+
Image.fromarray(gray, mode="L").save(str(path))
|
| 188 |
+
|
| 189 |
+
result = mri_dl_2d.predict_image(model, path)
|
| 190 |
+
assert 0.0 <= result["confidence"] <= 1.0
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
Run: `pytest tests/models/test_mri_dl_2d.py -v` → expect ImportError on `src.models.mri_dl_2d`.
|
| 194 |
+
|
| 195 |
+
- [ ] **Step 3: Minimal implementation.**
|
| 196 |
+
|
| 197 |
+
`src/models/mri_dl_2d.py`:
|
| 198 |
+
|
| 199 |
+
```python
|
| 200 |
+
"""Pretrained 2D MRI Alzheimer's classifier (resnet18, 4 classes).
|
| 201 |
+
|
| 202 |
+
Decision-layer bridge for an externally-trained PyTorch checkpoint. Loads
|
| 203 |
+
either a state_dict (default) or a full pickled model, applies the trainer's
|
| 204 |
+
preprocessing (resize image_size=160, ImageNet normalisation), and emits the
|
| 205 |
+
same dict shape as src.models.mri_model.predict_with_proba so downstream
|
| 206 |
+
code paths don't care which backend produced the prediction.
|
| 207 |
+
"""
|
| 208 |
+
from __future__ import annotations
|
| 209 |
+
|
| 210 |
+
from pathlib import Path
|
| 211 |
+
from typing import Any
|
| 212 |
+
|
| 213 |
+
import numpy as np
|
| 214 |
+
import torch
|
| 215 |
+
import torch.nn as nn
|
| 216 |
+
from PIL import Image
|
| 217 |
+
from torchvision import models, transforms
|
| 218 |
+
|
| 219 |
+
from src.core.logger import get_logger
|
| 220 |
+
|
| 221 |
+
logger = get_logger(__name__)
|
| 222 |
+
|
| 223 |
+
CLASS_TO_IDX: dict[str, int] = {
|
| 224 |
+
"MildDemented": 0,
|
| 225 |
+
"ModerateDemented": 1,
|
| 226 |
+
"NonDemented": 2,
|
| 227 |
+
"VeryMildDemented": 3,
|
| 228 |
+
}
|
| 229 |
+
IDX_TO_CLASS: dict[int, str] = {v: k for k, v in CLASS_TO_IDX.items()}
|
| 230 |
+
|
| 231 |
+
DEFAULT_IMAGE_SIZE = 160
|
| 232 |
+
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 233 |
+
_IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 234 |
+
|
| 235 |
+
# torchvision transform reused for every prediction. Constructed once at
|
| 236 |
+
# import time — no per-call allocation.
|
| 237 |
+
_TRANSFORM = transforms.Compose([
|
| 238 |
+
transforms.Resize((DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)),
|
| 239 |
+
transforms.ToTensor(),
|
| 240 |
+
transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
|
| 241 |
+
])
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _build_resnet18_4class() -> nn.Module:
|
| 245 |
+
model = models.resnet18(weights=None)
|
| 246 |
+
model.fc = nn.Linear(model.fc.in_features, len(CLASS_TO_IDX))
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load(path: Path) -> nn.Module:
|
| 251 |
+
"""Load checkpoint. Supports state_dict (preferred) or full pickled model."""
|
| 252 |
+
path = Path(path)
|
| 253 |
+
if not path.exists():
|
| 254 |
+
raise FileNotFoundError(f"MRI 2D checkpoint not found: {path}")
|
| 255 |
+
obj = torch.load(str(path), map_location="cpu", weights_only=False)
|
| 256 |
+
if isinstance(obj, nn.Module):
|
| 257 |
+
model = obj
|
| 258 |
+
else:
|
| 259 |
+
model = _build_resnet18_4class()
|
| 260 |
+
# Strip 'module.' prefix if the trainer used DataParallel / DDP.
|
| 261 |
+
clean = {k.removeprefix("module."): v for k, v in obj.items()}
|
| 262 |
+
model.load_state_dict(clean, strict=True)
|
| 263 |
+
model.eval()
|
| 264 |
+
return model
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def predict_image(model: nn.Module, image_path: Path) -> dict[str, Any]:
|
| 268 |
+
"""Run inference on one image. Output shape mirrors mri_model.predict_with_proba."""
|
| 269 |
+
image_path = Path(image_path)
|
| 270 |
+
if not image_path.exists():
|
| 271 |
+
raise FileNotFoundError(f"MRI image not found: {image_path}")
|
| 272 |
+
img = Image.open(str(image_path)).convert("RGB")
|
| 273 |
+
tensor = _TRANSFORM(img).unsqueeze(0) # (1, 3, 160, 160)
|
| 274 |
+
|
| 275 |
+
with torch.inference_mode():
|
| 276 |
+
logits = model(tensor)
|
| 277 |
+
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
|
| 278 |
+
|
| 279 |
+
label_idx = int(np.argmax(probs))
|
| 280 |
+
return {
|
| 281 |
+
"label": label_idx,
|
| 282 |
+
"label_text": IDX_TO_CLASS[label_idx],
|
| 283 |
+
"confidence": float(probs[label_idx]),
|
| 284 |
+
"probabilities": [
|
| 285 |
+
{"label": i, "label_text": IDX_TO_CLASS[i], "probability": float(p)}
|
| 286 |
+
for i, p in enumerate(probs)
|
| 287 |
+
],
|
| 288 |
+
}
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
Run: `pytest tests/models/test_mri_dl_2d.py -v` → expect 5 passed.
|
| 292 |
+
|
| 293 |
+
- [ ] **Step 4:** `pytest -q` → expect 295+1 baseline + 5 new = ~300 passed.
|
| 294 |
+
|
| 295 |
+
- [ ] **Step 5:** commit:
|
| 296 |
+
|
| 297 |
+
```bash
|
| 298 |
+
git add src/models/mri_dl_2d.py tests/fixtures/build_dummy_resnet18_2d.py tests/models/test_mri_dl_2d.py
|
| 299 |
+
git commit -m "feat(models): add 2D resnet18 4-class Alzheimer's MRI inference module"
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
---
|
| 303 |
+
|
| 304 |
+
### Task 2: Selector for 3D vs 2D
|
| 305 |
+
|
| 306 |
+
**Files:**
|
| 307 |
+
- Create: `src/models/mri_selector.py`
|
| 308 |
+
- Create: `tests/models/test_mri_selector.py`
|
| 309 |
+
|
| 310 |
+
- [ ] **Step 1: Failing test.**
|
| 311 |
+
|
| 312 |
+
`tests/models/test_mri_selector.py`:
|
| 313 |
+
|
| 314 |
+
```python
|
| 315 |
+
"""Tests for src.models.mri_selector — env-var-driven 2D / 3D dispatch."""
|
| 316 |
+
from __future__ import annotations
|
| 317 |
+
|
| 318 |
+
from pathlib import Path
|
| 319 |
+
|
| 320 |
+
import pytest
|
| 321 |
+
|
| 322 |
+
from src.models import mri_selector
|
| 323 |
+
from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_3d
|
| 324 |
+
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
_FIXTURE_MRI = Path(__file__).resolve().parents[1] / "fixtures" / "mri_sample" / "subject_0.nii.gz"
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class TestSelector:
|
| 331 |
+
def test_default_kind_is_volumetric(self, monkeypatch) -> None:
|
| 332 |
+
monkeypatch.delenv("MRI_MODEL_KIND", raising=False)
|
| 333 |
+
assert mri_selector.current_kind() == "volumetric_onnx"
|
| 334 |
+
|
| 335 |
+
def test_explicit_2d_selection(self, monkeypatch) -> None:
|
| 336 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
|
| 337 |
+
assert mri_selector.current_kind() == "resnet18_2d"
|
| 338 |
+
|
| 339 |
+
def test_unknown_kind_raises(self, monkeypatch) -> None:
|
| 340 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "neural_net_supreme")
|
| 341 |
+
with pytest.raises(ValueError, match="unknown MRI_MODEL_KIND"):
|
| 342 |
+
mri_selector.current_kind()
|
| 343 |
+
|
| 344 |
+
def test_predict_routes_to_volumetric(self, monkeypatch, tmp_path) -> None:
|
| 345 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "volumetric_onnx")
|
| 346 |
+
artifact = build_dummy_3d(tmp_path / "vol.onnx")
|
| 347 |
+
result = mri_selector.predict(
|
| 348 |
+
input_path=_FIXTURE_MRI,
|
| 349 |
+
checkpoint_path=artifact,
|
| 350 |
+
target_shape=(8, 8, 8),
|
| 351 |
+
label_names=("control", "abnormal"),
|
| 352 |
+
)
|
| 353 |
+
assert result["label_text"] in {"control", "abnormal"}
|
| 354 |
+
|
| 355 |
+
def test_predict_routes_to_2d(self, monkeypatch, tmp_path) -> None:
|
| 356 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
|
| 357 |
+
artifact = build_dummy_2d(tmp_path / "best.pt")
|
| 358 |
+
# Build a tiny PNG.
|
| 359 |
+
from PIL import Image
|
| 360 |
+
import numpy as np
|
| 361 |
+
img_path = tmp_path / "scan.png"
|
| 362 |
+
Image.fromarray((np.random.RandomState(0).rand(160, 160, 3) * 255).astype("uint8")).save(str(img_path))
|
| 363 |
+
result = mri_selector.predict(
|
| 364 |
+
input_path=img_path,
|
| 365 |
+
checkpoint_path=artifact,
|
| 366 |
+
)
|
| 367 |
+
assert result["label_text"] in mri_selector.label_names_for_kind("resnet18_2d")
|
| 368 |
+
```
|
| 369 |
+
|
| 370 |
+
Run: `pytest tests/models/test_mri_selector.py -v` → ImportError.
|
| 371 |
+
|
| 372 |
+
- [ ] **Step 2: Minimal impl.**
|
| 373 |
+
|
| 374 |
+
`src/models/mri_selector.py`:
|
| 375 |
+
|
| 376 |
+
```python
|
| 377 |
+
"""Env-var-driven dispatch between volumetric ONNX and 2D resnet18 MRI models."""
|
| 378 |
+
from __future__ import annotations
|
| 379 |
+
|
| 380 |
+
import os
|
| 381 |
+
from pathlib import Path
|
| 382 |
+
from typing import Any
|
| 383 |
+
|
| 384 |
+
from src.core.logger import get_logger
|
| 385 |
+
from src.models import mri_dl_2d, mri_model
|
| 386 |
+
|
| 387 |
+
logger = get_logger(__name__)
|
| 388 |
+
|
| 389 |
+
VALID_KINDS = ("volumetric_onnx", "resnet18_2d")
|
| 390 |
+
_DEFAULT_KIND = "volumetric_onnx"
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def current_kind() -> str:
|
| 394 |
+
kind = os.environ.get("MRI_MODEL_KIND", _DEFAULT_KIND)
|
| 395 |
+
if kind not in VALID_KINDS:
|
| 396 |
+
raise ValueError(f"unknown MRI_MODEL_KIND={kind!r}; expected one of {VALID_KINDS}")
|
| 397 |
+
return kind
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def label_names_for_kind(kind: str) -> tuple[str, ...]:
|
| 401 |
+
if kind == "resnet18_2d":
|
| 402 |
+
return tuple(mri_dl_2d.IDX_TO_CLASS[i] for i in range(len(mri_dl_2d.CLASS_TO_IDX)))
|
| 403 |
+
return mri_model.DEFAULT_LABEL_NAMES
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def predict(
|
| 407 |
+
input_path: Path,
|
| 408 |
+
checkpoint_path: Path,
|
| 409 |
+
target_shape: tuple[int, int, int] | None = None,
|
| 410 |
+
label_names: tuple[str, ...] | None = None,
|
| 411 |
+
) -> dict[str, Any]:
|
| 412 |
+
"""Run the active MRI model on one input. Returns the unified prediction dict."""
|
| 413 |
+
kind = current_kind()
|
| 414 |
+
logger.info("dispatching MRI prediction kind=%s input=%s", kind, input_path)
|
| 415 |
+
if kind == "resnet18_2d":
|
| 416 |
+
model = mri_dl_2d.load(checkpoint_path)
|
| 417 |
+
return mri_dl_2d.predict_image(model, input_path)
|
| 418 |
+
model = mri_model.load(checkpoint_path)
|
| 419 |
+
return mri_model.predict_nifti(
|
| 420 |
+
model,
|
| 421 |
+
input_path,
|
| 422 |
+
target_shape=target_shape or mri_model.DEFAULT_TARGET_SHAPE,
|
| 423 |
+
label_names=label_names,
|
| 424 |
+
)
|
| 425 |
+
```
|
| 426 |
+
|
| 427 |
+
Run tests → 5 passed.
|
| 428 |
+
|
| 429 |
+
- [ ] **Step 3:** `pytest -q` → ~305 passed.
|
| 430 |
+
|
| 431 |
+
- [ ] **Step 4:** commit: `feat(models): selector dispatch for volumetric vs 2D MRI models`.
|
| 432 |
+
|
| 433 |
+
---
|
| 434 |
+
|
| 435 |
+
### Task 3: Wire into `POST /predict/mri`
|
| 436 |
+
|
| 437 |
+
**Files:**
|
| 438 |
+
- Modify: `src/api/routes.py`
|
| 439 |
+
- Modify: `src/api/schemas.py` (docstring only — `input_path` now optionally accepts a 2D image)
|
| 440 |
+
- Create: `tests/api/test_mri_2d_route.py`
|
| 441 |
+
|
| 442 |
+
- [ ] **Step 1: Failing test.**
|
| 443 |
+
|
| 444 |
+
`tests/api/test_mri_2d_route.py`:
|
| 445 |
+
|
| 446 |
+
```python
|
| 447 |
+
"""Integration: POST /predict/mri with MRI_MODEL_KIND=resnet18_2d."""
|
| 448 |
+
from __future__ import annotations
|
| 449 |
+
|
| 450 |
+
import os
|
| 451 |
+
from pathlib import Path
|
| 452 |
+
|
| 453 |
+
import numpy as np
|
| 454 |
+
import pytest
|
| 455 |
+
from fastapi.testclient import TestClient
|
| 456 |
+
from PIL import Image
|
| 457 |
+
|
| 458 |
+
from src.api.main import app
|
| 459 |
+
from tests.fixtures.build_dummy_resnet18_2d import build as build_dummy_2d
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@pytest.fixture()
|
| 463 |
+
def client_2d(monkeypatch, tmp_path):
|
| 464 |
+
monkeypatch.setenv("MRI_MODEL_KIND", "resnet18_2d")
|
| 465 |
+
ckpt = build_dummy_2d(tmp_path / "best.pt")
|
| 466 |
+
monkeypatch.setenv("MRI_MODEL_PATH_2D", str(ckpt))
|
| 467 |
+
return TestClient(app)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def test_predict_mri_2d_happy_path(client_2d, tmp_path):
|
| 471 |
+
# Tiny RGB PNG.
|
| 472 |
+
img_path = tmp_path / "scan.png"
|
| 473 |
+
Image.fromarray((np.random.RandomState(0).rand(170, 170, 3) * 255).astype("uint8")).save(str(img_path))
|
| 474 |
+
|
| 475 |
+
r = client_2d.post("/predict/mri", json={"input_path": str(img_path)})
|
| 476 |
+
assert r.status_code == 200, r.text
|
| 477 |
+
data = r.json()
|
| 478 |
+
assert data["label_text"] in {
|
| 479 |
+
"MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented",
|
| 480 |
+
}
|
| 481 |
+
assert 0.0 <= data["confidence"] <= 1.0
|
| 482 |
+
assert len(data["probabilities"]) == 4
|
| 483 |
+
```
|
| 484 |
+
|
| 485 |
+
Run → expect 500 (route hardcoded to volumetric path) or schema error.
|
| 486 |
+
|
| 487 |
+
- [ ] **Step 2: Modify the route handler.**
|
| 488 |
+
|
| 489 |
+
In `src/api/routes.py`, find `predict_mri` (around line 318). Change the body to dispatch via the selector:
|
| 490 |
+
|
| 491 |
+
```python
|
| 492 |
+
@predict_router.post("/mri", response_model=MRIPredictResponse)
|
| 493 |
+
def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
|
| 494 |
+
from src.models import mri_selector
|
| 495 |
+
|
| 496 |
+
kind = mri_selector.current_kind()
|
| 497 |
+
if kind == "resnet18_2d":
|
| 498 |
+
ckpt = Path(os.environ.get("MRI_MODEL_PATH_2D", "data/processed/mri_dl_2d/best_model.pt"))
|
| 499 |
+
result = mri_selector.predict(input_path=Path(req.input_path), checkpoint_path=ckpt)
|
| 500 |
+
model_path = str(ckpt)
|
| 501 |
+
else:
|
| 502 |
+
ckpt = _mri_model_path()
|
| 503 |
+
result = mri_selector.predict(
|
| 504 |
+
input_path=Path(req.input_path),
|
| 505 |
+
checkpoint_path=ckpt,
|
| 506 |
+
target_shape=tuple(req.target_shape),
|
| 507 |
+
label_names=tuple(req.label_names) if req.label_names else None,
|
| 508 |
+
)
|
| 509 |
+
model_path = str(ckpt)
|
| 510 |
+
|
| 511 |
+
return MRIPredictResponse(
|
| 512 |
+
**result,
|
| 513 |
+
input_path=str(req.input_path),
|
| 514 |
+
model_path=model_path,
|
| 515 |
+
)
|
| 516 |
+
```
|
| 517 |
+
|
| 518 |
+
You'll need to add `import os` and `from pathlib import Path` if not already present at the top of the file (Path likely already is). Keep the existing `_mri_model_path()` helper.
|
| 519 |
+
|
| 520 |
+
- [ ] **Step 3:** Update `src/api/schemas.py`. The class `MRIPredictRequest.input_path` description currently says "Path to one .nii or .nii.gz MRI volume". Change to:
|
| 521 |
+
|
| 522 |
+
```python
|
| 523 |
+
input_path: str = Field(..., description="Path to MRI input. With MRI_MODEL_KIND=volumetric_onnx (default), expects a .nii/.nii.gz volume. With MRI_MODEL_KIND=resnet18_2d, expects a 2D image (.png/.jpg).")
|
| 524 |
+
```
|
| 525 |
+
|
| 526 |
+
- [ ] **Step 4:** `pytest tests/api/test_mri_2d_route.py -v` → expect 1 passed.
|
| 527 |
+
|
| 528 |
+
- [ ] **Step 5:** `pytest -q` → expect no regressions vs the prior baseline + 1 new.
|
| 529 |
+
|
| 530 |
+
- [ ] **Step 6:** commit: `feat(api): dispatch /predict/mri via MRI_MODEL_KIND env var`.
|
| 531 |
+
|
| 532 |
+
---
|
| 533 |
+
|
| 534 |
+
### Task 4: Sanity check on real artifact (one-shot, runs only when artifact is present)
|
| 535 |
+
|
| 536 |
+
**Files:**
|
| 537 |
+
- Create: `tests/models/test_mri_dl_2d_real.py`
|
| 538 |
+
|
| 539 |
+
This test is opt-in via env: only runs if `data/processed/mri_dl_2d/best_model.pt` is present. Catches the "trainer used different class index order" bug.
|
| 540 |
+
|
| 541 |
+
- [ ] **Step 1: Test.**
|
| 542 |
+
|
| 543 |
+
```python
|
| 544 |
+
"""Real-artifact sanity test. Skipped unless the checkpoint is present."""
|
| 545 |
+
from __future__ import annotations
|
| 546 |
+
|
| 547 |
+
from pathlib import Path
|
| 548 |
+
|
| 549 |
+
import numpy as np
|
| 550 |
+
import pytest
|
| 551 |
+
from PIL import Image
|
| 552 |
+
|
| 553 |
+
from src.models import mri_dl_2d
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
REAL_CKPT = Path("data/processed/mri_dl_2d/best_model.pt")
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@pytest.mark.skipif(not REAL_CKPT.exists(), reason="real MRI checkpoint not present")
|
| 560 |
+
def test_real_checkpoint_loads_and_predicts(tmp_path):
|
| 561 |
+
model = mri_dl_2d.load(REAL_CKPT)
|
| 562 |
+
arr = (np.random.RandomState(0).rand(170, 170, 3) * 255).astype(np.uint8)
|
| 563 |
+
img = tmp_path / "scan.png"
|
| 564 |
+
Image.fromarray(arr).save(str(img))
|
| 565 |
+
result = mri_dl_2d.predict_image(model, img)
|
| 566 |
+
|
| 567 |
+
assert result["label_text"] in mri_dl_2d.CLASS_TO_IDX
|
| 568 |
+
# Probabilities sum to 1.
|
| 569 |
+
s = sum(p["probability"] for p in result["probabilities"])
|
| 570 |
+
assert abs(s - 1.0) < 1e-5
|
| 571 |
+
```
|
| 572 |
+
|
| 573 |
+
- [ ] **Step 2:** `pytest tests/models/test_mri_dl_2d_real.py -v` → if no real checkpoint, **skipped** (expected). When you do drop the artifact in, `pytest -q` will run it.
|
| 574 |
+
|
| 575 |
+
- [ ] **Step 3:** commit: `test(models): real-artifact sanity for MRI DL 2D (skips when absent)`.
|
| 576 |
+
|
| 577 |
+
---
|
| 578 |
+
|
| 579 |
+
### Task 5: Streamlit + README
|
| 580 |
+
|
| 581 |
+
**Files:**
|
| 582 |
+
- Modify: `src/frontend/app.py` (the MRI Predict tab — add `MRI_MODEL_KIND` indicator + accept image upload when 2D is active)
|
| 583 |
+
- Modify: `README.md`
|
| 584 |
+
|
| 585 |
+
- [ ] **Step 1:** In `src/frontend/app.py`, find the MRI predict section (likely near line 1330 — search for `mri_predict_d`). Add a small caption above the existing UI:
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
mri_kind = os.environ.get("MRI_MODEL_KIND", "volumetric_onnx")
|
| 589 |
+
st.caption(f"Active MRI model: `{mri_kind}` (set `MRI_MODEL_KIND` env to switch)")
|
| 590 |
+
```
|
| 591 |
+
|
| 592 |
+
If `mri_kind == "resnet18_2d"`, swap the file picker hint from `.nii/.nii.gz` to `.png/.jpg`. The existing `target_shape` widgets become irrelevant in 2D mode — wrap them in `if mri_kind == "volumetric_onnx":`.
|
| 593 |
+
|
| 594 |
+
- [ ] **Step 2:** README update. Add a paragraph under the existing MRI section:
|
| 595 |
+
|
| 596 |
+
```markdown
|
| 597 |
+
### MRI Deep-Learning Backends
|
| 598 |
+
|
| 599 |
+
The MRI prediction route supports two backends, selected via env:
|
| 600 |
+
|
| 601 |
+
- `MRI_MODEL_KIND=volumetric_onnx` (default). Loads an ONNX volumetric model from `MRI_MODEL_PATH` (default `data/processed/mri_model.onnx`). Input: `.nii` / `.nii.gz`.
|
| 602 |
+
- `MRI_MODEL_KIND=resnet18_2d`. Loads a PyTorch state_dict from `MRI_MODEL_PATH_2D` (default `data/processed/mri_dl_2d/best_model.pt`). Input: 2D image (`.png` / `.jpg`). Classes: `MildDemented`, `ModerateDemented`, `NonDemented`, `VeryMildDemented`.
|
| 603 |
+
|
| 604 |
+
Switch backends without restarting workers — env is read on each request.
|
| 605 |
+
```
|
| 606 |
+
|
| 607 |
+
- [ ] **Step 3:** `pytest -q` → no regressions. (Streamlit code is not unit-tested in this repo — manual smoke at the end is fine.)
|
| 608 |
+
|
| 609 |
+
- [ ] **Step 4:** commit: `feat(frontend): expose MRI_MODEL_KIND in MRI predict tab; doc backends`.
|
| 610 |
+
|
| 611 |
+
---
|
| 612 |
+
|
| 613 |
+
## Self-review checklist
|
| 614 |
+
|
| 615 |
+
1. **Spec coverage.** Trainer's BEST_PARAMS that matter at inference: `image_size=160`, `model_name=resnet18`, class index map. All locked in via Task 1 constants. The other params (`lr`, `epochs`, `batch_size`) are training-only and intentionally not surfaced.
|
| 616 |
+
2. **Independence.** No coupling to BBB. The new module imports only stdlib + torch + torchvision + Pillow + numpy + the existing `src/core/logger`.
|
| 617 |
+
3. **Sanity test for class-order drift.** Task 4 runs only when the real checkpoint is dropped in. If the trainer used `ImageFolder`'s alphabetical order (`MildDemented=0, ModerateDemented=1, NonDemented=2, VeryMildDemented=3` — same as ours, by luck), it passes. If they used a different order, the user must update `CLASS_TO_IDX` in `mri_dl_2d.py`.
|
| 618 |
+
4. **No placeholders.** Every step contains the full code.
|
| 619 |
+
5. **Hackathon-grade.** No XAI / saliency-map / Grad-CAM scope creep. That's a separate sub-plan if the demo demands it.
|
| 620 |
+
|
| 621 |
+
---
|
| 622 |
+
|
| 623 |
+
## Execution handoff
|
| 624 |
+
|
| 625 |
+
Save and choose: subagent-driven (recommended) or inline executing-plans.
|
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OASIS Tabular Classifier — Fusion Integration Plan
|
| 2 |
+
|
| 3 |
+
> **For agentic workers:** REQUIRED SUB-SKILL: `superpowers:subagent-driven-development`. TDD throughout.
|
| 4 |
+
|
| 5 |
+
## ⚠️ Important context — read before executing
|
| 6 |
+
|
| 7 |
+
The user said "I have the pretrained model for eeg, integrate it into the eeg pipeline. its the ipynb file named detecting-early-alzheimers...".
|
| 8 |
+
|
| 9 |
+
The notebook (`/Users/mertgungor/Downloads/rag/detecting-early-alzheimer-s (1).ipynb`) is **NOT an EEG model**. It is an sklearn ensemble (LogReg / SVM / DT / RF / AdaBoost) trained on the OASIS longitudinal **tabular** dataset — features are MMSE, eTIV, nWBV, ASF, EDUC, SES, M/F, Age. Zero EEG signal processing. Zero saved model artifact (the notebook trains in-memory only).
|
| 10 |
+
|
| 11 |
+
This plan therefore has **two branches**. Pick one with the user before executing.
|
| 12 |
+
|
| 13 |
+
### Branch 3a — Train + integrate the OASIS *tabular* classifier as a fusion feature
|
| 14 |
+
|
| 15 |
+
We re-train the best variant (Random Forest, AUC 84.4 % per the notebook) from the OASIS CSV, save a `joblib` artifact, and expose it as a fusion-engine modality named `tabular_oasis`. The fusion engine already handles arbitrary modality keys; this plugs in cleanly.
|
| 16 |
+
|
| 17 |
+
**Demo value:** When a doctor has only OASIS-style biomarkers (MMSE / eTIV / nWBV / ASF / Age / EDUC / SES / M/F) but no MRI image, the fusion engine still produces an Alzheimer's confidence with attribution.
|
| 18 |
+
|
| 19 |
+
### Branch 3b — User has a real EEG model elsewhere
|
| 20 |
+
|
| 21 |
+
If the user can point us to a checkpoint that consumes raw FIF / EDF EEG data (e.g., a `.pt`, `.pth`, `.h5`, `.onnx`, or `.joblib` file) and emits Alzheimer's class probabilities, this plan is rewritten around that artifact: signature, expected input shape, label order. We replace `src/models/eeg_model.py` (currently absent — `eeg_pipeline.py` only does signal processing) with a new module similar to `mri_dl_2d.py`.
|
| 22 |
+
|
| 23 |
+
**The user must pick a branch** before any task starts. The default below is **Branch 3a**, because the notebook is what's actually on disk.
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## Branch 3a (default): OASIS tabular classifier as fusion modality
|
| 28 |
+
|
| 29 |
+
**Goal.** Save a Random Forest trained on OASIS biomarkers; wire it into the fusion engine as a new modality `tabular_oasis`. The doctor enters MMSE/eTIV/nWBV/ASF (fusion already takes MMSE; this extends to the other three) and gets an Alzheimer's signal that flows through the existing logit/sigmoid combiner.
|
| 30 |
+
|
| 31 |
+
**Architecture.** New module `src/models/tabular_oasis.py` trains-or-loads a `joblib`-pickled `Pipeline(scaler -> RandomForestClassifier)`. The fusion engine grows one entry in `_CLINICAL_FNS` (or, more cleanly, a sibling `_TABULAR_FNS`) so the model's class probability for `Demented=1` becomes a signed signal. New API route `POST /predict/tabular_oasis` lets the frontend call it directly. All optional — if the OASIS CSV is absent, the module degrades gracefully and fusion ignores the modality.
|
| 32 |
+
|
| 33 |
+
**Tech stack.** scikit-learn (already in deps), pandas, joblib (likely in deps via sklearn).
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Prerequisite (controller blocker)
|
| 38 |
+
|
| 39 |
+
The OASIS dataset is not in this repo. Two acquisition options:
|
| 40 |
+
|
| 41 |
+
1. **Download from Kaggle** (https://www.kaggle.com/datasets/jboysen/mri-and-alzheimers, file `oasis_longitudinal.csv`). Save to `data/external/oasis_longitudinal.csv`. Gitignore (already covered by `data/external_rag/` if you broaden it; otherwise add `data/external/`).
|
| 42 |
+
|
| 43 |
+
2. **Use a local copy** if the user already downloaded it for the notebook. Same destination.
|
| 44 |
+
|
| 45 |
+
If the dataset is unavailable, **stop and surface to the user**. The classifier cannot be trained without it; we will not fabricate synthetic OASIS-shaped data for a clinical demo.
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## File structure
|
| 50 |
+
|
| 51 |
+
| Path | Responsibility |
|
| 52 |
+
|---|---|
|
| 53 |
+
| Modify `requirements.txt` | confirm `joblib` (sklearn pulls it transitively but pin explicitly is safer) |
|
| 54 |
+
| Modify `.gitignore` | ensure `data/external/` is ignored |
|
| 55 |
+
| Create `src/models/tabular_oasis.py` | train + persist + load + predict the OASIS RF classifier |
|
| 56 |
+
| Create `scripts/train_oasis.py` | one-shot CLI: trains and saves the model artifact |
|
| 57 |
+
| Modify `src/fusion/types.py` | extend `ClinicalScores` with `etiv`, `nwbv`, `asf`, `educ`, `ses`, `is_male` |
|
| 58 |
+
| Modify `src/fusion/weights.py` | add `tabular_oasis` weight key for `alzheimers` |
|
| 59 |
+
| Modify `src/fusion/engine.py` | add `tabular_oasis` to the modality dispatch |
|
| 60 |
+
| Modify `src/api/routes.py` | new route `POST /predict/tabular_oasis` |
|
| 61 |
+
| Modify `src/api/schemas.py` | request/response for the new route |
|
| 62 |
+
| Create `tests/models/test_tabular_oasis.py` | training + persistence + prediction tests |
|
| 63 |
+
| Create `tests/fixtures/build_synthetic_oasis.py` | synthetic OASIS-shaped CSV for tests (clearly labelled non-clinical) |
|
| 64 |
+
| Create `tests/fusion/test_tabular_oasis_modality.py` | fusion-side integration |
|
| 65 |
+
| Create `tests/api/test_tabular_oasis_route.py` | API integration |
|
| 66 |
+
| Modify `README.md` | document the modality + how to acquire the OASIS CSV |
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Tasks
|
| 71 |
+
|
| 72 |
+
### Task 0: Deps + ignore
|
| 73 |
+
|
| 74 |
+
**Files:** `requirements.txt`, `.gitignore`
|
| 75 |
+
|
| 76 |
+
- [ ] **Step 1:** verify `joblib` and `pandas` are in `requirements.txt`. `pandas` already is (used by every pipeline). Add `joblib>=1.3,<2.0` if not pinned.
|
| 77 |
+
|
| 78 |
+
- [ ] **Step 2:** `.gitignore` should cover `data/external/`. Add it if needed.
|
| 79 |
+
|
| 80 |
+
- [ ] **Step 3:** `pytest -q` baseline. Commit: `chore(oasis): pin joblib; gitignore external dataset dir`.
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
### Task 1: Training + persistence module
|
| 85 |
+
|
| 86 |
+
**Files:**
|
| 87 |
+
- Create: `src/models/tabular_oasis.py`
|
| 88 |
+
- Create: `scripts/train_oasis.py`
|
| 89 |
+
- Create: `tests/fixtures/build_synthetic_oasis.py`
|
| 90 |
+
- Create: `tests/models/test_tabular_oasis.py`
|
| 91 |
+
|
| 92 |
+
- [ ] **Step 1: Synthetic-fixture helper** (clearly synthetic — never confused with real clinical data):
|
| 93 |
+
|
| 94 |
+
`tests/fixtures/build_synthetic_oasis.py`:
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
"""Build a synthetic OASIS-shaped CSV for tests. NON-CLINICAL data."""
|
| 98 |
+
from __future__ import annotations
|
| 99 |
+
|
| 100 |
+
from pathlib import Path
|
| 101 |
+
|
| 102 |
+
import numpy as np
|
| 103 |
+
import pandas as pd
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def build(path: Path, n: int = 200, seed: int = 42) -> Path:
|
| 107 |
+
"""Save a synthetic CSV at `path` with the columns the trainer expects."""
|
| 108 |
+
path = Path(path)
|
| 109 |
+
if path.exists():
|
| 110 |
+
return path
|
| 111 |
+
rng = np.random.default_rng(seed)
|
| 112 |
+
n_dem = n // 2
|
| 113 |
+
|
| 114 |
+
# Demented half — lower MMSE, higher CDR, smaller nWBV.
|
| 115 |
+
dem = pd.DataFrame({
|
| 116 |
+
"Group": ["Demented"] * n_dem,
|
| 117 |
+
"M/F": rng.choice(["M", "F"], n_dem),
|
| 118 |
+
"Age": rng.integers(70, 95, n_dem),
|
| 119 |
+
"EDUC": rng.integers(8, 18, n_dem),
|
| 120 |
+
"SES": rng.integers(1, 5, n_dem),
|
| 121 |
+
"MMSE": rng.integers(15, 26, n_dem),
|
| 122 |
+
"CDR": rng.choice([0.5, 1.0], n_dem),
|
| 123 |
+
"eTIV": rng.integers(1200, 1700, n_dem),
|
| 124 |
+
"nWBV": rng.uniform(0.65, 0.74, n_dem),
|
| 125 |
+
"ASF": rng.uniform(1.0, 1.4, n_dem),
|
| 126 |
+
"Visit": 1,
|
| 127 |
+
"Hand": "R",
|
| 128 |
+
})
|
| 129 |
+
nondem = pd.DataFrame({
|
| 130 |
+
"Group": ["Nondemented"] * (n - n_dem),
|
| 131 |
+
"M/F": rng.choice(["M", "F"], n - n_dem),
|
| 132 |
+
"Age": rng.integers(60, 90, n - n_dem),
|
| 133 |
+
"EDUC": rng.integers(10, 22, n - n_dem),
|
| 134 |
+
"SES": rng.integers(1, 5, n - n_dem),
|
| 135 |
+
"MMSE": rng.integers(26, 31, n - n_dem),
|
| 136 |
+
"CDR": rng.choice([0.0], n - n_dem),
|
| 137 |
+
"eTIV": rng.integers(1300, 1900, n - n_dem),
|
| 138 |
+
"nWBV": rng.uniform(0.70, 0.83, n - n_dem),
|
| 139 |
+
"ASF": rng.uniform(0.9, 1.5, n - n_dem),
|
| 140 |
+
"Visit": 1,
|
| 141 |
+
"Hand": "R",
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
pd.concat([dem, nondem], ignore_index=True).to_csv(path, index=False)
|
| 145 |
+
return path
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
- [ ] **Step 2: Failing test.**
|
| 149 |
+
|
| 150 |
+
`tests/models/test_tabular_oasis.py`:
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
"""Tests for src.models.tabular_oasis."""
|
| 154 |
+
from __future__ import annotations
|
| 155 |
+
|
| 156 |
+
from pathlib import Path
|
| 157 |
+
|
| 158 |
+
import pytest
|
| 159 |
+
|
| 160 |
+
from src.models import tabular_oasis
|
| 161 |
+
from tests.fixtures.build_synthetic_oasis import build as build_synth
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class TestTrainAndPredict:
|
| 165 |
+
def test_train_persists_loadable_artifact(self, tmp_path: Path) -> None:
|
| 166 |
+
csv = build_synth(tmp_path / "oasis.csv")
|
| 167 |
+
artifact = tabular_oasis.train_from_csv(csv, tmp_path / "rf.joblib")
|
| 168 |
+
assert artifact.exists()
|
| 169 |
+
loaded = tabular_oasis.load(artifact)
|
| 170 |
+
assert hasattr(loaded, "predict_proba")
|
| 171 |
+
|
| 172 |
+
def test_predict_returns_full_dict(self, tmp_path: Path) -> None:
|
| 173 |
+
csv = build_synth(tmp_path / "oasis.csv")
|
| 174 |
+
artifact = tabular_oasis.train_from_csv(csv, tmp_path / "rf.joblib")
|
| 175 |
+
model = tabular_oasis.load(artifact)
|
| 176 |
+
out = tabular_oasis.predict_one(model, {
|
| 177 |
+
"is_male": 1, "age": 80, "educ": 10, "ses": 3.0,
|
| 178 |
+
"mmse": 18.0, "etiv": 1500.0, "nwbv": 0.68, "asf": 1.2,
|
| 179 |
+
})
|
| 180 |
+
assert set(out) == {"label", "label_text", "confidence", "probabilities"}
|
| 181 |
+
assert out["label"] in {0, 1}
|
| 182 |
+
assert out["label_text"] in {"Nondemented", "Demented"}
|
| 183 |
+
assert 0.0 <= out["confidence"] <= 1.0
|
| 184 |
+
probs = out["probabilities"]
|
| 185 |
+
assert len(probs) == 2
|
| 186 |
+
assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
|
| 187 |
+
|
| 188 |
+
def test_predict_with_synthetic_demented_profile_yields_demented_label(self, tmp_path: Path) -> None:
|
| 189 |
+
# The synthetic data has clean separation, so a clearly-demented profile
|
| 190 |
+
# (MMSE=15, low nWBV, age 88) should classify as Demented.
|
| 191 |
+
csv = build_synth(tmp_path / "oasis.csv")
|
| 192 |
+
artifact = tabular_oasis.train_from_csv(csv, tmp_path / "rf.joblib")
|
| 193 |
+
model = tabular_oasis.load(artifact)
|
| 194 |
+
out = tabular_oasis.predict_one(model, {
|
| 195 |
+
"is_male": 1, "age": 88, "educ": 8, "ses": 3.0,
|
| 196 |
+
"mmse": 15.0, "etiv": 1300.0, "nwbv": 0.66, "asf": 1.3,
|
| 197 |
+
})
|
| 198 |
+
assert out["label_text"] == "Demented"
|
| 199 |
+
|
| 200 |
+
def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
|
| 201 |
+
with pytest.raises(FileNotFoundError, match="OASIS classifier artifact not found"):
|
| 202 |
+
tabular_oasis.load(tmp_path / "missing.joblib")
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
Run → ImportError.
|
| 206 |
+
|
| 207 |
+
- [ ] **Step 3: Minimal impl.**
|
| 208 |
+
|
| 209 |
+
`src/models/tabular_oasis.py`:
|
| 210 |
+
|
| 211 |
+
```python
|
| 212 |
+
"""OASIS tabular Alzheimer's classifier — Random Forest with full pipeline."""
|
| 213 |
+
from __future__ import annotations
|
| 214 |
+
|
| 215 |
+
from pathlib import Path
|
| 216 |
+
from typing import Any
|
| 217 |
+
|
| 218 |
+
import joblib
|
| 219 |
+
import numpy as np
|
| 220 |
+
import pandas as pd
|
| 221 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 222 |
+
from sklearn.pipeline import Pipeline
|
| 223 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 224 |
+
|
| 225 |
+
from src.core.logger import get_logger
|
| 226 |
+
|
| 227 |
+
logger = get_logger(__name__)
|
| 228 |
+
|
| 229 |
+
FEATURE_ORDER: tuple[str, ...] = (
|
| 230 |
+
"is_male", "age", "educ", "ses", "mmse", "etiv", "nwbv", "asf",
|
| 231 |
+
)
|
| 232 |
+
LABEL_NAMES: tuple[str, ...] = ("Nondemented", "Demented")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _df_from_oasis_csv(csv_path: Path) -> tuple[pd.DataFrame, pd.Series]:
|
| 236 |
+
"""Replicate the notebook's preprocessing: first visit only, M/F encoded,
|
| 237 |
+
Converted-as-Demented, drop unused columns, median-impute SES on EDUC."""
|
| 238 |
+
df = pd.read_csv(csv_path)
|
| 239 |
+
df = df.loc[df["Visit"] == 1].reset_index(drop=True)
|
| 240 |
+
df["M/F"] = df["M/F"].replace({"F": 0, "M": 1})
|
| 241 |
+
df["Group"] = df["Group"].replace({"Converted": "Demented"}).replace(
|
| 242 |
+
{"Demented": 1, "Nondemented": 0}
|
| 243 |
+
)
|
| 244 |
+
df = df.drop(columns=[c for c in ("MRI ID", "Visit", "Hand") if c in df.columns])
|
| 245 |
+
df["SES"] = df["SES"].fillna(df.groupby("EDUC")["SES"].transform("median"))
|
| 246 |
+
|
| 247 |
+
feature_df = pd.DataFrame({
|
| 248 |
+
"is_male": df["M/F"].astype(float),
|
| 249 |
+
"age": df["Age"].astype(float),
|
| 250 |
+
"educ": df["EDUC"].astype(float),
|
| 251 |
+
"ses": df["SES"].astype(float),
|
| 252 |
+
"mmse": df["MMSE"].astype(float),
|
| 253 |
+
"etiv": df["eTIV"].astype(float),
|
| 254 |
+
"nwbv": df["nWBV"].astype(float),
|
| 255 |
+
"asf": df["ASF"].astype(float),
|
| 256 |
+
})[list(FEATURE_ORDER)]
|
| 257 |
+
return feature_df, df["Group"].astype(int)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def train_from_csv(csv_path: Path, artifact_path: Path) -> Path:
|
| 261 |
+
"""Train and persist a MinMaxScaler→RandomForest pipeline. Returns artifact path."""
|
| 262 |
+
csv_path = Path(csv_path)
|
| 263 |
+
artifact_path = Path(artifact_path)
|
| 264 |
+
if not csv_path.exists():
|
| 265 |
+
raise FileNotFoundError(f"OASIS CSV not found: {csv_path}")
|
| 266 |
+
|
| 267 |
+
X, y = _df_from_oasis_csv(csv_path)
|
| 268 |
+
pipeline = Pipeline([
|
| 269 |
+
("scaler", MinMaxScaler()),
|
| 270 |
+
("rf", RandomForestClassifier(
|
| 271 |
+
n_estimators=12, max_depth=8, max_features=8,
|
| 272 |
+
n_jobs=4, random_state=0,
|
| 273 |
+
)),
|
| 274 |
+
])
|
| 275 |
+
pipeline.fit(X, y)
|
| 276 |
+
artifact_path.parent.mkdir(parents=True, exist_ok=True)
|
| 277 |
+
joblib.dump(pipeline, artifact_path)
|
| 278 |
+
logger.info("trained OASIS RF: n=%d, artifact=%s", len(X), artifact_path)
|
| 279 |
+
return artifact_path
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def load(artifact_path: Path) -> Pipeline:
|
| 283 |
+
p = Path(artifact_path)
|
| 284 |
+
if not p.exists():
|
| 285 |
+
raise FileNotFoundError(f"OASIS classifier artifact not found: {p}")
|
| 286 |
+
return joblib.load(p)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def predict_one(model: Pipeline, features: dict[str, float]) -> dict[str, Any]:
|
| 290 |
+
"""Predict for a single subject. `features` must have all FEATURE_ORDER keys."""
|
| 291 |
+
missing = [k for k in FEATURE_ORDER if k not in features]
|
| 292 |
+
if missing:
|
| 293 |
+
raise ValueError(f"OASIS prediction missing features: {missing}")
|
| 294 |
+
row = pd.DataFrame([{k: float(features[k]) for k in FEATURE_ORDER}])
|
| 295 |
+
probs = np.asarray(model.predict_proba(row))[0]
|
| 296 |
+
label_idx = int(np.argmax(probs))
|
| 297 |
+
return {
|
| 298 |
+
"label": label_idx,
|
| 299 |
+
"label_text": LABEL_NAMES[label_idx],
|
| 300 |
+
"confidence": float(probs[label_idx]),
|
| 301 |
+
"probabilities": [
|
| 302 |
+
{"label": i, "label_text": LABEL_NAMES[i], "probability": float(p)}
|
| 303 |
+
for i, p in enumerate(probs)
|
| 304 |
+
],
|
| 305 |
+
}
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
`scripts/train_oasis.py`:
|
| 309 |
+
|
| 310 |
+
```python
|
| 311 |
+
"""CLI: train the OASIS RF classifier and save it.
|
| 312 |
+
|
| 313 |
+
Usage:
|
| 314 |
+
python scripts/train_oasis.py data/external/oasis_longitudinal.csv data/processed/oasis_rf.joblib
|
| 315 |
+
"""
|
| 316 |
+
from __future__ import annotations
|
| 317 |
+
|
| 318 |
+
import sys
|
| 319 |
+
from pathlib import Path
|
| 320 |
+
|
| 321 |
+
from src.models.tabular_oasis import train_from_csv
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def main() -> None:
|
| 325 |
+
if len(sys.argv) != 3:
|
| 326 |
+
print(__doc__)
|
| 327 |
+
sys.exit(1)
|
| 328 |
+
csv = Path(sys.argv[1])
|
| 329 |
+
out = Path(sys.argv[2])
|
| 330 |
+
train_from_csv(csv, out)
|
| 331 |
+
print(f"saved: {out}")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
main()
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
Run tests → 4 passed.
|
| 339 |
+
|
| 340 |
+
- [ ] **Step 4:** commit: `feat(models): OASIS tabular Alzheimer's RF classifier (joblib + train CLI)`.
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
### Task 2: Extend fusion's clinical inputs
|
| 345 |
+
|
| 346 |
+
**Files:**
|
| 347 |
+
- Modify: `src/fusion/types.py` (extend `ClinicalScores`)
|
| 348 |
+
- Modify: `src/fusion/clinical.py` (add normalisers for the new fields)
|
| 349 |
+
- Modify: `tests/fusion/test_types.py` (loosen / extend bound tests)
|
| 350 |
+
- Modify: `tests/fusion/test_clinical.py` (add new normaliser tests)
|
| 351 |
+
|
| 352 |
+
- [ ] **Step 1: Failing test for new ClinicalScores fields.**
|
| 353 |
+
|
| 354 |
+
In `tests/fusion/test_types.py`, append:
|
| 355 |
+
|
| 356 |
+
```python
|
| 357 |
+
class TestExtendedClinicalScores:
|
| 358 |
+
def test_etiv_in_range(self) -> None:
|
| 359 |
+
s = ClinicalScores(etiv=1500.0)
|
| 360 |
+
assert s.etiv == pytest.approx(1500.0)
|
| 361 |
+
|
| 362 |
+
def test_etiv_out_of_range_rejected(self) -> None:
|
| 363 |
+
with pytest.raises(ValidationError):
|
| 364 |
+
ClinicalScores(etiv=5000.0)
|
| 365 |
+
|
| 366 |
+
def test_nwbv_in_range(self) -> None:
|
| 367 |
+
s = ClinicalScores(nwbv=0.72)
|
| 368 |
+
assert s.nwbv == pytest.approx(0.72)
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
- [ ] **Step 2: Update `src/fusion/types.py` ClinicalScores.**
|
| 372 |
+
|
| 373 |
+
Add fields (preserve existing ones):
|
| 374 |
+
|
| 375 |
+
```python
|
| 376 |
+
class ClinicalScores(BaseModel):
|
| 377 |
+
mmse: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
|
| 378 |
+
moca: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
|
| 379 |
+
updrs: Annotated[float, Field(ge=0.0, le=199.0)] | None = None
|
| 380 |
+
gait_speed_m_s: Annotated[float, Field(ge=0.0, le=2.5)] | None = None
|
| 381 |
+
age_years: Annotated[float, Field(ge=0.0, le=120.0)] | None = None
|
| 382 |
+
# OASIS biomarkers — used by the tabular_oasis modality.
|
| 383 |
+
etiv: Annotated[float, Field(ge=900.0, le=2200.0)] | None = None
|
| 384 |
+
nwbv: Annotated[float, Field(ge=0.5, le=0.95)] | None = None
|
| 385 |
+
asf: Annotated[float, Field(ge=0.5, le=2.0)] | None = None
|
| 386 |
+
educ: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
|
| 387 |
+
ses: Annotated[float, Field(ge=1.0, le=5.0)] | None = None
|
| 388 |
+
is_male: Annotated[int, Field(ge=0, le=1)] | None = None
|
| 389 |
+
```
|
| 390 |
+
|
| 391 |
+
- [ ] **Step 3:** the tests should pass after the type change. `pytest tests/fusion/test_types.py -v`.
|
| 392 |
+
|
| 393 |
+
- [ ] **Step 4:** commit: `feat(fusion): extend ClinicalScores with OASIS biomarker fields`.
|
| 394 |
+
|
| 395 |
+
---
|
| 396 |
+
|
| 397 |
+
### Task 3: Wire `tabular_oasis` modality into the fusion engine
|
| 398 |
+
|
| 399 |
+
**Files:**
|
| 400 |
+
- Modify: `src/fusion/weights.py`
|
| 401 |
+
- Modify: `src/fusion/engine.py`
|
| 402 |
+
- Create: `tests/fusion/test_tabular_oasis_modality.py`
|
| 403 |
+
|
| 404 |
+
- [ ] **Step 1: Update weights.**
|
| 405 |
+
|
| 406 |
+
`src/fusion/weights.py`, in the `alzheimers` table:
|
| 407 |
+
|
| 408 |
+
```python
|
| 409 |
+
"alzheimers": {
|
| 410 |
+
"mri": 0.25, # was 0.35
|
| 411 |
+
"eeg": 0.15, # was 0.20
|
| 412 |
+
"tabular_oasis": 0.20, # new
|
| 413 |
+
"clinical_mmse": 0.20,
|
| 414 |
+
"clinical_moca": 0.10, # was 0.15
|
| 415 |
+
"clinical_age": 0.10,
|
| 416 |
+
},
|
| 417 |
+
```
|
| 418 |
+
|
| 419 |
+
Re-balance so the table still sums to 1.0. Add a comment that re-balancing changed the existing tests' tolerances — verify which tests need updating.
|
| 420 |
+
|
| 421 |
+
- [ ] **Step 2: Failing fusion-modality test.**
|
| 422 |
+
|
| 423 |
+
`tests/fusion/test_tabular_oasis_modality.py`:
|
| 424 |
+
|
| 425 |
+
```python
|
| 426 |
+
"""Tests: tabular_oasis modality contributes to alzheimers fusion score."""
|
| 427 |
+
from __future__ import annotations
|
| 428 |
+
|
| 429 |
+
import os
|
| 430 |
+
from pathlib import Path
|
| 431 |
+
|
| 432 |
+
import pytest
|
| 433 |
+
|
| 434 |
+
from src.fusion import engine
|
| 435 |
+
from src.fusion.types import ClinicalScores, FusionInput
|
| 436 |
+
from src.models.tabular_oasis import train_from_csv
|
| 437 |
+
from tests.fixtures.build_synthetic_oasis import build as build_synth
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
@pytest.fixture()
|
| 441 |
+
def trained_artifact(tmp_path: Path, monkeypatch) -> Path:
|
| 442 |
+
csv = build_synth(tmp_path / "oasis.csv")
|
| 443 |
+
art = train_from_csv(csv, tmp_path / "rf.joblib")
|
| 444 |
+
monkeypatch.setenv("OASIS_RF_ARTIFACT", str(art))
|
| 445 |
+
return art
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class TestTabularOasisModality:
|
| 449 |
+
def test_demented_profile_raises_alzheimers(self, trained_artifact: Path) -> None:
|
| 450 |
+
out = engine.fuse(FusionInput(clinical=ClinicalScores(
|
| 451 |
+
is_male=1, age_years=88, educ=8, ses=3.0,
|
| 452 |
+
mmse=15.0, etiv=1300.0, nwbv=0.66, asf=1.3,
|
| 453 |
+
)))
|
| 454 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 455 |
+
assert alz.probability > 0.6
|
| 456 |
+
assert any(c.modality == "tabular_oasis" for c in alz.contributions)
|
| 457 |
+
|
| 458 |
+
def test_missing_oasis_inputs_skips_modality(self, trained_artifact: Path) -> None:
|
| 459 |
+
# MMSE alone but no etiv/nwbv → tabular_oasis should be skipped, not error.
|
| 460 |
+
out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=12.0)))
|
| 461 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 462 |
+
names = {c.modality for c in alz.contributions}
|
| 463 |
+
assert "tabular_oasis" not in names
|
| 464 |
+
```
|
| 465 |
+
|
| 466 |
+
- [ ] **Step 3: Update the engine.**
|
| 467 |
+
|
| 468 |
+
In `src/fusion/engine.py`, add a tabular-modality dispatcher that lazy-loads the joblib artifact once and treats the OASIS classifier's `P(Demented)` as the alzheimers signal `2*P-1`:
|
| 469 |
+
|
| 470 |
+
```python
|
| 471 |
+
import os
|
| 472 |
+
|
| 473 |
+
_oasis_cache: dict[str, Any] = {}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def _signal_for_tabular_oasis(disease: str, clinical: ClinicalScores) -> float | None:
|
| 477 |
+
if disease != "alzheimers":
|
| 478 |
+
return None
|
| 479 |
+
required = ("is_male", "age_years", "educ", "ses", "mmse", "etiv", "nwbv", "asf")
|
| 480 |
+
if any(getattr(clinical, k, None) is None for k in required):
|
| 481 |
+
return None
|
| 482 |
+
artifact = os.environ.get("OASIS_RF_ARTIFACT", "data/processed/oasis_rf.joblib")
|
| 483 |
+
artifact_path = Path(artifact)
|
| 484 |
+
if not artifact_path.exists():
|
| 485 |
+
logger.warning("tabular_oasis artifact missing at %s; skipping modality", artifact_path)
|
| 486 |
+
return None
|
| 487 |
+
if "model" not in _oasis_cache:
|
| 488 |
+
from src.models.tabular_oasis import load
|
| 489 |
+
_oasis_cache["model"] = load(artifact_path)
|
| 490 |
+
from src.models.tabular_oasis import predict_one
|
| 491 |
+
feats = {
|
| 492 |
+
"is_male": int(clinical.is_male),
|
| 493 |
+
"age": float(clinical.age_years),
|
| 494 |
+
"educ": float(clinical.educ),
|
| 495 |
+
"ses": float(clinical.ses),
|
| 496 |
+
"mmse": float(clinical.mmse),
|
| 497 |
+
"etiv": float(clinical.etiv),
|
| 498 |
+
"nwbv": float(clinical.nwbv),
|
| 499 |
+
"asf": float(clinical.asf),
|
| 500 |
+
}
|
| 501 |
+
pred = predict_one(_oasis_cache["model"], feats)
|
| 502 |
+
p_dem = next(p["probability"] for p in pred["probabilities"] if p["label_text"] == "Demented")
|
| 503 |
+
return 2.0 * p_dem - 1.0
|
| 504 |
+
```
|
| 505 |
+
|
| 506 |
+
In `_signal_for_modality`, add the dispatch:
|
| 507 |
+
|
| 508 |
+
```python
|
| 509 |
+
if modality_key == "tabular_oasis":
|
| 510 |
+
return _signal_for_tabular_oasis(disease, clinical)
|
| 511 |
+
```
|
| 512 |
+
|
| 513 |
+
- [ ] **Step 4:** `pytest tests/fusion/ -v` — expect re-balancing to perturb a couple of existing thresholds. Adjust thresholds in the affected tests (e.g., the disagreement test) so they still hold with the new weights, OR adjust the new weights so existing tests still pass within tolerance. Prefer the latter — existing thresholds were chosen carefully.
|
| 514 |
+
|
| 515 |
+
- [ ] **Step 5:** commit: `feat(fusion): add tabular_oasis modality with lazy joblib load`.
|
| 516 |
+
|
| 517 |
+
---
|
| 518 |
+
|
| 519 |
+
### Task 4: API + Streamlit + README
|
| 520 |
+
|
| 521 |
+
**Files:**
|
| 522 |
+
- Modify: `src/api/routes.py` — add `POST /predict/tabular_oasis`
|
| 523 |
+
- Modify: `src/api/schemas.py` — request/response schemas
|
| 524 |
+
- Modify: `src/frontend/app.py` — extend the Doctor view's clinical-input form with eTIV / nWBV / ASF / EDUC / SES
|
| 525 |
+
- Modify: `README.md` — describe the new modality and the OASIS dataset path
|
| 526 |
+
|
| 527 |
+
- [ ] **Step 1: New schemas.**
|
| 528 |
+
|
| 529 |
+
`src/api/schemas.py`:
|
| 530 |
+
|
| 531 |
+
```python
|
| 532 |
+
class TabularOasisRequest(BaseModel):
|
| 533 |
+
is_male: int = Field(..., ge=0, le=1)
|
| 534 |
+
age: float = Field(..., ge=0.0, le=120.0)
|
| 535 |
+
educ: float = Field(..., ge=0.0, le=30.0)
|
| 536 |
+
ses: float = Field(..., ge=1.0, le=5.0)
|
| 537 |
+
mmse: float = Field(..., ge=0.0, le=30.0)
|
| 538 |
+
etiv: float = Field(..., ge=900.0, le=2200.0)
|
| 539 |
+
nwbv: float = Field(..., ge=0.5, le=0.95)
|
| 540 |
+
asf: float = Field(..., ge=0.5, le=2.0)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class TabularOasisProbability(BaseModel):
|
| 544 |
+
label: int
|
| 545 |
+
label_text: str
|
| 546 |
+
probability: float
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class TabularOasisResponse(BaseModel):
|
| 550 |
+
label: int
|
| 551 |
+
label_text: str
|
| 552 |
+
confidence: float
|
| 553 |
+
probabilities: list[TabularOasisProbability]
|
| 554 |
+
```
|
| 555 |
+
|
| 556 |
+
- [ ] **Step 2: Route.**
|
| 557 |
+
|
| 558 |
+
`src/api/routes.py`:
|
| 559 |
+
|
| 560 |
+
```python
|
| 561 |
+
@predict_router.post("/tabular_oasis", response_model=TabularOasisResponse)
|
| 562 |
+
def predict_tabular_oasis(req: TabularOasisRequest) -> TabularOasisResponse:
|
| 563 |
+
from src.models.tabular_oasis import load, predict_one
|
| 564 |
+
artifact = Path(os.environ.get("OASIS_RF_ARTIFACT", "data/processed/oasis_rf.joblib"))
|
| 565 |
+
model = load(artifact)
|
| 566 |
+
out = predict_one(model, req.model_dump())
|
| 567 |
+
return TabularOasisResponse(**out)
|
| 568 |
+
```
|
| 569 |
+
|
| 570 |
+
- [ ] **Step 3: Test (`tests/api/test_tabular_oasis_route.py`).**
|
| 571 |
+
|
| 572 |
+
```python
|
| 573 |
+
"""Integration: POST /predict/tabular_oasis."""
|
| 574 |
+
from __future__ import annotations
|
| 575 |
+
|
| 576 |
+
from pathlib import Path
|
| 577 |
+
|
| 578 |
+
import pytest
|
| 579 |
+
from fastapi.testclient import TestClient
|
| 580 |
+
|
| 581 |
+
from src.api.main import app
|
| 582 |
+
from src.models.tabular_oasis import train_from_csv
|
| 583 |
+
from tests.fixtures.build_synthetic_oasis import build as build_synth
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@pytest.fixture()
|
| 587 |
+
def client(monkeypatch, tmp_path):
|
| 588 |
+
csv = build_synth(tmp_path / "oasis.csv")
|
| 589 |
+
artifact = train_from_csv(csv, tmp_path / "rf.joblib")
|
| 590 |
+
monkeypatch.setenv("OASIS_RF_ARTIFACT", str(artifact))
|
| 591 |
+
return TestClient(app)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def test_predict_tabular_oasis_demented_profile(client):
|
| 595 |
+
body = {
|
| 596 |
+
"is_male": 1, "age": 88, "educ": 8, "ses": 3.0,
|
| 597 |
+
"mmse": 15.0, "etiv": 1300.0, "nwbv": 0.66, "asf": 1.3,
|
| 598 |
+
}
|
| 599 |
+
r = client.post("/predict/tabular_oasis", json=body)
|
| 600 |
+
assert r.status_code == 200, r.text
|
| 601 |
+
data = r.json()
|
| 602 |
+
assert data["label_text"] == "Demented"
|
| 603 |
+
```
|
| 604 |
+
|
| 605 |
+
- [ ] **Step 4:** Streamlit form extension. In `src/frontend/app.py`, find the clinical-inputs section the doctor view exposes (likely under a "Clinical scores" expander; if absent, add it under the fusion tab). Add number_input widgets for the seven new fields (`is_male`, `age`, `educ`, `ses`, `etiv`, `nwbv`, `asf`) that flow into the existing `/fusion/predict` payload's `clinical` block.
|
| 606 |
+
|
| 607 |
+
- [ ] **Step 5:** README. Append:
|
| 608 |
+
|
| 609 |
+
```markdown
|
| 610 |
+
### OASIS Tabular Alzheimer's Classifier
|
| 611 |
+
|
| 612 |
+
A scikit-learn Random Forest trained on the OASIS longitudinal dataset (https://www.oasis-brains.org/) classifies Demented vs Nondemented from 8 biomarkers (sex, age, education, SES, MMSE, eTIV, nWBV, ASF). It contributes to the fusion engine as modality `tabular_oasis` (weight 0.20 for Alzheimer's).
|
| 613 |
+
|
| 614 |
+
To use: download `oasis_longitudinal.csv` from Kaggle, save to `data/external/oasis_longitudinal.csv`, then:
|
| 615 |
+
|
| 616 |
+
```bash
|
| 617 |
+
python scripts/train_oasis.py data/external/oasis_longitudinal.csv data/processed/oasis_rf.joblib
|
| 618 |
+
export OASIS_RF_ARTIFACT=data/processed/oasis_rf.joblib
|
| 619 |
+
```
|
| 620 |
+
|
| 621 |
+
The fusion engine and `POST /predict/tabular_oasis` will pick it up. If the artifact is missing, the modality is skipped — fusion still works.
|
| 622 |
+
```
|
| 623 |
+
|
| 624 |
+
- [ ] **Step 6:** commit: `feat(oasis): /predict/tabular_oasis route + Streamlit form + README`.
|
| 625 |
+
|
| 626 |
+
---
|
| 627 |
+
|
| 628 |
+
## Self-review checklist
|
| 629 |
+
|
| 630 |
+
1. **Independence.** OASIS classifier and fusion remain decoupled when the artifact is absent (`OASIS_RF_ARTIFACT` unset → modality skipped). ✓
|
| 631 |
+
2. **No real-data fabrication.** Tests use a clearly-labelled synthetic CSV. The real OASIS dataset is never committed. ✓
|
| 632 |
+
3. **Backward compatibility.** Existing `ClinicalScores` fields untouched. New fields are all `Optional`. ✓
|
| 633 |
+
4. **Branch 3a vs 3b.** This plan is Branch 3a. If the user picks Branch 3b, this plan is replaced wholesale.
|
| 634 |
+
|
| 635 |
+
---
|
| 636 |
+
|
| 637 |
+
## Execution handoff
|
| 638 |
+
|
| 639 |
+
Save and choose: subagent-driven (recommended) or inline executing-plans.
|
| 640 |
+
|
| 641 |
+
**Reminder to controller:** before starting any task, confirm with the user: "Do you have a real EEG checkpoint I'm missing, or shall I proceed with Branch 3a (OASIS tabular Alzheimer's classifier)?"
|
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TF-IDF Clinical RAG Integration Plan
|
| 2 |
+
|
| 3 |
+
> **For agentic workers:** REQUIRED SUB-SKILL: `superpowers:subagent-driven-development`. TDD throughout.
|
| 4 |
+
|
| 5 |
+
**Goal.** Integrate the user's pre-built TF-IDF RAG corpus (14 medical PDFs covering Alzheimer's, Parkinson's, lifestyle, nutrition, exercise; Turkish + English query expansion; pre-built `rag_index.pkl`) into the platform alongside the existing FAISS+fastembed RAG. Both run side-by-side; the agent picks per-query.
|
| 6 |
+
|
| 7 |
+
**Architecture.** A new `src/rag/clinical/` sub-package wraps the user's `rag.py` script as an importable module. The existing `retrieve_context` agent tool grows a `corpus` parameter (`"reference"` for FAISS — current behaviour, kept default — `"clinical"` for the new TF-IDF index). A new module-level retriever object is constructed once at startup and reused. Pure addition: existing tests and behaviour stay green.
|
| 8 |
+
|
| 9 |
+
**Tech stack.** scikit-learn (already in deps via the existing pipelines? — verify; if not, add), `pypdf`, `numpy`. No FAISS or new embedding model. The user's pickle deserialises with stdlib `pickle` and `sklearn.feature_extraction.text.TfidfVectorizer`.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Prerequisite (controller blocker)
|
| 14 |
+
|
| 15 |
+
The corpus and index live at `/Users/mertgungor/Downloads/rag/`. Before any task starts:
|
| 16 |
+
|
| 17 |
+
1. **Source PDFs.** Copy `Downloads/rag/HACKATHON/*.pdf` to `data/external_rag/clinical_pdfs/` in this repo. **Do NOT commit the PDFs to git** — they are external research papers, possibly copyrighted, and large (~33 MB total). Add `data/external_rag/` to `.gitignore` if not already covered (the existing repo gitignores `data/processed/` — check that this also covers `data/external_rag/`).
|
| 18 |
+
|
| 19 |
+
2. **Pre-built index.** Copy `Downloads/rag/index/rag_index.pkl` to `data/external_rag/index/rag_index.pkl`. Also gitignored.
|
| 20 |
+
|
| 21 |
+
3. **Verify pickle loads.** `python -c "import pickle; print(list(pickle.load(open('data/external_rag/index/rag_index.pkl','rb')).keys()))"`. Expect: `['created_at', 'source_dir', 'chunk_words', 'overlap_words', 'chunks', 'vectorizer', 'matrix']`.
|
| 22 |
+
|
| 23 |
+
If pickle fails to load (sklearn version mismatch, missing module), Task 1 has a regenerate fallback: rebuild the index from the PDFs using the same parameters in `Downloads/rag/rag.py`.
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## File structure
|
| 28 |
+
|
| 29 |
+
| Path | Responsibility |
|
| 30 |
+
|---|---|
|
| 31 |
+
| Modify `requirements.txt` | confirm `scikit-learn` and `pypdf` are present (sklearn likely is; pypdf may not be) |
|
| 32 |
+
| Modify `.gitignore` | ensure `data/external_rag/` is ignored |
|
| 33 |
+
| Create `src/rag/clinical/__init__.py` | package marker |
|
| 34 |
+
| Create `src/rag/clinical/types.py` | `ClinicalChunk` dataclass + `ClinicalRetrievalResult` pydantic |
|
| 35 |
+
| Create `src/rag/clinical/loader.py` | unpickle the index; handle the user's payload schema; rebuild from PDFs as fallback |
|
| 36 |
+
| Create `src/rag/clinical/retrieve.py` | TF-IDF query + Turkish/English query expansion + sentence-level evidence picking |
|
| 37 |
+
| Modify `src/agents/tools.py` | `retrieve_context` accepts `corpus: Literal["reference", "clinical"]` |
|
| 38 |
+
| Modify `src/agents/prompts.py` | one-line update describing the corpus parameter |
|
| 39 |
+
| Create `tests/rag/test_clinical_loader.py` | unpickle + rebuild tests using a tiny fixture corpus |
|
| 40 |
+
| Create `tests/rag/test_clinical_retrieve.py` | retrieval correctness tests |
|
| 41 |
+
| Create `tests/agents/test_tools_clinical_corpus.py` | end-to-end agent-tool routing |
|
| 42 |
+
| Create `tests/fixtures/build_tiny_clinical_index.py` | builds a 2-page synthetic PDF index for tests |
|
| 43 |
+
| Modify `README.md` | document the dual-corpus surface |
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Tasks
|
| 48 |
+
|
| 49 |
+
### Task 0: Deps + gitignore + asset copy verification
|
| 50 |
+
|
| 51 |
+
**Files:** `requirements.txt`, `.gitignore`
|
| 52 |
+
|
| 53 |
+
- [ ] **Step 1:** check `requirements.txt`. If `pypdf` is absent, add `pypdf>=4.0,<6.0`. `scikit-learn` should already be there from the existing pipelines — verify with `pip show scikit-learn`.
|
| 54 |
+
|
| 55 |
+
- [ ] **Step 2:** open `.gitignore`. If `data/external_rag/` (or a parent that covers it) is not ignored, add a single line: `data/external_rag/`.
|
| 56 |
+
|
| 57 |
+
- [ ] **Step 3:** verify the asset transfer. `ls data/external_rag/clinical_pdfs/*.pdf | wc -l` should print `14` and `ls -lh data/external_rag/index/rag_index.pkl` should show ~12.9 MB.
|
| 58 |
+
|
| 59 |
+
- [ ] **Step 4:** if pypdf was added, run `pip install -r requirements.txt`.
|
| 60 |
+
|
| 61 |
+
- [ ] **Step 5:** `pytest -q` baseline. Commit: `chore(rag): pin pypdf; gitignore external rag corpus`.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
### Task 1: Loader for the pre-built TF-IDF index
|
| 66 |
+
|
| 67 |
+
**Files:**
|
| 68 |
+
- Create: `src/rag/clinical/__init__.py`
|
| 69 |
+
- Create: `src/rag/clinical/types.py`
|
| 70 |
+
- Create: `src/rag/clinical/loader.py`
|
| 71 |
+
- Create: `tests/rag/test_clinical_loader.py`
|
| 72 |
+
- Create: `tests/fixtures/build_tiny_clinical_index.py`
|
| 73 |
+
|
| 74 |
+
- [ ] **Step 1: Build the tiny-fixture helper.**
|
| 75 |
+
|
| 76 |
+
`tests/fixtures/build_tiny_clinical_index.py`:
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
"""Build a synthetic TF-IDF clinical-RAG index for tests.
|
| 80 |
+
|
| 81 |
+
Avoids needing real PDFs. Constructs the same payload schema the user's
|
| 82 |
+
rag.py produces so the loader can be tested independently of pypdf.
|
| 83 |
+
"""
|
| 84 |
+
from __future__ import annotations
|
| 85 |
+
|
| 86 |
+
import pickle
|
| 87 |
+
from dataclasses import dataclass
|
| 88 |
+
from datetime import datetime
|
| 89 |
+
from pathlib import Path
|
| 90 |
+
|
| 91 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Same schema the user's rag.py produces. We define our own dataclass here
|
| 95 |
+
# so the test fixture is self-contained.
|
| 96 |
+
@dataclass(frozen=True)
|
| 97 |
+
class _Chunk:
|
| 98 |
+
chunk_id: int
|
| 99 |
+
source: str
|
| 100 |
+
page_start: int
|
| 101 |
+
page_end: int
|
| 102 |
+
text: str
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def build(path: Path) -> Path:
|
| 106 |
+
"""Save a tiny TF-IDF index at `path`."""
|
| 107 |
+
path = Path(path)
|
| 108 |
+
if path.exists():
|
| 109 |
+
return path
|
| 110 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
chunks = [
|
| 113 |
+
_Chunk(0, "alzheimers_lifestyle.pdf", 1, 1,
|
| 114 |
+
"Aerobic exercise and Mediterranean diet are associated with reduced cognitive decline in older adults at risk for Alzheimer's disease."),
|
| 115 |
+
_Chunk(1, "parkinsons_motor.pdf", 1, 1,
|
| 116 |
+
"Levodopa remains the most effective symptomatic treatment for motor symptoms of Parkinson's disease."),
|
| 117 |
+
_Chunk(2, "alzheimers_mci.pdf", 2, 2,
|
| 118 |
+
"Mild cognitive impairment may progress to dementia; MMSE and MoCA are standard screening tools."),
|
| 119 |
+
_Chunk(3, "parkinsons_nutrition.pdf", 1, 1,
|
| 120 |
+
"Dietary patterns rich in antioxidants and omega-3 fatty acids are linked to lower Parkinson's risk."),
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
vectorizer = TfidfVectorizer(lowercase=True, ngram_range=(1, 2), min_df=1, norm="l2")
|
| 124 |
+
matrix = vectorizer.fit_transform([c.text for c in chunks])
|
| 125 |
+
|
| 126 |
+
payload = {
|
| 127 |
+
"created_at": datetime.now().isoformat(timespec="seconds"),
|
| 128 |
+
"source_dir": str(path.parent),
|
| 129 |
+
"chunk_words": 220,
|
| 130 |
+
"overlap_words": 45,
|
| 131 |
+
"chunks": chunks,
|
| 132 |
+
"vectorizer": vectorizer,
|
| 133 |
+
"matrix": matrix,
|
| 134 |
+
}
|
| 135 |
+
with path.open("wb") as f:
|
| 136 |
+
pickle.dump(payload, f)
|
| 137 |
+
return path
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
- [ ] **Step 2: Failing test.**
|
| 141 |
+
|
| 142 |
+
`tests/rag/test_clinical_loader.py`:
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
"""Tests for src.rag.clinical.loader."""
|
| 146 |
+
from __future__ import annotations
|
| 147 |
+
|
| 148 |
+
from pathlib import Path
|
| 149 |
+
|
| 150 |
+
import pytest
|
| 151 |
+
|
| 152 |
+
from src.rag.clinical import loader
|
| 153 |
+
from tests.fixtures.build_tiny_clinical_index import build as build_tiny
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class TestLoadIndex:
|
| 157 |
+
def test_load_returns_payload_with_expected_keys(self, tmp_path: Path) -> None:
|
| 158 |
+
idx_path = build_tiny(tmp_path / "tiny.pkl")
|
| 159 |
+
payload = loader.load_index(idx_path)
|
| 160 |
+
assert {"chunks", "vectorizer", "matrix"} <= set(payload)
|
| 161 |
+
assert len(payload["chunks"]) == 4
|
| 162 |
+
|
| 163 |
+
def test_missing_index_raises(self, tmp_path: Path) -> None:
|
| 164 |
+
with pytest.raises(FileNotFoundError, match="clinical RAG index not found"):
|
| 165 |
+
loader.load_index(tmp_path / "nope.pkl")
|
| 166 |
+
|
| 167 |
+
def test_unique_sources(self, tmp_path: Path) -> None:
|
| 168 |
+
idx_path = build_tiny(tmp_path / "tiny.pkl")
|
| 169 |
+
payload = loader.load_index(idx_path)
|
| 170 |
+
sources = {c.source for c in payload["chunks"]}
|
| 171 |
+
assert sources == {
|
| 172 |
+
"alzheimers_lifestyle.pdf", "parkinsons_motor.pdf",
|
| 173 |
+
"alzheimers_mci.pdf", "parkinsons_nutrition.pdf",
|
| 174 |
+
}
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Run → ImportError on `src.rag.clinical.loader`.
|
| 178 |
+
|
| 179 |
+
- [ ] **Step 3: Minimal impl.**
|
| 180 |
+
|
| 181 |
+
`src/rag/clinical/__init__.py`: empty.
|
| 182 |
+
|
| 183 |
+
`src/rag/clinical/types.py`:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
"""Types shared across clinical-RAG modules."""
|
| 187 |
+
from __future__ import annotations
|
| 188 |
+
|
| 189 |
+
from dataclasses import dataclass
|
| 190 |
+
|
| 191 |
+
from pydantic import BaseModel, Field
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# The user's rag.py uses a frozen dataclass with this exact name and fields.
|
| 195 |
+
# We re-export the same shape so the loader can read pickles produced by
|
| 196 |
+
# either the user's script or our test fixture without translation.
|
| 197 |
+
@dataclass(frozen=True)
|
| 198 |
+
class ClinicalChunk:
|
| 199 |
+
chunk_id: int
|
| 200 |
+
source: str
|
| 201 |
+
page_start: int
|
| 202 |
+
page_end: int
|
| 203 |
+
text: str
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class ClinicalEvidence(BaseModel):
|
| 207 |
+
sentence: str
|
| 208 |
+
source: str
|
| 209 |
+
page_start: int
|
| 210 |
+
page_end: int
|
| 211 |
+
score: float = Field(..., ge=0.0)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ClinicalRetrievalResult(BaseModel):
|
| 215 |
+
query: str
|
| 216 |
+
evidence: list[ClinicalEvidence]
|
| 217 |
+
summary_text: str = Field(..., description="Pre-formatted RAG feedback string for the agent")
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
`src/rag/clinical/loader.py`:
|
| 221 |
+
|
| 222 |
+
```python
|
| 223 |
+
"""Load (or rebuild) the TF-IDF clinical RAG index."""
|
| 224 |
+
from __future__ import annotations
|
| 225 |
+
|
| 226 |
+
import pickle
|
| 227 |
+
from pathlib import Path
|
| 228 |
+
from typing import Any
|
| 229 |
+
|
| 230 |
+
from src.core.logger import get_logger
|
| 231 |
+
|
| 232 |
+
logger = get_logger(__name__)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def load_index(path: Path) -> dict[str, Any]:
|
| 236 |
+
"""Unpickle a TF-IDF index produced by the user's rag.py."""
|
| 237 |
+
path = Path(path)
|
| 238 |
+
if not path.exists():
|
| 239 |
+
raise FileNotFoundError(f"clinical RAG index not found: {path}")
|
| 240 |
+
with path.open("rb") as f:
|
| 241 |
+
payload = pickle.load(f)
|
| 242 |
+
if "chunks" not in payload or "vectorizer" not in payload or "matrix" not in payload:
|
| 243 |
+
raise ValueError(f"clinical RAG index missing expected keys: {sorted(payload)}")
|
| 244 |
+
logger.info("loaded clinical RAG index: %d chunks from %s", len(payload["chunks"]), path)
|
| 245 |
+
return payload
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
Note: we rely on the source dataclass `Chunk` being importable from where it was pickled. Sklearn is fine with version drift across minor patches; if a major-version drift causes a deserialise error, the user's `rag.py` rebuild is the recovery path (Task 2 wraps that).
|
| 249 |
+
|
| 250 |
+
Run tests → 3 passed.
|
| 251 |
+
|
| 252 |
+
- [ ] **Step 4:** `pytest -q` no regressions.
|
| 253 |
+
|
| 254 |
+
- [ ] **Step 5:** commit: `feat(rag): clinical TF-IDF index loader`.
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
### Task 2: Retrieval (TF-IDF query + Turkish/English expansion + evidence picking)
|
| 259 |
+
|
| 260 |
+
**Files:**
|
| 261 |
+
- Create: `src/rag/clinical/retrieve.py`
|
| 262 |
+
- Create: `tests/rag/test_clinical_retrieve.py`
|
| 263 |
+
|
| 264 |
+
- [ ] **Step 1: Failing test.**
|
| 265 |
+
|
| 266 |
+
`tests/rag/test_clinical_retrieve.py`:
|
| 267 |
+
|
| 268 |
+
```python
|
| 269 |
+
"""Tests for src.rag.clinical.retrieve."""
|
| 270 |
+
from __future__ import annotations
|
| 271 |
+
|
| 272 |
+
from pathlib import Path
|
| 273 |
+
|
| 274 |
+
import pytest
|
| 275 |
+
|
| 276 |
+
from src.rag.clinical.retrieve import retrieve_clinical
|
| 277 |
+
from src.rag.clinical.loader import load_index
|
| 278 |
+
from tests.fixtures.build_tiny_clinical_index import build as build_tiny
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class TestRetrieve:
|
| 282 |
+
def test_alzheimer_query_picks_alzheimer_chunks(self, tmp_path: Path) -> None:
|
| 283 |
+
payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
|
| 284 |
+
result = retrieve_clinical(payload, query="exercise and Alzheimer's", top_k=2)
|
| 285 |
+
sources = {ev.source for ev in result.evidence}
|
| 286 |
+
assert any("alzheimers" in s for s in sources)
|
| 287 |
+
|
| 288 |
+
def test_parkinson_query_picks_parkinson_chunks(self, tmp_path: Path) -> None:
|
| 289 |
+
payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
|
| 290 |
+
result = retrieve_clinical(payload, query="Parkinson levodopa", top_k=2)
|
| 291 |
+
sources = {ev.source for ev in result.evidence}
|
| 292 |
+
assert any("parkinsons" in s for s in sources)
|
| 293 |
+
|
| 294 |
+
def test_turkish_keyword_routes_via_expansion(self, tmp_path: Path) -> None:
|
| 295 |
+
# User's rag.py expands "egzersiz" -> "exercise physical activity ...".
|
| 296 |
+
# Our retrieve must honour the same expansion table so Turkish queries
|
| 297 |
+
# hit English chunks.
|
| 298 |
+
payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
|
| 299 |
+
result = retrieve_clinical(payload, query="egzersiz Alzheimer", top_k=2)
|
| 300 |
+
# Turkish "egzersiz" + "alzheimer" should pick the lifestyle PDF.
|
| 301 |
+
assert any("alzheimers_lifestyle" in ev.source for ev in result.evidence)
|
| 302 |
+
|
| 303 |
+
def test_summary_text_contains_citations(self, tmp_path: Path) -> None:
|
| 304 |
+
payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
|
| 305 |
+
result = retrieve_clinical(payload, query="diet and Parkinson", top_k=2)
|
| 306 |
+
# Summary should embed source filenames so the LLM has citations.
|
| 307 |
+
assert any(ev.source in result.summary_text for ev in result.evidence)
|
| 308 |
+
|
| 309 |
+
def test_empty_query_returns_empty_evidence(self, tmp_path: Path) -> None:
|
| 310 |
+
payload = load_index(build_tiny(tmp_path / "tiny.pkl"))
|
| 311 |
+
result = retrieve_clinical(payload, query="", top_k=2)
|
| 312 |
+
assert result.evidence == []
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
Run → ImportError.
|
| 316 |
+
|
| 317 |
+
- [ ] **Step 2: Minimal impl.**
|
| 318 |
+
|
| 319 |
+
`src/rag/clinical/retrieve.py`:
|
| 320 |
+
|
| 321 |
+
```python
|
| 322 |
+
"""TF-IDF retrieval over the clinical-paper corpus, with Turkish→English query expansion."""
|
| 323 |
+
from __future__ import annotations
|
| 324 |
+
|
| 325 |
+
import re
|
| 326 |
+
from textwrap import shorten
|
| 327 |
+
from typing import Any
|
| 328 |
+
|
| 329 |
+
import numpy as np
|
| 330 |
+
|
| 331 |
+
from src.core.logger import get_logger
|
| 332 |
+
from src.rag.clinical.types import ClinicalEvidence, ClinicalRetrievalResult
|
| 333 |
+
|
| 334 |
+
logger = get_logger(__name__)
|
| 335 |
+
|
| 336 |
+
# Mirrors the table in /Users/mertgungor/Downloads/rag/rag.py so the same
|
| 337 |
+
# Turkish keyword set produces the same expansion in both pipelines.
|
| 338 |
+
_QUERY_EXPANSIONS: dict[str, str] = {
|
| 339 |
+
"alzheimer": "alzheimer dementia cognitive impairment mild cognitive impairment mci memory",
|
| 340 |
+
"demans": "dementia alzheimer cognitive impairment memory cognition",
|
| 341 |
+
"unutkanlik": "memory impairment cognitive decline dementia alzheimer",
|
| 342 |
+
"parkinson": "parkinson disease movement disorder tremor motor symptoms non motor symptoms",
|
| 343 |
+
"titreme": "tremor parkinson motor symptoms movement disorder",
|
| 344 |
+
"egzersiz": "exercise physical activity training aerobic resistance cognition",
|
| 345 |
+
"beslenme": "nutrition diet lifestyle metabolic risk factors",
|
| 346 |
+
"risk": "risk factors lifestyle metabolic nutrition prevention",
|
| 347 |
+
"tani": "diagnosis diagnostic criteria assessment screening",
|
| 348 |
+
"tedavi": "treatment management therapy intervention",
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _expand_query(query: str) -> str:
|
| 353 |
+
normalized = query.casefold()
|
| 354 |
+
extras = [exp for key, exp in _QUERY_EXPANSIONS.items() if key in normalized]
|
| 355 |
+
return f"{query} {' '.join(extras)}" if extras else query
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _split_sentences(text: str) -> list[str]:
|
| 359 |
+
sentences = re.split(r"(?<=[.!?])\s+", text)
|
| 360 |
+
return [s.strip() for s in sentences if len(s.split()) >= 6]
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _query_terms(expanded: str) -> set[str]:
|
| 364 |
+
return {t for t in re.findall(r"[A-Za-z0-9]+", expanded.lower()) if len(t) >= 4}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def retrieve_clinical(
|
| 368 |
+
payload: dict[str, Any],
|
| 369 |
+
query: str,
|
| 370 |
+
top_k: int = 5,
|
| 371 |
+
evidence_limit: int = 5,
|
| 372 |
+
) -> ClinicalRetrievalResult:
|
| 373 |
+
"""Run TF-IDF search over the clinical corpus, return evidence + a feedback summary."""
|
| 374 |
+
if not query.strip():
|
| 375 |
+
return ClinicalRetrievalResult(query=query, evidence=[], summary_text="")
|
| 376 |
+
|
| 377 |
+
vectorizer = payload["vectorizer"]
|
| 378 |
+
matrix = payload["matrix"]
|
| 379 |
+
chunks = payload["chunks"]
|
| 380 |
+
|
| 381 |
+
expanded = _expand_query(query)
|
| 382 |
+
qv = vectorizer.transform([expanded])
|
| 383 |
+
scores = (matrix @ qv.T).toarray().ravel()
|
| 384 |
+
if not np.any(scores):
|
| 385 |
+
return ClinicalRetrievalResult(query=query, evidence=[], summary_text="")
|
| 386 |
+
|
| 387 |
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 388 |
+
top_chunks = [(chunks[int(i)], float(scores[int(i)])) for i in top_indices if scores[int(i)] > 0]
|
| 389 |
+
|
| 390 |
+
# Sentence-level evidence picking: pick the highest-overlap sentences first.
|
| 391 |
+
terms = _query_terms(expanded)
|
| 392 |
+
candidates: list[tuple[float, str, Any, float]] = []
|
| 393 |
+
for chunk, chunk_score in top_chunks:
|
| 394 |
+
for sentence in _split_sentences(chunk.text):
|
| 395 |
+
sent_terms = set(re.findall(r"[A-Za-z0-9]+", sentence.lower()))
|
| 396 |
+
overlap = len(terms & sent_terms)
|
| 397 |
+
if overlap == 0:
|
| 398 |
+
continue
|
| 399 |
+
candidates.append((overlap + chunk_score, sentence, chunk, chunk_score))
|
| 400 |
+
|
| 401 |
+
candidates.sort(key=lambda item: item[0], reverse=True)
|
| 402 |
+
seen: set[str] = set()
|
| 403 |
+
evidence: list[ClinicalEvidence] = []
|
| 404 |
+
for _, sent, chunk, sc in candidates:
|
| 405 |
+
fp = sent[:120].lower()
|
| 406 |
+
if fp in seen:
|
| 407 |
+
continue
|
| 408 |
+
seen.add(fp)
|
| 409 |
+
evidence.append(ClinicalEvidence(
|
| 410 |
+
sentence=shorten(sent, width=420, placeholder="..."),
|
| 411 |
+
source=chunk.source,
|
| 412 |
+
page_start=chunk.page_start,
|
| 413 |
+
page_end=chunk.page_end,
|
| 414 |
+
score=sc,
|
| 415 |
+
))
|
| 416 |
+
if len(evidence) >= evidence_limit:
|
| 417 |
+
break
|
| 418 |
+
|
| 419 |
+
if not evidence:
|
| 420 |
+
# Fall back to chunk-level evidence if no sentence overlapped.
|
| 421 |
+
for chunk, sc in top_chunks[:evidence_limit]:
|
| 422 |
+
evidence.append(ClinicalEvidence(
|
| 423 |
+
sentence=shorten(chunk.text, width=420, placeholder="..."),
|
| 424 |
+
source=chunk.source,
|
| 425 |
+
page_start=chunk.page_start,
|
| 426 |
+
page_end=chunk.page_end,
|
| 427 |
+
score=sc,
|
| 428 |
+
))
|
| 429 |
+
|
| 430 |
+
lines = ["Clinical RAG evidence (not a medical diagnosis):"]
|
| 431 |
+
for ev in evidence:
|
| 432 |
+
page = (
|
| 433 |
+
f"p.{ev.page_start}" if ev.page_start == ev.page_end
|
| 434 |
+
else f"pp.{ev.page_start}-{ev.page_end}"
|
| 435 |
+
)
|
| 436 |
+
lines.append(f"- {ev.sentence} [{ev.source}, {page} | score={ev.score:.3f}]")
|
| 437 |
+
summary = "\n".join(lines)
|
| 438 |
+
|
| 439 |
+
return ClinicalRetrievalResult(query=query, evidence=evidence, summary_text=summary)
|
| 440 |
+
```
|
| 441 |
+
|
| 442 |
+
Run tests → 5 passed.
|
| 443 |
+
|
| 444 |
+
- [ ] **Step 3:** commit: `feat(rag): TF-IDF clinical retrieval with Turkish/English query expansion`.
|
| 445 |
+
|
| 446 |
+
---
|
| 447 |
+
|
| 448 |
+
### Task 3: Wire into the agent's `retrieve_context` tool
|
| 449 |
+
|
| 450 |
+
**Files:**
|
| 451 |
+
- Modify: `src/agents/tools.py`
|
| 452 |
+
- Modify: `src/agents/prompts.py`
|
| 453 |
+
- Create: `tests/agents/test_tools_clinical_corpus.py`
|
| 454 |
+
|
| 455 |
+
- [ ] **Step 1: Failing test.**
|
| 456 |
+
|
| 457 |
+
`tests/agents/test_tools_clinical_corpus.py`:
|
| 458 |
+
|
| 459 |
+
```python
|
| 460 |
+
"""Tests: retrieve_context tool dispatches by `corpus`."""
|
| 461 |
+
from __future__ import annotations
|
| 462 |
+
|
| 463 |
+
from pathlib import Path
|
| 464 |
+
|
| 465 |
+
from src.agents.tools import build_default_tools
|
| 466 |
+
from tests.fixtures.build_tiny_clinical_index import build as build_tiny
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class TestClinicalCorpus:
|
| 470 |
+
def test_corpus_default_is_reference(self, tmp_path: Path) -> None:
|
| 471 |
+
clinical_idx = build_tiny(tmp_path / "tiny.pkl")
|
| 472 |
+
tools = {t.name: t for t in build_default_tools(
|
| 473 |
+
rag_index_dir=None,
|
| 474 |
+
clinical_rag_index_path=clinical_idx,
|
| 475 |
+
)}
|
| 476 |
+
tool = tools["retrieve_context"]
|
| 477 |
+
# `corpus` not provided → defaults to "reference".
|
| 478 |
+
out = tool.execute(tool.input_model.model_validate({"query": "test"}))
|
| 479 |
+
# With rag_index_dir=None, reference returns empty.
|
| 480 |
+
assert hasattr(out, "chunks")
|
| 481 |
+
|
| 482 |
+
def test_clinical_corpus_returns_evidence(self, tmp_path: Path) -> None:
|
| 483 |
+
clinical_idx = build_tiny(tmp_path / "tiny.pkl")
|
| 484 |
+
tools = {t.name: t for t in build_default_tools(
|
| 485 |
+
rag_index_dir=None,
|
| 486 |
+
clinical_rag_index_path=clinical_idx,
|
| 487 |
+
)}
|
| 488 |
+
tool = tools["retrieve_context"]
|
| 489 |
+
out = tool.execute(tool.input_model.model_validate({
|
| 490 |
+
"query": "exercise and Alzheimer",
|
| 491 |
+
"corpus": "clinical",
|
| 492 |
+
}))
|
| 493 |
+
assert len(out.chunks) > 0
|
| 494 |
+
# Each returned chunk has source + page metadata.
|
| 495 |
+
for c in out.chunks:
|
| 496 |
+
assert "source" in c and "text" in c
|
| 497 |
+
```
|
| 498 |
+
|
| 499 |
+
Run → fails (signature mismatch).
|
| 500 |
+
|
| 501 |
+
- [ ] **Step 2: Wire the tool.**
|
| 502 |
+
|
| 503 |
+
In `src/agents/tools.py`, the existing `RetrieveContextInput`/`RetrieveContextOutput` need the `corpus` field. Find the schemas (likely in `src/agents/schemas.py`) and add:
|
| 504 |
+
|
| 505 |
+
```python
|
| 506 |
+
from typing import Literal
|
| 507 |
+
|
| 508 |
+
class RetrieveContextInput(BaseModel):
|
| 509 |
+
query: str
|
| 510 |
+
k: int = 4
|
| 511 |
+
corpus: Literal["reference", "clinical"] = "reference"
|
| 512 |
+
```
|
| 513 |
+
|
| 514 |
+
`RetrieveContextOutput.chunks` already accepts dicts; no change needed there.
|
| 515 |
+
|
| 516 |
+
In `src/agents/tools.py`, add a `clinical_rag_index_path: Path | None = None` parameter to `build_default_tools`. Update `_make_retrieve_executor` to take both index sources:
|
| 517 |
+
|
| 518 |
+
```python
|
| 519 |
+
def _make_retrieve_executor(
|
| 520 |
+
rag_index_dir: Path | None,
|
| 521 |
+
clinical_rag_index_path: Path | None,
|
| 522 |
+
) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
|
| 523 |
+
# Lazily load the clinical payload at first use, cache for subsequent calls.
|
| 524 |
+
clinical_cache: dict[str, Any] = {}
|
| 525 |
+
|
| 526 |
+
def execute(inp: RetrieveContextInput) -> RetrieveContextOutput:
|
| 527 |
+
if inp.corpus == "clinical":
|
| 528 |
+
if clinical_rag_index_path is None:
|
| 529 |
+
logger.warning("retrieve_context corpus=clinical but no index path configured")
|
| 530 |
+
return RetrieveContextOutput(chunks=[])
|
| 531 |
+
if "payload" not in clinical_cache:
|
| 532 |
+
from src.rag.clinical.loader import load_index
|
| 533 |
+
clinical_cache["payload"] = load_index(clinical_rag_index_path)
|
| 534 |
+
from src.rag.clinical.retrieve import retrieve_clinical
|
| 535 |
+
result = retrieve_clinical(clinical_cache["payload"], inp.query, top_k=inp.k)
|
| 536 |
+
return RetrieveContextOutput(chunks=[
|
| 537 |
+
{
|
| 538 |
+
"source": ev.source,
|
| 539 |
+
"page_start": ev.page_start,
|
| 540 |
+
"page_end": ev.page_end,
|
| 541 |
+
"text": ev.sentence,
|
| 542 |
+
"score": ev.score,
|
| 543 |
+
}
|
| 544 |
+
for ev in result.evidence
|
| 545 |
+
])
|
| 546 |
+
|
| 547 |
+
# corpus == "reference" — existing FAISS path. Keep current behaviour.
|
| 548 |
+
... (preserve existing executor body) ...
|
| 549 |
+
|
| 550 |
+
return execute
|
| 551 |
+
```
|
| 552 |
+
|
| 553 |
+
In `src/api/routes.py`, where `build_default_tools(...)` is called inside `_build_orchestrator()` (around line 577), pass the new path:
|
| 554 |
+
|
| 555 |
+
```python
|
| 556 |
+
clinical_idx = Path(os.environ.get(
|
| 557 |
+
"CLINICAL_RAG_INDEX_PATH",
|
| 558 |
+
"data/external_rag/index/rag_index.pkl",
|
| 559 |
+
))
|
| 560 |
+
tools = build_default_tools(
|
| 561 |
+
rag_index_dir=rag_dir if rag_status["exists"] else None,
|
| 562 |
+
clinical_rag_index_path=clinical_idx if clinical_idx.exists() else None,
|
| 563 |
+
)
|
| 564 |
+
```
|
| 565 |
+
|
| 566 |
+
Update `src/agents/prompts.py` `retrieve_context` description (already mentions FAISS) — adapt to:
|
| 567 |
+
|
| 568 |
+
```
|
| 569 |
+
- retrieve_context: retrieve up to k passages from a knowledge base. Pass corpus="clinical" for medical-paper evidence (Alzheimer's / Parkinson's / lifestyle / nutrition; supports Turkish keywords); default corpus="reference" for the curated FAISS index.
|
| 570 |
+
```
|
| 571 |
+
|
| 572 |
+
- [ ] **Step 3:** `pytest -q` → 2 new tests + previous baseline + retrieve regressions checked.
|
| 573 |
+
|
| 574 |
+
- [ ] **Step 4:** commit: `feat(agents): retrieve_context corpus dispatch (reference vs clinical)`.
|
| 575 |
+
|
| 576 |
+
---
|
| 577 |
+
|
| 578 |
+
### Task 4: README + CLI sanity
|
| 579 |
+
|
| 580 |
+
**Files:**
|
| 581 |
+
- Modify: `README.md`
|
| 582 |
+
- Create: `scripts/clinical_rag_smoke.py`
|
| 583 |
+
|
| 584 |
+
- [ ] **Step 1:** small CLI tool to demo the corpus from the terminal:
|
| 585 |
+
|
| 586 |
+
```python
|
| 587 |
+
"""Smoke: ask the clinical corpus a question from the terminal.
|
| 588 |
+
|
| 589 |
+
Usage:
|
| 590 |
+
python scripts/clinical_rag_smoke.py "egzersiz Alzheimer feedback"
|
| 591 |
+
"""
|
| 592 |
+
from __future__ import annotations
|
| 593 |
+
|
| 594 |
+
import sys
|
| 595 |
+
from pathlib import Path
|
| 596 |
+
|
| 597 |
+
from src.rag.clinical.loader import load_index
|
| 598 |
+
from src.rag.clinical.retrieve import retrieve_clinical
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def main() -> None:
|
| 602 |
+
if len(sys.argv) < 2:
|
| 603 |
+
print(__doc__)
|
| 604 |
+
sys.exit(1)
|
| 605 |
+
query = " ".join(sys.argv[1:])
|
| 606 |
+
payload = load_index(Path("data/external_rag/index/rag_index.pkl"))
|
| 607 |
+
result = retrieve_clinical(payload, query, top_k=5, evidence_limit=5)
|
| 608 |
+
print(result.summary_text or "(no matches)")
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
if __name__ == "__main__":
|
| 612 |
+
main()
|
| 613 |
+
```
|
| 614 |
+
|
| 615 |
+
- [ ] **Step 2:** README addition (find the existing RAG section, append):
|
| 616 |
+
|
| 617 |
+
```markdown
|
| 618 |
+
### Clinical Corpus (TF-IDF, Turkish + English)
|
| 619 |
+
|
| 620 |
+
A second, lightweight RAG index covers 14 medical PDFs (Alzheimer's, Parkinson's, lifestyle, nutrition, exercise) using TF-IDF + sklearn. Source PDFs live at `data/external_rag/clinical_pdfs/` (gitignored — copy from your team's shared drive). Pre-built index at `data/external_rag/index/rag_index.pkl`.
|
| 621 |
+
|
| 622 |
+
Agent invocation:
|
| 623 |
+
|
| 624 |
+
```python
|
| 625 |
+
retrieve_context(query="egzersiz Alzheimer feedback", corpus="clinical", k=5)
|
| 626 |
+
```
|
| 627 |
+
|
| 628 |
+
Local CLI smoke:
|
| 629 |
+
|
| 630 |
+
```bash
|
| 631 |
+
python scripts/clinical_rag_smoke.py "egzersiz Alzheimer feedback"
|
| 632 |
+
```
|
| 633 |
+
|
| 634 |
+
The Turkish keywords `alzheimer`, `parkinson`, `egzersiz`, `beslenme`, `tani`, `tedavi`, `risk`, `unutkanlik`, `titreme`, `demans` auto-expand to English equivalents so Turkish queries hit English chunks.
|
| 635 |
+
```
|
| 636 |
+
|
| 637 |
+
- [ ] **Step 3:** commit: `docs(rag): document clinical TF-IDF corpus + add CLI smoke`.
|
| 638 |
+
|
| 639 |
+
---
|
| 640 |
+
|
| 641 |
+
## Self-review
|
| 642 |
+
|
| 643 |
+
1. **Spec coverage.** User said "fully integrate the new RAG folder". The wrapper imports the user's exact pickle schema, mirrors the Turkish expansion table verbatim, and surfaces the same evidence semantics (citation per sentence, source+page tags). ✓
|
| 644 |
+
2. **Backward compatibility.** Default corpus is `"reference"`, preserving existing FAISS behaviour. ✓
|
| 645 |
+
3. **No XAI / re-embedding.** The user's TF-IDF index is used as-is. We don't re-embed or fine-tune. ✓
|
| 646 |
+
4. **Independence.** No coupling to BBB, fusion, or MRI modules. ✓
|
| 647 |
+
5. **No placeholders.** Every step has the full code or full diff direction. ✓
|
| 648 |
+
|
| 649 |
+
---
|
| 650 |
+
|
| 651 |
+
## Execution handoff
|
| 652 |
+
|
| 653 |
+
Save and choose: subagent-driven (recommended) or inline executing-plans.
|