File size: 4,329 Bytes
460fcc2 91dde0d 460fcc2 c0a7163 460fcc2 327b23d 460fcc2 91dde0d 460fcc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """Pydantic input/output schemas for orchestrator tools and the agent result.
These schemas double as OpenAI function-calling parameter definitions
(via `model_json_schema()`) and as runtime validation gates. Keep field
names lowercase + snake_case so prompts and JSON outputs align.
"""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
# --- Pipeline tool inputs ---------------------------------------------------
class BBBPipelineInput(BaseModel):
"""Input for `run_bbb_pipeline` — a single SMILES string."""
smiles: str = Field(..., description="A single molecular SMILES string, e.g. 'CCO'")
top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP attributions to return")
class EEGPipelineInput(BaseModel):
"""Input for `run_eeg_pipeline` — path to an EEG file (.fif or .edf)."""
input_path: str = Field(..., description="Path to EEG recording file (.fif or .edf)")
epoch_duration_s: float = Field(2.0, gt=0.1, le=60.0)
class MRIPipelineInput(BaseModel):
"""Input for `run_mri_pipeline` — directory of NIfTI files + sites CSV."""
input_dir: str = Field(..., description="Directory containing .nii.gz volumes")
sites_csv: str | None = Field(
None,
description="CSV mapping subject_id → site; defaults to <input_dir>/sites.csv",
)
class BBBPermeabilityMapInput(BaseModel):
"""Input for `compute_bbb_leakage_score` — MRI input + scoring mode."""
input_path: str = Field(..., description="Path to MRI input (2D image for heuristic_proxy; 4D NIfTI for dce_onnx).")
mode: Literal["heuristic_proxy", "dce_onnx"] = Field(
"heuristic_proxy",
description="'heuristic_proxy' (default) | 'dce_onnx' (real DCE artifact)",
)
class BBBPermeabilityMapOutput(BaseModel):
permeability_score: float
interpretation: str
method: str
voxel_map_available: bool
class DrugDoseAdjustmentInput(BaseModel):
"""Input for `adjust_drug_dose` — baseline + patient + drug profile."""
baseline_dose_mg: float = Field(..., gt=0.0)
bbb_permeability_score: float = Field(..., ge=0.0, le=1.0)
drug_bbb_permeable: bool | None = None
smiles: str | None = Field(
None, description="Optional SMILES; auto-resolves drug_bbb_permeable when given.",
)
class DrugDoseAdjustmentOutput(BaseModel):
recommended_dose_mg: float
adjustment_factor: float
risk_level: str
rationale: str
drug_bbb_permeable: bool | None = None
class RetrieveContextInput(BaseModel):
"""Input for `retrieve_context` — natural-language query into the KB."""
query: str = Field(..., min_length=2, description="Search query for the knowledge base")
k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
corpus: Literal["reference", "clinical"] = Field(
"reference",
description=(
"Which corpus to query. 'reference' = curated FAISS index (default). "
"'clinical' = TF-IDF index over peer-reviewed Alzheimer's/Parkinson's "
"papers with Turkish+English query expansion."
),
)
# --- Pipeline tool outputs --------------------------------------------------
class BBBPipelineOutput(BaseModel):
smiles: str
label: int
label_text: str
confidence: float
top_features: list[dict[str, Any]]
drift_z: float | None = None
class EEGPipelineOutput(BaseModel):
input_path: str
output_path: str
rows: int
columns: int
duration_sec: float
class MRIPipelineOutput(BaseModel):
input_dir: str
output_path: str
rows: int
columns: int
duration_sec: float
class RetrieveContextOutput(BaseModel):
query: str
chunks: list[dict[str, Any]]
# --- Agent result -----------------------------------------------------------
class ToolTraceItem(BaseModel):
"""One step in the orchestrator's tool-call trace."""
name: str
args: dict[str, Any]
result: dict[str, Any] | None = None
error: str | None = None
class AgentResult(BaseModel):
"""Final orchestrator response: synthesized text + full trace."""
text: str
trace: list[ToolTraceItem] = Field(default_factory=list)
model: str | None = None
finish_reason: str = "complete" # complete | max_steps | error
|