File size: 6,533 Bytes
fae874a d69f171 fae874a ae883d4 42366a8 28ca4f9 d69f171 28ca4f9 ae883d4 c26a55c ae883d4 42366a8 c26a55c 28ca4f9 985240b 5e9f487 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Pydantic request / response models for the NeuroBridge FastAPI surface.
Each pipeline accepts its own request schema (BBBRequest / EEGRequest /
MRIRequest) but they all return a unified PipelineResponse — the dashboard
can render a single result card regardless of modality.
"""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
class BBBRequest(BaseModel):
input_path: str = Field(..., description="CSV path with a 'smiles' column")
output_path: str = Field(..., description="Parquet output path")
smiles_col: str = "smiles"
n_bits: int = 2048
radius: int = 2
class EEGRequest(BaseModel):
"""Field names mirror eeg_pipeline.run_pipeline kwargs exactly."""
input_path: str = Field(..., description="FIF or EDF file")
output_path: str = Field(..., description="Parquet output path")
epoch_duration_s: float = 2.0
eog_ch_name: str | None = None
n_components: int = 15
random_state: int = 97
class MRIRequest(BaseModel):
input_dir: str = Field(..., description="Directory of .nii.gz files")
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
output_path: str = Field(..., description="Parquet output path")
class PipelineResponse(BaseModel):
"""Uniform response for every pipeline route."""
status: str
output_path: str
rows: int
columns: int
duration_sec: float
mlflow_run_id: str | None = None
class HealthResponse(BaseModel):
status: str
pipelines: list[str]
class BBBPredictRequest(BaseModel):
"""Single-molecule BBB-permeability prediction request."""
smiles: str = Field(..., description="SMILES string; e.g. 'CCO' for ethanol")
top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP features to return")
class FeatureAttribution(BaseModel):
"""A single SHAP attribution: which fingerprint bit contributed and by how much."""
feature: str = Field(..., description="Fingerprint column name, e.g. 'fp_1234'")
shap_value: float = Field(
...,
description="Signed SHAP value for the predicted class (positive pushed model toward, negative away)",
)
class CalibrationContext(BaseModel):
"""Precision-at-confidence-threshold bin matched to a single prediction."""
threshold: float = Field(..., description="Lowest confidence threshold this bin covers (0.0-1.0)")
precision: float = Field(..., description="Precision on the held-out test set among predictions ≥ threshold")
support: int = Field(..., description="Number of held-out predictions falling in this bin")
class ModelProvenance(BaseModel):
"""Auditable provenance of the BBB model that produced a prediction."""
# Disable the `model_` protected-namespace check so `model_version` doesn't
# trip Pydantic v2's UserWarning (which our DoD gate escalates to error).
model_config = ConfigDict(protected_namespaces=())
mlflow_run_id: str | None = Field(None, description="MLflow run id of the most recent training run, if any")
model_version: str = Field("v1", description="Manually-bumped model version label")
train_date: str | None = Field(None, description="ISO 8601 train timestamp from MLflow run start_time")
n_examples: int | None = Field(None, description="Training set size (from model._neurobridge_train_stats[\"n_train\"])")
class BBBPredictResponse(BaseModel):
"""Decision-system payload: prediction + uncertainty + explanation + drift."""
label: int
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
confidence: float
top_features: list[FeatureAttribution]
calibration: CalibrationContext | None = Field(
None,
description="Statistical context: how often the model is right when this confident on held-out data.",
)
drift_z: float | None = Field(
None,
description=(
"Z-score of the trailing-100 confidence median against the "
"train-time median; None when warming up (<10 samples) or "
"when the model lacks _neurobridge_train_stats."
),
)
rolling_n: int = Field(
0,
description=(
"Number of confidence samples currently buffered in the worker's "
"rolling window (max 100). Zero on a fresh worker."
),
)
provenance: ModelProvenance | None = Field(
None,
description="Auditing metadata (MLflow run id, train date, n_examples).",
)
class MRIDiagnosticsRequest(BaseModel):
"""Request body for /pipeline/mri/diagnostics — same as MRIRequest minus output_path."""
input_dir: str = Field(..., description="Directory of .nii.gz files")
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
class HarmonizationRow(BaseModel):
subject_id: str
site: str
feature: str
feature_value: float
harmonization_state: str
class MRIDiagnosticsResponse(BaseModel):
"""Long-format pre/post ComBat data for visualization."""
rows: list[HarmonizationRow]
site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
class BBBExplainRequest(BaseModel):
"""Day-7 T3B: payload for POST /explain/bbb (chat-style explainer)."""
smiles: str = Field(..., description="SMILES string of the molecule")
label: int = Field(..., description="Predicted label (0 = non-permeable, 1 = permeable)")
label_text: str = Field(..., description="'permeable' or 'non-permeable'")
confidence: float = Field(..., ge=0.0, le=1.0)
top_features: list[FeatureAttribution] = Field(
..., min_length=1,
description="Non-empty list of SHAP attributions; an empty list returns 400.",
)
calibration: CalibrationContext | None = None
drift_z: float | None = None
user_question: str | None = Field(
None,
description="Optional question from the user; passed to the LLM prompt only.",
)
class BBBExplainResponse(BaseModel):
"""Day-7 T3B: response from POST /explain/bbb."""
rationale: str = Field(..., description="2-4 sentence natural-language explanation")
source: str = Field(..., description="'llm' or 'template'")
model: str | None = Field(
None,
description="LLM model name when source='llm'; None when source='template'",
)
|