mekosotto Claude Sonnet 4.6 commited on
Commit
460fcc2
·
1 Parent(s): 978f645

feat(agents): Tool dataclass + registry + 4 tool wrappers (3 pipelines + RAG)

Browse files
src/agents/__init__.py ADDED
File without changes
src/agents/schemas.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic input/output schemas for orchestrator tools and the agent result.
2
+
3
+ These schemas double as OpenAI function-calling parameter definitions
4
+ (via `model_json_schema()`) and as runtime validation gates. Keep field
5
+ names lowercase + snake_case so prompts and JSON outputs align.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ # --- Pipeline tool inputs ---------------------------------------------------
15
+
16
+ class BBBPipelineInput(BaseModel):
17
+ """Input for `run_bbb_pipeline` — a single SMILES string."""
18
+ smiles: str = Field(..., description="A single molecular SMILES string, e.g. 'CCO'")
19
+ top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP attributions to return")
20
+
21
+
22
+ class EEGPipelineInput(BaseModel):
23
+ """Input for `run_eeg_pipeline` — path to an EEG file (.fif or .edf)."""
24
+ input_path: str = Field(..., description="Path to EEG recording file (.fif or .edf)")
25
+ epoch_duration_s: float = Field(2.0, gt=0.1, le=60.0)
26
+
27
+
28
+ class MRIPipelineInput(BaseModel):
29
+ """Input for `run_mri_pipeline` — directory of NIfTI files + sites CSV."""
30
+ input_dir: str = Field(..., description="Directory containing .nii.gz volumes")
31
+ sites_csv: str = Field(..., description="CSV mapping subject_id → site")
32
+
33
+
34
+ class RetrieveContextInput(BaseModel):
35
+ """Input for `retrieve_context` — natural-language query into the KB."""
36
+ query: str = Field(..., min_length=2, description="Search query for the knowledge base")
37
+ k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
38
+
39
+
40
+ # --- Pipeline tool outputs --------------------------------------------------
41
+
42
+ class BBBPipelineOutput(BaseModel):
43
+ smiles: str
44
+ label: int
45
+ label_text: str
46
+ confidence: float
47
+ top_features: list[dict[str, Any]]
48
+ drift_z: float | None = None
49
+
50
+
51
+ class EEGPipelineOutput(BaseModel):
52
+ input_path: str
53
+ output_path: str
54
+ rows: int
55
+ columns: int
56
+ duration_sec: float
57
+
58
+
59
+ class MRIPipelineOutput(BaseModel):
60
+ input_dir: str
61
+ output_path: str
62
+ rows: int
63
+ columns: int
64
+ duration_sec: float
65
+
66
+
67
+ class RetrieveContextOutput(BaseModel):
68
+ query: str
69
+ chunks: list[dict[str, Any]]
70
+
71
+
72
+ # --- Agent result -----------------------------------------------------------
73
+
74
+ class ToolTraceItem(BaseModel):
75
+ """One step in the orchestrator's tool-call trace."""
76
+ name: str
77
+ args: dict[str, Any]
78
+ result: dict[str, Any] | None = None
79
+ error: str | None = None
80
+
81
+
82
+ class AgentResult(BaseModel):
83
+ """Final orchestrator response: synthesized text + full trace."""
84
+ text: str
85
+ trace: list[ToolTraceItem] = Field(default_factory=list)
86
+ model: str | None = None
87
+ finish_reason: str = "complete" # complete | max_steps | error
src/agents/tools.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
2
+ function-callable tool the orchestrator can invoke.
3
+
4
+ Public entry: `build_default_tools(rag_index_dir)` returns the 4 tools.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Callable
11
+
12
+ from pydantic import BaseModel, ValidationError
13
+
14
+ from src.agents.schemas import (
15
+ BBBPipelineInput,
16
+ BBBPipelineOutput,
17
+ EEGPipelineInput,
18
+ EEGPipelineOutput,
19
+ MRIPipelineInput,
20
+ MRIPipelineOutput,
21
+ RetrieveContextInput,
22
+ RetrieveContextOutput,
23
+ )
24
+ from src.core.logger import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class Tool:
31
+ """One callable tool exposed to the orchestrator.
32
+
33
+ `execute(input_model_instance) -> output_model_instance` is the contract.
34
+ `invoke(args_dict)` validates the dict, runs execute, returns a plain dict.
35
+ """
36
+ name: str
37
+ description: str
38
+ input_model: type[BaseModel]
39
+ output_model: type[BaseModel]
40
+ execute: Callable[[Any], BaseModel]
41
+
42
+ def openai_schema(self) -> dict[str, Any]:
43
+ """OpenAI/OpenRouter function-calling schema for this tool."""
44
+ params = self.input_model.model_json_schema()
45
+ # OpenAI doesn't accept top-level $defs / title in some clients —
46
+ # strip the cosmetic ones; keep properties/required/type.
47
+ cleaned = {
48
+ "type": "object",
49
+ "properties": params.get("properties", {}),
50
+ "required": params.get("required", []),
51
+ }
52
+ return {
53
+ "type": "function",
54
+ "function": {
55
+ "name": self.name,
56
+ "description": self.description,
57
+ "parameters": cleaned,
58
+ },
59
+ }
60
+
61
+ def invoke(self, args: dict[str, Any]) -> dict[str, Any]:
62
+ try:
63
+ inp = self.input_model.model_validate(args)
64
+ except ValidationError as e:
65
+ raise ValueError(f"invalid input for {self.name}: {e}") from e
66
+ out = self.execute(inp)
67
+ return out.model_dump()
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Tool implementations — thin wrappers around existing pipelines + RAG.
72
+ # Heavy work stays in the underlying modules; these only adapt I/O.
73
+ # ---------------------------------------------------------------------------
74
+
75
+
76
+ def _execute_bbb(inp: BBBPipelineInput) -> BBBPipelineOutput:
77
+ """Predict + SHAP for a single SMILES, reusing the existing model surface."""
78
+ from src.api import routes as api_routes
79
+ from src.api.schemas import BBBPredictRequest
80
+
81
+ response = api_routes.predict_bbb(
82
+ BBBPredictRequest(smiles=inp.smiles, top_k=inp.top_k)
83
+ )
84
+ return BBBPipelineOutput(
85
+ smiles=inp.smiles,
86
+ label=response.label,
87
+ label_text=response.label_text,
88
+ confidence=response.confidence,
89
+ top_features=[f.model_dump() for f in response.top_features],
90
+ drift_z=response.drift_z,
91
+ )
92
+
93
+
94
+ def _execute_eeg(inp: EEGPipelineInput) -> EEGPipelineOutput:
95
+ """Run the EEG pipeline via the existing route function (run_eeg)."""
96
+ from src.api.schemas import EEGRequest
97
+ from src.api import routes as api_routes
98
+
99
+ out_path = Path("data/processed/eeg_features.parquet")
100
+ response = api_routes.run_eeg(
101
+ EEGRequest(
102
+ input_path=inp.input_path,
103
+ output_path=str(out_path),
104
+ epoch_duration_s=inp.epoch_duration_s,
105
+ )
106
+ )
107
+ return EEGPipelineOutput(
108
+ input_path=inp.input_path,
109
+ output_path=response.output_path,
110
+ rows=response.rows,
111
+ columns=response.columns,
112
+ duration_sec=response.duration_sec,
113
+ )
114
+
115
+
116
+ def _execute_mri(inp: MRIPipelineInput) -> MRIPipelineOutput:
117
+ """Run the MRI pipeline via the existing route function (run_mri)."""
118
+ from src.api.schemas import MRIRequest
119
+ from src.api import routes as api_routes
120
+
121
+ out_path = Path("data/processed/mri_features.parquet")
122
+ response = api_routes.run_mri(
123
+ MRIRequest(
124
+ input_dir=inp.input_dir,
125
+ sites_csv=inp.sites_csv,
126
+ output_path=str(out_path),
127
+ )
128
+ )
129
+ return MRIPipelineOutput(
130
+ input_dir=inp.input_dir,
131
+ output_path=response.output_path,
132
+ rows=response.rows,
133
+ columns=response.columns,
134
+ duration_sec=response.duration_sec,
135
+ )
136
+
137
+
138
+ def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
139
+ """Closure: capture the index dir; lazy-load the retriever on first call."""
140
+ state: dict[str, Any] = {"retriever": None}
141
+
142
+ def execute(inp: RetrieveContextInput) -> RetrieveContextOutput:
143
+ if rag_index_dir is None or not (rag_index_dir / "index.bin").exists():
144
+ return RetrieveContextOutput(query=inp.query, chunks=[])
145
+ if state["retriever"] is None:
146
+ from src.rag.retrieve import RAGRetriever
147
+ state["retriever"] = RAGRetriever.load(rag_index_dir)
148
+ hits = state["retriever"].search(inp.query, k=inp.k)
149
+ return RetrieveContextOutput(query=inp.query, chunks=hits)
150
+
151
+ return execute
152
+
153
+
154
+ def build_default_tools(rag_index_dir: Path | None) -> list[Tool]:
155
+ """Return the 4 tools the orchestrator gets by default."""
156
+ return [
157
+ Tool(
158
+ name="run_bbb_pipeline",
159
+ description=(
160
+ "Predict blood-brain-barrier permeability for a SINGLE SMILES "
161
+ "string. Use this when the user input looks like a molecule "
162
+ "(short alphanumeric string with no file extension, e.g. 'CCO', "
163
+ "'c1ccccc1'). Returns label, confidence, top SHAP features, drift."
164
+ ),
165
+ input_model=BBBPipelineInput,
166
+ output_model=BBBPipelineOutput,
167
+ execute=_execute_bbb,
168
+ ),
169
+ Tool(
170
+ name="run_eeg_pipeline",
171
+ description=(
172
+ "Run the EEG signal-processing pipeline (bandpass + ICA + "
173
+ "epoching + feature extraction) on an EEG recording file. Use "
174
+ "when input_path ends in .fif or .edf. Returns row/column "
175
+ "counts + duration."
176
+ ),
177
+ input_model=EEGPipelineInput,
178
+ output_model=EEGPipelineOutput,
179
+ execute=_execute_eeg,
180
+ ),
181
+ Tool(
182
+ name="run_mri_pipeline",
183
+ description=(
184
+ "Run the multi-site MRI ComBat-harmonization pipeline. Use "
185
+ "when input is a directory containing .nii.gz volumes paired "
186
+ "with a sites.csv. Returns row/column counts + duration."
187
+ ),
188
+ input_model=MRIPipelineInput,
189
+ output_model=MRIPipelineOutput,
190
+ execute=_execute_mri,
191
+ ),
192
+ Tool(
193
+ name="retrieve_context",
194
+ description=(
195
+ "Retrieve up to k passages from the curated reference knowledge "
196
+ "base. Use AFTER a pipeline tool returns, to ground your final "
197
+ "synthesis in cited literature. Formulate a focused query "
198
+ "based on the pipeline output (e.g., 'BBB permeability of "
199
+ "small lipophilic molecules' or 'ComBat site harmonization')."
200
+ ),
201
+ input_model=RetrieveContextInput,
202
+ output_model=RetrieveContextOutput,
203
+ execute=_make_retrieve_executor(rag_index_dir),
204
+ ),
205
+ ]
tests/agents/__init__.py ADDED
File without changes
tests/agents/test_tools.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.agents.tools — Tool dataclass + registry + 4 tool wrappers."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+ from pydantic import BaseModel
8
+
9
+ from src.agents.tools import (
10
+ Tool,
11
+ build_default_tools,
12
+ BBBPipelineInput,
13
+ EEGPipelineInput,
14
+ MRIPipelineInput,
15
+ RetrieveContextInput,
16
+ )
17
+
18
+
19
+ class _DummyInput(BaseModel):
20
+ x: int
21
+ y: str = "default"
22
+
23
+
24
+ class _DummyOutput(BaseModel):
25
+ result: int
26
+
27
+
28
+ class TestTool:
29
+ def test_openai_schema_shape(self) -> None:
30
+ tool = Tool(
31
+ name="dummy",
32
+ description="A dummy tool",
33
+ input_model=_DummyInput,
34
+ output_model=_DummyOutput,
35
+ execute=lambda inp: _DummyOutput(result=inp.x * 2),
36
+ )
37
+ schema = tool.openai_schema()
38
+ assert schema["type"] == "function"
39
+ assert schema["function"]["name"] == "dummy"
40
+ assert schema["function"]["description"] == "A dummy tool"
41
+ params = schema["function"]["parameters"]
42
+ assert params["type"] == "object"
43
+ assert "x" in params["properties"]
44
+ assert "x" in params["required"]
45
+ assert "y" not in params["required"] # has default
46
+
47
+ def test_invoke_validates_and_returns_dict(self) -> None:
48
+ tool = Tool(
49
+ name="dummy",
50
+ description="d",
51
+ input_model=_DummyInput,
52
+ output_model=_DummyOutput,
53
+ execute=lambda inp: _DummyOutput(result=inp.x * 2),
54
+ )
55
+ out = tool.invoke({"x": 5})
56
+ assert out == {"result": 10}
57
+
58
+ def test_invoke_invalid_input_raises(self) -> None:
59
+ tool = Tool(
60
+ name="dummy",
61
+ description="d",
62
+ input_model=_DummyInput,
63
+ output_model=_DummyOutput,
64
+ execute=lambda inp: _DummyOutput(result=inp.x * 2),
65
+ )
66
+ with pytest.raises(ValueError, match="invalid input"):
67
+ tool.invoke({"y": "missing-x"})
68
+
69
+
70
+ class TestBuildDefaultTools:
71
+ def test_default_set_has_four_tools(self, tmp_path: Path) -> None:
72
+ # build with placeholder paths; tools won't be invoked here
73
+ tools = build_default_tools(rag_index_dir=None)
74
+ names = {t.name for t in tools}
75
+ assert names == {
76
+ "run_bbb_pipeline",
77
+ "run_eeg_pipeline",
78
+ "run_mri_pipeline",
79
+ "retrieve_context",
80
+ }
81
+
82
+ def test_each_tool_has_pydantic_input_model(self) -> None:
83
+ tools = build_default_tools(rag_index_dir=None)
84
+ for t in tools:
85
+ assert issubclass(t.input_model, BaseModel)
86
+ assert issubclass(t.output_model, BaseModel)
87
+
88
+ def test_input_models_have_smiles_paths(self) -> None:
89
+ # verify the field names downstream system prompt depends on
90
+ assert "smiles" in BBBPipelineInput.model_fields
91
+ assert "input_path" in EEGPipelineInput.model_fields
92
+ assert "input_dir" in MRIPipelineInput.model_fields
93
+ assert "sites_csv" in MRIPipelineInput.model_fields
94
+ assert "query" in RetrieveContextInput.model_fields
95
+ assert "k" in RetrieveContextInput.model_fields