hackathon / src /api /schemas.py
mekosotto's picture
feat(researcher): DCE-MRI BBB permeability bridge + drug-dose adjuster
327b23d
"""Pydantic request / response models for the NeuroBridge FastAPI surface.
Each pipeline accepts its own request schema (BBBRequest / EEGRequest /
MRIRequest) but they all return a unified PipelineResponse — the dashboard
can render a single result card regardless of modality.
"""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
class BBBRequest(BaseModel):
input_path: str = Field(..., description="CSV path with a 'smiles' column")
output_path: str = Field(..., description="Parquet output path")
smiles_col: str = "smiles"
n_bits: int = 2048
radius: int = 2
class EEGRequest(BaseModel):
"""Field names mirror eeg_pipeline.run_pipeline kwargs exactly."""
input_path: str = Field(..., description="FIF or EDF file")
output_path: str = Field(..., description="Parquet output path")
epoch_duration_s: float = 2.0
eog_ch_name: str | None = None
n_components: int = 15
random_state: int = 97
class MRIRequest(BaseModel):
input_dir: str = Field(..., description="Directory of .nii.gz files")
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
output_path: str = Field(..., description="Parquet output path")
class PipelineResponse(BaseModel):
"""Uniform response for every pipeline route."""
status: str
output_path: str
rows: int
columns: int
duration_sec: float
mlflow_run_id: str | None = None
class HealthResponse(BaseModel):
status: str
pipelines: list[str]
class BBBPredictRequest(BaseModel):
"""Single-molecule BBB-permeability prediction request."""
smiles: str = Field(..., description="SMILES string; e.g. 'CCO' for ethanol")
top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP features to return")
class FeatureAttribution(BaseModel):
"""A single SHAP attribution: which fingerprint bit contributed and by how much."""
feature: str = Field(..., description="Fingerprint column name, e.g. 'fp_1234'")
shap_value: float = Field(
...,
description="Signed SHAP value for the predicted class (positive pushed model toward, negative away)",
)
class CalibrationContext(BaseModel):
"""Precision-at-confidence-threshold bin matched to a single prediction."""
threshold: float = Field(..., description="Lowest confidence threshold this bin covers (0.0-1.0)")
precision: float = Field(..., description="Precision on the held-out test set among predictions ≥ threshold")
support: int = Field(..., description="Number of held-out predictions falling in this bin")
class ModelProvenance(BaseModel):
"""Auditable provenance of the BBB model that produced a prediction."""
# Disable the `model_` protected-namespace check so `model_version` doesn't
# trip Pydantic v2's UserWarning (which our DoD gate escalates to error).
model_config = ConfigDict(protected_namespaces=())
mlflow_run_id: str | None = Field(None, description="MLflow run id of the most recent training run, if any")
model_version: str = Field("v1", description="Manually-bumped model version label")
train_date: str | None = Field(None, description="ISO 8601 train timestamp from MLflow run start_time")
n_examples: int | None = Field(None, description="Training set size (from model._neurobridge_train_stats[\"n_train\"])")
class BBBPredictResponse(BaseModel):
"""Decision-system payload: prediction + uncertainty + explanation + drift."""
label: int
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
confidence: float
top_features: list[FeatureAttribution]
calibration: CalibrationContext | None = Field(
None,
description="Statistical context: how often the model is right when this confident on held-out data.",
)
drift_z: float | None = Field(
None,
description=(
"Z-score of the trailing-100 confidence median against the "
"train-time median; None when warming up (<10 samples) or "
"when the model lacks _neurobridge_train_stats."
),
)
rolling_n: int = Field(
0,
description=(
"Number of confidence samples currently buffered in the worker's "
"rolling window (max 100). Zero on a fresh worker."
),
)
provenance: ModelProvenance | None = Field(
None,
description="Auditing metadata (MLflow run id, train date, n_examples).",
)
class BBBPermeabilityMapRequest(BaseModel):
"""Compute a per-patient BBB permeability score from MRI input."""
input_path: str = Field(
...,
description=(
"Path to MRI input. heuristic_proxy mode: 2D image (.png/.jpg) "
"consumed by the resnet18 4-class Alzheimer's classifier. "
"dce_onnx mode: 4D NIfTI (X,Y,Z,T) for the DCE Ktrans pipeline."
),
)
mode: str = Field(
"heuristic_proxy",
description="'heuristic_proxy' (default, demo-ready) | 'dce_onnx' (real DCE artifact)",
)
class BBBPermeabilityMapResponse(BaseModel):
"""Researcher-persona BBB leakage payload."""
permeability_score: float = Field(..., ge=0.0, le=1.0,
description="Scalar in [0,1]; 0=intact, 1=fully leaky.")
interpretation: str = Field(..., description="'BBB intact' | 'mild leakage' | 'moderate leakage' | 'severe leakage'")
method: str = Field(..., description="'heuristic_proxy' | 'dce_onnx'")
voxel_map_available: bool = False
class DrugDoseAdjustmentRequest(BaseModel):
"""Researcher-persona dose-revision request, given patient BBB + drug profile."""
smiles: str | None = Field(
None,
description=(
"Optional SMILES. When provided, the route auto-resolves "
"drug_bbb_permeable via the BBB classifier (overrides any "
"explicit value below)."
),
)
baseline_dose_mg: float = Field(..., gt=0.0, description="Standard adult dose in mg.")
bbb_permeability_score: float = Field(..., ge=0.0, le=1.0)
drug_bbb_permeable: bool | None = Field(
None,
description="If known, whether the drug crosses the BBB. Auto-resolved when smiles is given.",
)
class DrugDoseAdjustmentResponse(BaseModel):
"""Recommended dose with rationale. NOT medical advice."""
recommended_dose_mg: float
adjustment_factor: float = Field(..., ge=0.0, le=1.0)
risk_level: str = Field(..., description="'low' | 'moderate' | 'high'")
rationale: str
drug_bbb_permeable: bool | None = Field(
None,
description="Echoed back; reflects what was used in the calculation (auto-resolved if smiles was given).",
)
class EEGPredictRequest(BaseModel):
"""Single-subject EEG-features prediction request."""
features: list[float] = Field(
..., min_length=1,
description="EEG features matching the classifier's training-time feature count.",
)
class EEGClassProbability(BaseModel):
"""One EEG model class probability."""
label: int
label_text: str
probability: float
class EEGPredictResponse(BaseModel):
"""EEG prediction payload — same shape as MRIPredictResponse minus model_path."""
label: int
label_text: str
confidence: float
probabilities: list[EEGClassProbability]
class MRIPredictRequest(BaseModel):
"""Single-subject MRI image prediction request."""
input_path: str = Field(
...,
description=(
"Path to MRI input. With MRI_MODEL_KIND=volumetric_onnx (default), "
"expects a .nii/.nii.gz volume. With MRI_MODEL_KIND=resnet18_2d, "
"expects a 2D image (.png/.jpg)."
),
)
target_shape: tuple[int, int, int] = Field(
(64, 64, 64),
description="Model preprocessing resize target as (D, H, W)",
)
label_names: list[str] | None = Field(
None,
description="Optional class labels matching ONNX output order",
)
class MRIClassProbability(BaseModel):
"""One MRI model class probability."""
label: int
label_text: str
probability: float
class MRIPredictResponse(BaseModel):
"""MRI DL decision payload from a volumetric ONNX model."""
model_config = ConfigDict(protected_namespaces=())
label: int
label_text: str
confidence: float
probabilities: list[MRIClassProbability]
input_path: str
model_path: str
class MRIDiagnosticsRequest(BaseModel):
"""Request body for /pipeline/mri/diagnostics — same as MRIRequest minus output_path."""
input_dir: str = Field(..., description="Directory of .nii.gz files")
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
class HarmonizationRow(BaseModel):
subject_id: str
site: str
feature: str
feature_value: float
harmonization_state: str
class MRIDiagnosticsResponse(BaseModel):
"""Long-format pre/post ComBat data for visualization."""
rows: list[HarmonizationRow]
site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
class BBBExplainRequest(BaseModel):
"""Day-7 T3B: payload for POST /explain/bbb (chat-style explainer)."""
smiles: str = Field(..., description="SMILES string of the molecule")
label: int = Field(..., description="Predicted label (0 = non-permeable, 1 = permeable)")
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
confidence: float = Field(..., ge=0.0, le=1.0)
top_features: list[FeatureAttribution] = Field(
..., min_length=1,
description="Non-empty list of SHAP attributions; an empty list returns 400.",
)
calibration: CalibrationContext | None = None
drift_z: float | None = None
user_question: str | None = Field(
None,
description="Optional question from the user; passed to the LLM prompt only.",
)
class BBBExplainResponse(BaseModel):
"""Day-7 T3B: response from POST /explain/bbb."""
rationale: str = Field(..., description="2-4 sentence natural-language explanation")
source: str = Field(..., description="'llm' or 'template'")
model: str | None = Field(
None,
description="LLM model name when source='llm'; None when source='template'",
)
class EEGExplainRequest(BaseModel):
"""Day-8 T1B: payload for POST /explain/eeg."""
rows: int = Field(..., ge=0, description="Number of epochs produced")
columns: int = Field(..., ge=0, description="Number of features per epoch")
duration_sec: float = Field(..., ge=0.0, description="Pipeline wall-clock seconds")
mlflow_run_id: str | None = Field(None, description="MLflow run id, if available")
user_question: str | None = Field(None, description="Optional user question for the LLM prompt")
class EEGExplainResponse(BaseModel):
"""Day-8 T1B: response from POST /explain/eeg."""
rationale: str
source: str
model: str | None = None
class MRIExplainRequest(BaseModel):
"""Day-8 T1B: payload for POST /explain/mri."""
site_gap_pre: float = Field(..., ge=0.0)
site_gap_post: float = Field(..., ge=0.0)
reduction_factor: float = Field(..., ge=0.0)
n_subjects: int = Field(..., ge=0)
user_question: str | None = None
class MRIExplainResponse(BaseModel):
"""Day-8 T1B: response from POST /explain/mri."""
rationale: str
source: str
model: str | None = None
class MLflowRunSummary(BaseModel):
"""One MLflow run row for the Experiments tab table."""
run_id: str
experiment_name: str
start_time: str # ISO 8601
status: str
metrics: dict[str, float] = Field(default_factory=dict)
params: dict[str, str] = Field(default_factory=dict)
class MLflowRunsResponse(BaseModel):
"""Response for GET /experiments/runs."""
runs: list[MLflowRunSummary]
class RunDiffRequest(BaseModel):
"""Request body for POST /experiments/diff."""
run_id_a: str
run_id_b: str
class RunDiffRow(BaseModel):
"""One row of a run-vs-run diff: metric/param key + value pair."""
key: str
kind: str # "metric" | "param"
value_a: str | None
value_b: str | None
differs: bool
class RunDiffResponse(BaseModel):
"""Response for POST /experiments/diff: side-by-side metric/param diff."""
rows: list[RunDiffRow]
# --- Agent surface (orchestrator + RAG) ------------------------------------
class AgentRunRequest(BaseModel):
"""User input to the orchestrator."""
user_input: str = Field(..., min_length=1, description="SMILES, file path, or directory path")
user_question: str | None = Field(
None, description="Optional natural-language question to language-match the response"
)
sites_csv: str | None = Field(
None,
description="Optional MRI sites CSV. Defaults to <user_input>/sites.csv for directory inputs.",
)
class AgentToolTraceItem(BaseModel):
name: str
args: dict = Field(default_factory=dict)
result: dict | None = None
error: str | None = None
class AgentRunResponse(BaseModel):
text: str
trace: list[AgentToolTraceItem] = Field(default_factory=list)
model: str | None = None
finish_reason: str = "complete"
# --- Fusion engine surface --------------------------------------------------
# Re-export the fusion types so the API surface lives in one file but the
# implementation stays in src/fusion. This keeps `from src.api.schemas import *`
# style imports stable for the frontend layer.
from src.fusion.types import ( # noqa: E402,F401
ClinicalScores as FusionClinicalScores,
FusionInput as FusionRequest,
FusionOutput as FusionResponse,
ModalityPrediction as FusionModalityPrediction,
)