mekosotto Claude Sonnet 4.6 commited on
Commit
6d2aa47
·
1 Parent(s): 2091a1b

fix(agents/tools): parameterize processed_dir + translate HTTPException → ValueError

Browse files

Replace top-level _execute_bbb/eeg/mri functions with closure factories
(_make_bbb/eeg/mri_executor) that (a) accept processed_dir so output paths
are no longer hard-coded relative strings, and (b) catch HTTPException from
route calls and re-raise as ValueError with a clean message.

Add build_default_tools(processed_dir=Path("data/processed")) keyword arg
for backwards-compatible parameterization.

Add 4 new tests: processed_dir threading, retrieve short-circuit with no
index, backwards-compat default, and HTTPException → ValueError translation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. src/agents/tools.py +78 -60
  2. tests/agents/test_tools.py +33 -0
src/agents/tools.py CHANGED
@@ -73,66 +73,81 @@ class Tool:
73
  # ---------------------------------------------------------------------------
74
 
75
 
76
- def _execute_bbb(inp: BBBPipelineInput) -> BBBPipelineOutput:
77
- """Predict + SHAP for a single SMILES, reusing the existing model surface."""
78
- from src.api import routes as api_routes
79
- from src.api.schemas import BBBPredictRequest
80
-
81
- response = api_routes.predict_bbb(
82
- BBBPredictRequest(smiles=inp.smiles, top_k=inp.top_k)
83
- )
84
- return BBBPipelineOutput(
85
- smiles=inp.smiles,
86
- label=response.label,
87
- label_text=response.label_text,
88
- confidence=response.confidence,
89
- top_features=[f.model_dump() for f in response.top_features],
90
- drift_z=response.drift_z,
91
- )
92
-
93
-
94
- def _execute_eeg(inp: EEGPipelineInput) -> EEGPipelineOutput:
95
- """Run the EEG pipeline via the existing route function (run_eeg)."""
96
- from src.api.schemas import EEGRequest
97
- from src.api import routes as api_routes
98
-
99
- out_path = Path("data/processed/eeg_features.parquet")
100
- response = api_routes.run_eeg(
101
- EEGRequest(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  input_path=inp.input_path,
103
- output_path=str(out_path),
104
- epoch_duration_s=inp.epoch_duration_s,
 
 
105
  )
106
- )
107
- return EEGPipelineOutput(
108
- input_path=inp.input_path,
109
- output_path=response.output_path,
110
- rows=response.rows,
111
- columns=response.columns,
112
- duration_sec=response.duration_sec,
113
- )
114
-
115
-
116
- def _execute_mri(inp: MRIPipelineInput) -> MRIPipelineOutput:
117
- """Run the MRI pipeline via the existing route function (run_mri)."""
118
- from src.api.schemas import MRIRequest
119
- from src.api import routes as api_routes
120
-
121
- out_path = Path("data/processed/mri_features.parquet")
122
- response = api_routes.run_mri(
123
- MRIRequest(
 
 
 
124
  input_dir=inp.input_dir,
125
- sites_csv=inp.sites_csv,
126
- output_path=str(out_path),
 
 
127
  )
128
- )
129
- return MRIPipelineOutput(
130
- input_dir=inp.input_dir,
131
- output_path=response.output_path,
132
- rows=response.rows,
133
- columns=response.columns,
134
- duration_sec=response.duration_sec,
135
- )
136
 
137
 
138
  def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
@@ -151,7 +166,10 @@ def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveCon
151
  return execute
152
 
153
 
154
- def build_default_tools(rag_index_dir: Path | None) -> list[Tool]:
 
 
 
155
  """Return the 4 tools the orchestrator gets by default."""
156
  return [
157
  Tool(
@@ -164,7 +182,7 @@ def build_default_tools(rag_index_dir: Path | None) -> list[Tool]:
164
  ),
165
  input_model=BBBPipelineInput,
166
  output_model=BBBPipelineOutput,
167
- execute=_execute_bbb,
168
  ),
169
  Tool(
170
  name="run_eeg_pipeline",
@@ -176,7 +194,7 @@ def build_default_tools(rag_index_dir: Path | None) -> list[Tool]:
176
  ),
177
  input_model=EEGPipelineInput,
178
  output_model=EEGPipelineOutput,
179
- execute=_execute_eeg,
180
  ),
181
  Tool(
182
  name="run_mri_pipeline",
@@ -187,7 +205,7 @@ def build_default_tools(rag_index_dir: Path | None) -> list[Tool]:
187
  ),
188
  input_model=MRIPipelineInput,
189
  output_model=MRIPipelineOutput,
190
- execute=_execute_mri,
191
  ),
