feat(agents): register run_fusion tool for multi-modal disease confidence
Browse files- src/agents/prompts.py +2 -1
- src/agents/tools.py +17 -2
- tests/agents/test_tools.py +3 -2
- tests/agents/test_tools_fusion.py +20 -0
src/agents/prompts.py
CHANGED
|
@@ -7,12 +7,13 @@ from __future__ import annotations
|
|
| 7 |
|
| 8 |
|
| 9 |
ORCHESTRATOR_SYSTEM_PROMPT = """\
|
| 10 |
-
You are the NeuroBridge clinical-ML orchestrator. You have
|
| 11 |
|
| 12 |
- run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
|
| 13 |
- run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
|
| 14 |
- run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
|
| 15 |
- retrieve_context(query, k=4) → for grounding chunks from the knowledge base
|
|
|
|
| 16 |
|
| 17 |
Workflow — follow exactly:
|
| 18 |
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
ORCHESTRATOR_SYSTEM_PROMPT = """\
|
| 10 |
+
You are the NeuroBridge clinical-ML orchestrator. You have five tools:
|
| 11 |
|
| 12 |
- run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
|
| 13 |
- run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
|
| 14 |
- run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
|
| 15 |
- retrieve_context(query, k=4) → for grounding chunks from the knowledge base
|
| 16 |
+
- run_fusion: combine MRI/EEG/clinical-test scores into a per-disease confidence with attribution. Use when the doctor has more than one modality available.
|
| 17 |
|
| 18 |
Workflow — follow exactly:
|
| 19 |
|
src/agents/tools.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
|
| 2 |
function-callable tool the orchestrator can invoke.
|
| 3 |
|
| 4 |
-
Public entry: `build_default_tools(rag_index_dir)` returns the
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
|
@@ -21,6 +21,8 @@ from src.agents.schemas import (
|
|
| 21 |
RetrieveContextInput,
|
| 22 |
RetrieveContextOutput,
|
| 23 |
)
|
|
|
|
|
|
|
| 24 |
from src.core.logger import get_logger
|
| 25 |
|
| 26 |
logger = get_logger(__name__)
|
|
@@ -175,7 +177,7 @@ def build_default_tools(
|
|
| 175 |
rag_index_dir: Path | None,
|
| 176 |
processed_dir: Path = Path("data/processed"),
|
| 177 |
) -> list[Tool]:
|
| 178 |
-
"""Return the
|
| 179 |
return [
|
| 180 |
Tool(
|
| 181 |
name="run_bbb_pipeline",
|
|
@@ -225,4 +227,17 @@ def build_default_tools(
|
|
| 225 |
output_model=RetrieveContextOutput,
|
| 226 |
execute=_make_retrieve_executor(rag_index_dir),
|
| 227 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
]
|
|
|
|
| 1 |
"""Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
|
| 2 |
function-callable tool the orchestrator can invoke.
|
| 3 |
|
| 4 |
+
Public entry: `build_default_tools(rag_index_dir)` returns the 5 tools.
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
|
|
|
| 21 |
RetrieveContextInput,
|
| 22 |
RetrieveContextOutput,
|
| 23 |
)
|
| 24 |
+
from src.fusion.engine import fuse as fuse_engine
|
| 25 |
+
from src.fusion.types import FusionInput, FusionOutput
|
| 26 |
from src.core.logger import get_logger
|
| 27 |
|
| 28 |
logger = get_logger(__name__)
|
|
|
|
| 177 |
rag_index_dir: Path | None,
|
| 178 |
processed_dir: Path = Path("data/processed"),
|
| 179 |
) -> list[Tool]:
|
| 180 |
+
"""Return the 5 tools the orchestrator gets by default."""
|
| 181 |
return [
|
| 182 |
Tool(
|
| 183 |
name="run_bbb_pipeline",
|
|
|
|
| 227 |
output_model=RetrieveContextOutput,
|
| 228 |
execute=_make_retrieve_executor(rag_index_dir),
|
| 229 |
),
|
| 230 |
+
Tool(
|
| 231 |
+
name="run_fusion",
|
| 232 |
+
description=(
|
| 233 |
+
"Combine MRI prediction, EEG prediction, and clinical-test "
|
| 234 |
+
"scores (MMSE, MoCA, UPDRS, gait, age) into per-disease "
|
| 235 |
+
"(Alzheimer's, Parkinson's, other) confidence with full "
|
| 236 |
+
"attribution. Pass whichever modalities are available; "
|
| 237 |
+
"missing ones are skipped, not imputed. Does NOT use BBB."
|
| 238 |
+
),
|
| 239 |
+
input_model=FusionInput,
|
| 240 |
+
output_model=FusionOutput,
|
| 241 |
+
execute=lambda inp: fuse_engine(inp),
|
| 242 |
+
),
|
| 243 |
]
|
tests/agents/test_tools.py
CHANGED
|
@@ -79,6 +79,7 @@ class TestBuildDefaultTools:
|
|
| 79 |
"run_eeg_pipeline",
|
| 80 |
"run_mri_pipeline",
|
| 81 |
"retrieve_context",
|
|
|
|
| 82 |
}
|
| 83 |
|
| 84 |
def test_each_tool_has_pydantic_input_model(self) -> None:
|
|
@@ -115,8 +116,8 @@ class TestBuildDefaultTools:
|
|
| 115 |
def test_default_processed_dir_when_omitted(self) -> None:
|
| 116 |
# backwards-compat: omitting processed_dir keeps existing behavior
|
| 117 |
tools = build_default_tools(rag_index_dir=None)
|
| 118 |
-
# just ensure no exception and
|
| 119 |
-
assert len(tools) ==
|
| 120 |
|
| 121 |
def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
|
| 122 |
from fastapi import HTTPException
|
|
|
|
| 79 |
"run_eeg_pipeline",
|
| 80 |
"run_mri_pipeline",
|
| 81 |
"retrieve_context",
|
| 82 |
+
"run_fusion",
|
| 83 |
}
|
| 84 |
|
| 85 |
def test_each_tool_has_pydantic_input_model(self) -> None:
|
|
|
|
| 116 |
def test_default_processed_dir_when_omitted(self) -> None:
|
| 117 |
# backwards-compat: omitting processed_dir keeps existing behavior
|
| 118 |
tools = build_default_tools(rag_index_dir=None)
|
| 119 |
+
# just ensure no exception and 5 tools returned
|
| 120 |
+
assert len(tools) == 5
|
| 121 |
|
| 122 |
def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
|
| 123 |
from fastapi import HTTPException
|
tests/agents/test_tools_fusion.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the run_fusion agent tool."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from src.agents.tools import build_default_tools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestRunFusionTool:
|
| 8 |
+
def test_fusion_tool_is_registered(self) -> None:
|
| 9 |
+
tools = build_default_tools(rag_index_dir=None)
|
| 10 |
+
names = [t.name for t in tools]
|
| 11 |
+
assert "run_fusion" in names
|
| 12 |
+
|
| 13 |
+
def test_fusion_tool_executes_with_only_clinical(self) -> None:
|
| 14 |
+
tools = {t.name: t for t in build_default_tools(rag_index_dir=None)}
|
| 15 |
+
tool = tools["run_fusion"]
|
| 16 |
+
out = tool.execute(tool.input_model.model_validate({
|
| 17 |
+
"clinical": {"mmse": 12.0, "age_years": 78.0},
|
| 18 |
+
}))
|
| 19 |
+
assert out.top_disease in {"alzheimers", "parkinsons", "other"}
|
| 20 |
+
assert any(d.disease == "alzheimers" for d in out.diseases)
|