mekosotto Claude Sonnet 4.6 commited on
Commit
4fff9d2
·
1 Parent(s): 5d4dc71

feat(agents): register run_fusion tool for multi-modal disease confidence

Browse files
src/agents/prompts.py CHANGED
@@ -7,12 +7,13 @@ from __future__ import annotations
7
 
8
 
9
  ORCHESTRATOR_SYSTEM_PROMPT = """\
10
- You are the NeuroBridge clinical-ML orchestrator. You have four tools:
11
 
12
  - run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
13
  - run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
14
  - run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
15
  - retrieve_context(query, k=4) → for grounding chunks from the knowledge base
 
16
 
17
  Workflow — follow exactly:
18
 
 
7
 
8
 
9
  ORCHESTRATOR_SYSTEM_PROMPT = """\
10
+ You are the NeuroBridge clinical-ML orchestrator. You have five tools:
11
 
12
  - run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
13
  - run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
14
  - run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
15
  - retrieve_context(query, k=4) → for grounding chunks from the knowledge base
16
+ - run_fusion: combine MRI/EEG/clinical-test scores into a per-disease confidence with attribution. Use when the doctor has more than one modality available.
17
 
18
  Workflow — follow exactly:
19
 
src/agents/tools.py CHANGED
@@ -1,7 +1,7 @@
1
  """Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
2
  function-callable tool the orchestrator can invoke.
3
 
4
- Public entry: `build_default_tools(rag_index_dir)` returns the 4 tools.
5
  """
6
  from __future__ import annotations
7
 
@@ -21,6 +21,8 @@ from src.agents.schemas import (
21
  RetrieveContextInput,
22
  RetrieveContextOutput,
23
  )
 
 
24
  from src.core.logger import get_logger
25
 
26
  logger = get_logger(__name__)
@@ -175,7 +177,7 @@ def build_default_tools(
175
  rag_index_dir: Path | None,
176
  processed_dir: Path = Path("data/processed"),
177
  ) -> list[Tool]:
178
- """Return the 4 tools the orchestrator gets by default."""
179
  return [
180
  Tool(
181
  name="run_bbb_pipeline",
@@ -225,4 +227,17 @@ def build_default_tools(
225
  output_model=RetrieveContextOutput,
226
  execute=_make_retrieve_executor(rag_index_dir),
227
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ]
 
1
  """Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
2
  function-callable tool the orchestrator can invoke.
3
 
4
+ Public entry: `build_default_tools(rag_index_dir)` returns the 5 tools.
5
  """
6
  from __future__ import annotations
7
 
 
21
  RetrieveContextInput,
22
  RetrieveContextOutput,
23
  )
24
+ from src.fusion.engine import fuse as fuse_engine
25
+ from src.fusion.types import FusionInput, FusionOutput
26
  from src.core.logger import get_logger
27
 
28
  logger = get_logger(__name__)
 
177
  rag_index_dir: Path | None,
178
  processed_dir: Path = Path("data/processed"),
179
  ) -> list[Tool]:
180
+ """Return the 5 tools the orchestrator gets by default."""
181
  return [
182
  Tool(
183
  name="run_bbb_pipeline",
 
227
  output_model=RetrieveContextOutput,
228
  execute=_make_retrieve_executor(rag_index_dir),
229
  ),
230
+ Tool(
231
+ name="run_fusion",
232
+ description=(
233
+ "Combine MRI prediction, EEG prediction, and clinical-test "
234
+ "scores (MMSE, MoCA, UPDRS, gait, age) into per-disease "
235
+ "(Alzheimer's, Parkinson's, other) confidence with full "
236
+ "attribution. Pass whichever modalities are available; "
237
+ "missing ones are skipped, not imputed. Does NOT use BBB."
238
+ ),
239
+ input_model=FusionInput,
240
+ output_model=FusionOutput,
241
+ execute=lambda inp: fuse_engine(inp),
242
+ ),
243
  ]
tests/agents/test_tools.py CHANGED
@@ -79,6 +79,7 @@ class TestBuildDefaultTools:
79
  "run_eeg_pipeline",
80
  "run_mri_pipeline",
81
  "retrieve_context",
 
82
  }
83
 
84
  def test_each_tool_has_pydantic_input_model(self) -> None:
@@ -115,8 +116,8 @@ class TestBuildDefaultTools:
115
  def test_default_processed_dir_when_omitted(self) -> None:
116
  # backwards-compat: omitting processed_dir keeps existing behavior
117
  tools = build_default_tools(rag_index_dir=None)
118
- # just ensure no exception and 4 tools returned
119
- assert len(tools) == 4
120
 
121
  def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
122
  from fastapi import HTTPException
 
79
  "run_eeg_pipeline",
80
  "run_mri_pipeline",
81
  "retrieve_context",
82
+ "run_fusion",
83
  }
84
 
85
  def test_each_tool_has_pydantic_input_model(self) -> None:
 
116
  def test_default_processed_dir_when_omitted(self) -> None:
117
  # backwards-compat: omitting processed_dir keeps existing behavior
118
  tools = build_default_tools(rag_index_dir=None)
119
+ # just ensure no exception and 5 tools returned
120
+ assert len(tools) == 5
121
 
122
  def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
123
  from fastapi import HTTPException
tests/agents/test_tools_fusion.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the run_fusion agent tool."""
2
+ from __future__ import annotations
3
+
4
+ from src.agents.tools import build_default_tools
5
+
6
+
7
+ class TestRunFusionTool:
8
+ def test_fusion_tool_is_registered(self) -> None:
9
+ tools = build_default_tools(rag_index_dir=None)
10
+ names = [t.name for t in tools]
11
+ assert "run_fusion" in names
12
+
13
+ def test_fusion_tool_executes_with_only_clinical(self) -> None:
14
+ tools = {t.name: t for t in build_default_tools(rag_index_dir=None)}
15
+ tool = tools["run_fusion"]
16
+ out = tool.execute(tool.input_model.model_validate({
17
+ "clinical": {"mmse": 12.0, "age_years": 78.0},
18
+ }))
19
+ assert out.top_disease in {"alzheimers", "parkinsons", "other"}
20
+ assert any(d.disease == "alzheimers" for d in out.diseases)