192
  Tool(
193
  name="retrieve_context",
 
73
  # ---------------------------------------------------------------------------
74
 
75
 
76
+ def _make_bbb_executor() -> Callable[[BBBPipelineInput], BBBPipelineOutput]:
77
+ """Closure factory: BBB permeability prediction + SHAP, translates HTTPException."""
78
+ def execute(inp: BBBPipelineInput) -> BBBPipelineOutput:
79
+ from src.api import routes as api_routes
80
+ from src.api.schemas import BBBPredictRequest
81
+ from fastapi import HTTPException
82
+ try:
83
+ response = api_routes.predict_bbb(
84
+ BBBPredictRequest(smiles=inp.smiles, top_k=inp.top_k)
85
+ )
86
+ except HTTPException as e:
87
+ raise ValueError(f"bbb tool failed: {e.detail}") from e
88
+ return BBBPipelineOutput(
89
+ smiles=inp.smiles,
90
+ label=response.label,
91
+ label_text=response.label_text,
92
+ confidence=response.confidence,
93
+ top_features=[f.model_dump() for f in response.top_features],
94
+ drift_z=response.drift_z,
95
+ )
96
+ return execute
97
+
98
+
99
+ def _make_eeg_executor(processed_dir: Path) -> Callable[[EEGPipelineInput], EEGPipelineOutput]:
100
+ """Closure factory: EEG pipeline, writes output under processed_dir."""
101
+ def execute(inp: EEGPipelineInput) -> EEGPipelineOutput:
102
+ from src.api.schemas import EEGRequest
103
+ from src.api import routes as api_routes
104
+ from fastapi import HTTPException
105
+ out_path = processed_dir / "eeg_features.parquet"
106
+ try:
107
+ response = api_routes.run_eeg(
108
+ EEGRequest(
109
+ input_path=inp.input_path,
110
+ output_path=str(out_path),
111
+ epoch_duration_s=inp.epoch_duration_s,
112
+ )
113
+ )
114
+ except HTTPException as e:
115
+ raise ValueError(f"eeg tool failed: {e.detail}") from e
116
+ return EEGPipelineOutput(
117
  input_path=inp.input_path,
118
+ output_path=response.output_path,
119
+ rows=response.rows,
120
+ columns=response.columns,
121
+ duration_sec=response.duration_sec,
122
  )
123
+ return execute
124
+
125
+
126
+ def _make_mri_executor(processed_dir: Path) -> Callable[[MRIPipelineInput], MRIPipelineOutput]:
127
+ """Closure factory: MRI pipeline, writes output under processed_dir."""
128
+ def execute(inp: MRIPipelineInput) -> MRIPipelineOutput:
129
+ from src.api.schemas import MRIRequest
130
+ from src.api import routes as api_routes
131
+ from fastapi import HTTPException
132
+ out_path = processed_dir / "mri_features.parquet"
133
+ try:
134
+ response = api_routes.run_mri(
135
+ MRIRequest(
136
+ input_dir=inp.input_dir,
137
+ sites_csv=inp.sites_csv,
138
+ output_path=str(out_path),
139
+ )
140
+ )
141
+ except HTTPException as e:
142
+ raise ValueError(f"mri tool failed: {e.detail}") from e
143
+ return MRIPipelineOutput(
144
  input_dir=inp.input_dir,
145
+ output_path=response.output_path,
146
+ rows=response.rows,
147
+ columns=response.columns,
148
+ duration_sec=response.duration_sec,
149
  )
150
+ return execute
 
 
 
 
 
 
 
151
 
152
 
153
  def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
 
166
  return execute
167
 
168
 
169
+ def build_default_tools(
170
+ rag_index_dir: Path | None,
171
+ processed_dir: Path = Path("data/processed"),
172
+ ) -> list[Tool]:
173
  """Return the 4 tools the orchestrator gets by default."""
174
  return [
175
  Tool(
 
182
  ),
183
  input_model=BBBPipelineInput,
184
  output_model=BBBPipelineOutput,
185
+ execute=_make_bbb_executor(),
186
  ),
187
  Tool(
188
  name="run_eeg_pipeline",
 
194
  ),
195
  input_model=EEGPipelineInput,
196
  output_model=EEGPipelineOutput,
197
+ execute=_make_eeg_executor(processed_dir),
198
  ),
199
  Tool(
200
  name="run_mri_pipeline",
 
205
  ),
206
  input_model=MRIPipelineInput,
207
  output_model=MRIPipelineOutput,
208
+ execute=_make_mri_executor(processed_dir),
209
  ),
210
  Tool(
211
  name="retrieve_context",
tests/agents/test_tools.py CHANGED
@@ -93,3 +93,36 @@ class TestBuildDefaultTools:
93
  assert "sites_csv" in MRIPipelineInput.model_fields
94
  assert "query" in RetrieveContextInput.model_fields
95
  assert "k" in RetrieveContextInput.model_fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  assert "sites_csv" in MRIPipelineInput.model_fields
94
  assert "query" in RetrieveContextInput.model_fields
95
  assert "k" in RetrieveContextInput.model_fields
96
+
97
+ def test_retrieve_context_short_circuits_when_no_index(self) -> None:
98
+ tools = build_default_tools(rag_index_dir=None)
99
+ retrieve = next(t for t in tools if t.name == "retrieve_context")
100
+ out = retrieve.invoke({"query": "anything", "k": 3})
101
+ assert out == {"query": "anything", "chunks": []}
102
+
103
+ def test_processed_dir_parameter_threads_to_executors(self, tmp_path: Path) -> None:
104
+ # build_default_tools should accept processed_dir; executors should
105
+ # eventually write under it (we don't invoke the pipelines here, just
106
+ # verify the parameter is accepted and tools are built).
107
+ tools = build_default_tools(rag_index_dir=None, processed_dir=tmp_path)
108
+ names = {t.name for t in tools}
109
+ assert "run_eeg_pipeline" in names
110
+ assert "run_mri_pipeline" in names
111
+
112
+ def test_default_processed_dir_when_omitted(self) -> None:
113
+ # backwards-compat: omitting processed_dir keeps existing behavior
114
+ tools = build_default_tools(rag_index_dir=None)
115
+ # just ensure no exception and 4 tools returned
116
+ assert len(tools) == 4
117
+
118
+ def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
119
+ from unittest.mock import patch
120
+ from fastapi import HTTPException
121
+
122
+ tools = build_default_tools(rag_index_dir=None)
123
+ bbb = next(t for t in tools if t.name == "run_bbb_pipeline")
124
+
125
+ with patch("src.api.routes.predict_bbb",
126
+ side_effect=HTTPException(status_code=503, detail="model missing")):
127
+ with pytest.raises(ValueError, match="bbb tool failed"):
128
+ bbb.invoke({"smiles": "CCO"})