File size: 4,329 Bytes
460fcc2
 
 
 
 
 
 
 
91dde0d
460fcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0a7163
 
 
 
460fcc2
 
327b23d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460fcc2
 
 
 
91dde0d
 
 
 
 
 
 
 
460fcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Pydantic input/output schemas for orchestrator tools and the agent result.

These schemas double as OpenAI function-calling parameter definitions
(via `model_json_schema()`) and as runtime validation gates. Keep field
names lowercase + snake_case so prompts and JSON outputs align.
"""
from __future__ import annotations

from typing import Any, Literal

from pydantic import BaseModel, Field


# --- Pipeline tool inputs ---------------------------------------------------

class BBBPipelineInput(BaseModel):
    """Input for `run_bbb_pipeline` — a single SMILES string."""
    smiles: str = Field(..., description="A single molecular SMILES string, e.g. 'CCO'")
    top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP attributions to return")


class EEGPipelineInput(BaseModel):
    """Input for `run_eeg_pipeline` — path to an EEG file (.fif or .edf)."""
    input_path: str = Field(..., description="Path to EEG recording file (.fif or .edf)")
    epoch_duration_s: float = Field(2.0, gt=0.1, le=60.0)


class MRIPipelineInput(BaseModel):
    """Input for `run_mri_pipeline` — directory of NIfTI files + sites CSV."""
    input_dir: str = Field(..., description="Directory containing .nii.gz volumes")
    sites_csv: str | None = Field(
        None,
        description="CSV mapping subject_id → site; defaults to <input_dir>/sites.csv",
    )


class BBBPermeabilityMapInput(BaseModel):
    """Input for `compute_bbb_leakage_score` — MRI input + scoring mode."""
    input_path: str = Field(..., description="Path to MRI input (2D image for heuristic_proxy; 4D NIfTI for dce_onnx).")
    mode: Literal["heuristic_proxy", "dce_onnx"] = Field(
        "heuristic_proxy",
        description="'heuristic_proxy' (default) | 'dce_onnx' (real DCE artifact)",
    )


class BBBPermeabilityMapOutput(BaseModel):
    permeability_score: float
    interpretation: str
    method: str
    voxel_map_available: bool


class DrugDoseAdjustmentInput(BaseModel):
    """Input for `adjust_drug_dose` — baseline + patient + drug profile."""
    baseline_dose_mg: float = Field(..., gt=0.0)
    bbb_permeability_score: float = Field(..., ge=0.0, le=1.0)
    drug_bbb_permeable: bool | None = None
    smiles: str | None = Field(
        None, description="Optional SMILES; auto-resolves drug_bbb_permeable when given.",
    )


class DrugDoseAdjustmentOutput(BaseModel):
    recommended_dose_mg: float
    adjustment_factor: float
    risk_level: str
    rationale: str
    drug_bbb_permeable: bool | None = None


class RetrieveContextInput(BaseModel):
    """Input for `retrieve_context` — natural-language query into the KB."""
    query: str = Field(..., min_length=2, description="Search query for the knowledge base")
    k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
    corpus: Literal["reference", "clinical"] = Field(
        "reference",
        description=(
            "Which corpus to query. 'reference' = curated FAISS index (default). "
            "'clinical' = TF-IDF index over peer-reviewed Alzheimer's/Parkinson's "
            "papers with Turkish+English query expansion."
        ),
    )


# --- Pipeline tool outputs --------------------------------------------------

class BBBPipelineOutput(BaseModel):
    smiles: str
    label: int
    label_text: str
    confidence: float
    top_features: list[dict[str, Any]]
    drift_z: float | None = None


class EEGPipelineOutput(BaseModel):
    input_path: str
    output_path: str
    rows: int
    columns: int
    duration_sec: float


class MRIPipelineOutput(BaseModel):
    input_dir: str
    output_path: str
    rows: int
    columns: int
    duration_sec: float


class RetrieveContextOutput(BaseModel):
    query: str
    chunks: list[dict[str, Any]]


# --- Agent result -----------------------------------------------------------

class ToolTraceItem(BaseModel):
    """One step in the orchestrator's tool-call trace."""
    name: str
    args: dict[str, Any]
    result: dict[str, Any] | None = None
    error: str | None = None


class AgentResult(BaseModel):
    """Final orchestrator response: synthesized text + full trace."""
    text: str
    trace: list[ToolTraceItem] = Field(default_factory=list)
    model: str | None = None
    finish_reason: str = "complete"  # complete | max_steps | error