File size: 8,575 Bytes
fae874a d69f171 fae874a ae883d4 42366a8 28ca4f9 d69f171 28ca4f9 ae883d4 c26a55c ae883d4 42366a8 c26a55c 28ca4f9 985240b 5e9f487 3f348a3 d4000ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """Pydantic request / response models for the NeuroBridge FastAPI surface.
Each pipeline accepts its own request schema (BBBRequest / EEGRequest /
MRIRequest) but they all return a unified PipelineResponse — the dashboard
can render a single result card regardless of modality.
"""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
class BBBRequest(BaseModel):
input_path: str = Field(..., description="CSV path with a 'smiles' column")
output_path: str = Field(..., description="Parquet output path")
smiles_col: str = "smiles"
n_bits: int = 2048
radius: int = 2
class EEGRequest(BaseModel):
"""Field names mirror eeg_pipeline.run_pipeline kwargs exactly."""
input_path: str = Field(..., description="FIF or EDF file")
output_path: str = Field(..., description="Parquet output path")
epoch_duration_s: float = 2.0
eog_ch_name: str | None = None
n_components: int = 15
random_state: int = 97
class MRIRequest(BaseModel):
input_dir: str = Field(..., description="Directory of .nii.gz files")
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
output_path: str = Field(..., description="Parquet output path")
class PipelineResponse(BaseModel):
"""Uniform response for every pipeline route."""
status: str
output_path: str
rows: int
columns: int
duration_sec: float
mlflow_run_id: str | None = None
class HealthResponse(BaseModel):
status: str
pipelines: list[str]
class BBBPredictRequest(BaseModel):
"""Single-molecule BBB-permeability prediction request."""
smiles: str = Field(..., description="SMILES string; e.g. 'CCO' for ethanol")
top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP features to return")
class FeatureAttribution(BaseModel):
"""A single SHAP attribution: which fingerprint bit contributed and by how much."""
feature: str = Field(..., description="Fingerprint column name, e.g. 'fp_1234'")
shap_value: float = Field(
...,
description="Signed SHAP value for the predicted class (positive pushed model toward, negative away)",
)
class CalibrationContext(BaseModel):
"""Precision-at-confidence-threshold bin matched to a single prediction."""
threshold: float = Field(..., description="Lowest confidence threshold this bin covers (0.0-1.0)")
precision: float = Field(..., description="Precision on the held-out test set among predictions ≥ threshold")
support: int = Field(..., description="Number of held-out predictions falling in this bin")
class ModelProvenance(BaseModel):
"""Auditable provenance of the BBB model that produced a prediction."""
# Disable the `model_` protected-namespace check so `model_version` doesn't
# trip Pydantic v2's UserWarning (which our DoD gate escalates to error).
model_config = ConfigDict(protected_namespaces=())
mlflow_run_id: str | None = Field(None, description="MLflow run id of the most recent training run, if any")
model_version: str = Field("v1", description="Manually-bumped model version label")
train_date: str | None = Field(None, description="ISO 8601 train timestamp from MLflow run start_time")
n_examples: int | None = Field(None, description="Training set size (from model._neurobridge_train_stats[\"n_train\"])")
class BBBPredictResponse(BaseModel):
"""Decision-system payload: prediction + uncertainty + explanation + drift."""
label: int
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
confidence: float
top_features: list[FeatureAttribution]
calibration: CalibrationContext | None = Field(
None,
description="Statistical context: how often the model is right when this confident on held-out data.",
)
drift_z: float | None = Field(
None,
description=(
"Z-score of the trailing-100 confidence median against the "
"train-time median; None when warming up (<10 samples) or "
"when the model lacks _neurobridge_train_stats."
),
)
rolling_n: int = Field(
0,
description=(
"Number of confidence samples currently buffered in the worker's "
"rolling window (max 100). Zero on a fresh worker."
),
)
provenance: ModelProvenance | None = Field(
None,
description="Auditing metadata (MLflow run id, train date, n_examples).",
)
class MRIDiagnosticsRequest(BaseModel):
"""Request body for /pipeline/mri/diagnostics — same as MRIRequest minus output_path."""
input_dir: str = Field(..., description="Directory of .nii.gz files")
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
class HarmonizationRow(BaseModel):
subject_id: str
site: str
feature: str
feature_value: float
harmonization_state: str
class MRIDiagnosticsResponse(BaseModel):
"""Long-format pre/post ComBat data for visualization."""
rows: list[HarmonizationRow]
site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
class BBBExplainRequest(BaseModel):
"""Day-7 T3B: payload for POST /explain/bbb (chat-style explainer)."""
smiles: str = Field(..., description="SMILES string of the molecule")
label: int = Field(..., description="Predicted label (0 = non-permeable, 1 = permeable)")
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
confidence: float = Field(..., ge=0.0, le=1.0)
top_features: list[FeatureAttribution] = Field(
..., min_length=1,
description="Non-empty list of SHAP attributions; an empty list returns 400.",
)
calibration: CalibrationContext | None = None
drift_z: float | None = None
user_question: str | None = Field(
None,
description="Optional question from the user; passed to the LLM prompt only.",
)
class BBBExplainResponse(BaseModel):
"""Day-7 T3B: response from POST /explain/bbb."""
rationale: str = Field(..., description="2-4 sentence natural-language explanation")
source: str = Field(..., description="'llm' or 'template'")
model: str | None = Field(
None,
description="LLM model name when source='llm'; None when source='template'",
)
class EEGExplainRequest(BaseModel):
"""Day-8 T1B: payload for POST /explain/eeg."""
rows: int = Field(..., ge=0, description="Number of epochs produced")
columns: int = Field(..., ge=0, description="Number of features per epoch")
duration_sec: float = Field(..., ge=0.0, description="Pipeline wall-clock seconds")
mlflow_run_id: str | None = Field(None, description="MLflow run id, if available")
user_question: str | None = Field(None, description="Optional user question for the LLM prompt")
class EEGExplainResponse(BaseModel):
"""Day-8 T1B: response from POST /explain/eeg."""
rationale: str
source: str
model: str | None = None
class MRIExplainRequest(BaseModel):
"""Day-8 T1B: payload for POST /explain/mri."""
site_gap_pre: float = Field(..., ge=0.0)
site_gap_post: float = Field(..., ge=0.0)
reduction_factor: float = Field(..., ge=0.0)
n_subjects: int = Field(..., ge=0)
user_question: str | None = None
class MRIExplainResponse(BaseModel):
"""Day-8 T1B: response from POST /explain/mri."""
rationale: str
source: str
model: str | None = None
class MLflowRunSummary(BaseModel):
"""One MLflow run row for the Experiments tab table."""
run_id: str
experiment_name: str
start_time: str # ISO 8601
status: str
metrics: dict[str, float] = Field(default_factory=dict)
params: dict[str, str] = Field(default_factory=dict)
class MLflowRunsResponse(BaseModel):
"""Response for GET /experiments/runs."""
runs: list[MLflowRunSummary]
class RunDiffRequest(BaseModel):
"""Request body for POST /experiments/diff."""
run_id_a: str
run_id_b: str
class RunDiffRow(BaseModel):
"""One row of a run-vs-run diff: metric/param key + value pair."""
key: str
kind: str # "metric" | "param"
value_a: str | None
value_b: str | None
differs: bool
class RunDiffResponse(BaseModel):
"""Response for POST /experiments/diff: side-by-side metric/param diff."""
rows: list[RunDiffRow]
|