feat(agents): Tool dataclass + registry + 4 tool wrappers (3 pipelines + RAG)
Browse files- src/agents/__init__.py +0 -0
- src/agents/schemas.py +87 -0
- src/agents/tools.py +205 -0
- tests/agents/__init__.py +0 -0
- tests/agents/test_tools.py +95 -0
src/agents/__init__.py
ADDED
|
File without changes
|
src/agents/schemas.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic input/output schemas for orchestrator tools and the agent result.
|
| 2 |
+
|
| 3 |
+
These schemas double as OpenAI function-calling parameter definitions
|
| 4 |
+
(via `model_json_schema()`) and as runtime validation gates. Keep field
|
| 5 |
+
names lowercase + snake_case so prompts and JSON outputs align.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# --- Pipeline tool inputs ---------------------------------------------------
|
| 15 |
+
|
| 16 |
+
class BBBPipelineInput(BaseModel):
|
| 17 |
+
"""Input for `run_bbb_pipeline` — a single SMILES string."""
|
| 18 |
+
smiles: str = Field(..., description="A single molecular SMILES string, e.g. 'CCO'")
|
| 19 |
+
top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP attributions to return")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class EEGPipelineInput(BaseModel):
|
| 23 |
+
"""Input for `run_eeg_pipeline` — path to an EEG file (.fif or .edf)."""
|
| 24 |
+
input_path: str = Field(..., description="Path to EEG recording file (.fif or .edf)")
|
| 25 |
+
epoch_duration_s: float = Field(2.0, gt=0.1, le=60.0)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MRIPipelineInput(BaseModel):
|
| 29 |
+
"""Input for `run_mri_pipeline` — directory of NIfTI files + sites CSV."""
|
| 30 |
+
input_dir: str = Field(..., description="Directory containing .nii.gz volumes")
|
| 31 |
+
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class RetrieveContextInput(BaseModel):
|
| 35 |
+
"""Input for `retrieve_context` — natural-language query into the KB."""
|
| 36 |
+
query: str = Field(..., min_length=2, description="Search query for the knowledge base")
|
| 37 |
+
k: int = Field(4, ge=1, le=10, description="Number of chunks to return")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# --- Pipeline tool outputs --------------------------------------------------
|
| 41 |
+
|
| 42 |
+
class BBBPipelineOutput(BaseModel):
|
| 43 |
+
smiles: str
|
| 44 |
+
label: int
|
| 45 |
+
label_text: str
|
| 46 |
+
confidence: float
|
| 47 |
+
top_features: list[dict[str, Any]]
|
| 48 |
+
drift_z: float | None = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class EEGPipelineOutput(BaseModel):
|
| 52 |
+
input_path: str
|
| 53 |
+
output_path: str
|
| 54 |
+
rows: int
|
| 55 |
+
columns: int
|
| 56 |
+
duration_sec: float
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MRIPipelineOutput(BaseModel):
|
| 60 |
+
input_dir: str
|
| 61 |
+
output_path: str
|
| 62 |
+
rows: int
|
| 63 |
+
columns: int
|
| 64 |
+
duration_sec: float
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RetrieveContextOutput(BaseModel):
|
| 68 |
+
query: str
|
| 69 |
+
chunks: list[dict[str, Any]]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# --- Agent result -----------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
class ToolTraceItem(BaseModel):
|
| 75 |
+
"""One step in the orchestrator's tool-call trace."""
|
| 76 |
+
name: str
|
| 77 |
+
args: dict[str, Any]
|
| 78 |
+
result: dict[str, Any] | None = None
|
| 79 |
+
error: str | None = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class AgentResult(BaseModel):
|
| 83 |
+
"""Final orchestrator response: synthesized text + full trace."""
|
| 84 |
+
text: str
|
| 85 |
+
trace: list[ToolTraceItem] = Field(default_factory=list)
|
| 86 |
+
model: str | None = None
|
| 87 |
+
finish_reason: str = "complete" # complete | max_steps | error
|
src/agents/tools.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool dataclass + registry. Wraps each pipeline + the RAG retriever as a
|
| 2 |
+
function-callable tool the orchestrator can invoke.
|
| 3 |
+
|
| 4 |
+
Public entry: `build_default_tools(rag_index_dir)` returns the 4 tools.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Callable
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, ValidationError
|
| 13 |
+
|
| 14 |
+
from src.agents.schemas import (
|
| 15 |
+
BBBPipelineInput,
|
| 16 |
+
BBBPipelineOutput,
|
| 17 |
+
EEGPipelineInput,
|
| 18 |
+
EEGPipelineOutput,
|
| 19 |
+
MRIPipelineInput,
|
| 20 |
+
MRIPipelineOutput,
|
| 21 |
+
RetrieveContextInput,
|
| 22 |
+
RetrieveContextOutput,
|
| 23 |
+
)
|
| 24 |
+
from src.core.logger import get_logger
|
| 25 |
+
|
| 26 |
+
logger = get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Tool:
|
| 31 |
+
"""One callable tool exposed to the orchestrator.
|
| 32 |
+
|
| 33 |
+
`execute(input_model_instance) -> output_model_instance` is the contract.
|
| 34 |
+
`invoke(args_dict)` validates the dict, runs execute, returns a plain dict.
|
| 35 |
+
"""
|
| 36 |
+
name: str
|
| 37 |
+
description: str
|
| 38 |
+
input_model: type[BaseModel]
|
| 39 |
+
output_model: type[BaseModel]
|
| 40 |
+
execute: Callable[[Any], BaseModel]
|
| 41 |
+
|
| 42 |
+
def openai_schema(self) -> dict[str, Any]:
|
| 43 |
+
"""OpenAI/OpenRouter function-calling schema for this tool."""
|
| 44 |
+
params = self.input_model.model_json_schema()
|
| 45 |
+
# OpenAI doesn't accept top-level $defs / title in some clients —
|
| 46 |
+
# strip the cosmetic ones; keep properties/required/type.
|
| 47 |
+
cleaned = {
|
| 48 |
+
"type": "object",
|
| 49 |
+
"properties": params.get("properties", {}),
|
| 50 |
+
"required": params.get("required", []),
|
| 51 |
+
}
|
| 52 |
+
return {
|
| 53 |
+
"type": "function",
|
| 54 |
+
"function": {
|
| 55 |
+
"name": self.name,
|
| 56 |
+
"description": self.description,
|
| 57 |
+
"parameters": cleaned,
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def invoke(self, args: dict[str, Any]) -> dict[str, Any]:
|
| 62 |
+
try:
|
| 63 |
+
inp = self.input_model.model_validate(args)
|
| 64 |
+
except ValidationError as e:
|
| 65 |
+
raise ValueError(f"invalid input for {self.name}: {e}") from e
|
| 66 |
+
out = self.execute(inp)
|
| 67 |
+
return out.model_dump()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
# Tool implementations — thin wrappers around existing pipelines + RAG.
|
| 72 |
+
# Heavy work stays in the underlying modules; these only adapt I/O.
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _execute_bbb(inp: BBBPipelineInput) -> BBBPipelineOutput:
|
| 77 |
+
"""Predict + SHAP for a single SMILES, reusing the existing model surface."""
|
| 78 |
+
from src.api import routes as api_routes
|
| 79 |
+
from src.api.schemas import BBBPredictRequest
|
| 80 |
+
|
| 81 |
+
response = api_routes.predict_bbb(
|
| 82 |
+
BBBPredictRequest(smiles=inp.smiles, top_k=inp.top_k)
|
| 83 |
+
)
|
| 84 |
+
return BBBPipelineOutput(
|
| 85 |
+
smiles=inp.smiles,
|
| 86 |
+
label=response.label,
|
| 87 |
+
label_text=response.label_text,
|
| 88 |
+
confidence=response.confidence,
|
| 89 |
+
top_features=[f.model_dump() for f in response.top_features],
|
| 90 |
+
drift_z=response.drift_z,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _execute_eeg(inp: EEGPipelineInput) -> EEGPipelineOutput:
|
| 95 |
+
"""Run the EEG pipeline via the existing route function (run_eeg)."""
|
| 96 |
+
from src.api.schemas import EEGRequest
|
| 97 |
+
from src.api import routes as api_routes
|
| 98 |
+
|
| 99 |
+
out_path = Path("data/processed/eeg_features.parquet")
|
| 100 |
+
response = api_routes.run_eeg(
|
| 101 |
+
EEGRequest(
|
| 102 |
+
input_path=inp.input_path,
|
| 103 |
+
output_path=str(out_path),
|
| 104 |
+
epoch_duration_s=inp.epoch_duration_s,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
return EEGPipelineOutput(
|
| 108 |
+
input_path=inp.input_path,
|
| 109 |
+
output_path=response.output_path,
|
| 110 |
+
rows=response.rows,
|
| 111 |
+
columns=response.columns,
|
| 112 |
+
duration_sec=response.duration_sec,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _execute_mri(inp: MRIPipelineInput) -> MRIPipelineOutput:
|
| 117 |
+
"""Run the MRI pipeline via the existing route function (run_mri)."""
|
| 118 |
+
from src.api.schemas import MRIRequest
|
| 119 |
+
from src.api import routes as api_routes
|
| 120 |
+
|
| 121 |
+
out_path = Path("data/processed/mri_features.parquet")
|
| 122 |
+
response = api_routes.run_mri(
|
| 123 |
+
MRIRequest(
|
| 124 |
+
input_dir=inp.input_dir,
|
| 125 |
+
sites_csv=inp.sites_csv,
|
| 126 |
+
output_path=str(out_path),
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
return MRIPipelineOutput(
|
| 130 |
+
input_dir=inp.input_dir,
|
| 131 |
+
output_path=response.output_path,
|
| 132 |
+
rows=response.rows,
|
| 133 |
+
columns=response.columns,
|
| 134 |
+
duration_sec=response.duration_sec,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _make_retrieve_executor(rag_index_dir: Path | None) -> Callable[[RetrieveContextInput], RetrieveContextOutput]:
|
| 139 |
+
"""Closure: capture the index dir; lazy-load the retriever on first call."""
|
| 140 |
+
state: dict[str, Any] = {"retriever": None}
|
| 141 |
+
|
| 142 |
+
def execute(inp: RetrieveContextInput) -> RetrieveContextOutput:
|
| 143 |
+
if rag_index_dir is None or not (rag_index_dir / "index.bin").exists():
|
| 144 |
+
return RetrieveContextOutput(query=inp.query, chunks=[])
|
| 145 |
+
if state["retriever"] is None:
|
| 146 |
+
from src.rag.retrieve import RAGRetriever
|
| 147 |
+
state["retriever"] = RAGRetriever.load(rag_index_dir)
|
| 148 |
+
hits = state["retriever"].search(inp.query, k=inp.k)
|
| 149 |
+
return RetrieveContextOutput(query=inp.query, chunks=hits)
|
| 150 |
+
|
| 151 |
+
return execute
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def build_default_tools(rag_index_dir: Path | None) -> list[Tool]:
|
| 155 |
+
"""Return the 4 tools the orchestrator gets by default."""
|
| 156 |
+
return [
|
| 157 |
+
Tool(
|
| 158 |
+
name="run_bbb_pipeline",
|
| 159 |
+
description=(
|
| 160 |
+
"Predict blood-brain-barrier permeability for a SINGLE SMILES "
|
| 161 |
+
"string. Use this when the user input looks like a molecule "
|
| 162 |
+
"(short alphanumeric string with no file extension, e.g. 'CCO', "
|
| 163 |
+
"'c1ccccc1'). Returns label, confidence, top SHAP features, drift."
|
| 164 |
+
),
|
| 165 |
+
input_model=BBBPipelineInput,
|
| 166 |
+
output_model=BBBPipelineOutput,
|
| 167 |
+
execute=_execute_bbb,
|
| 168 |
+
),
|
| 169 |
+
Tool(
|
| 170 |
+
name="run_eeg_pipeline",
|
| 171 |
+
description=(
|
| 172 |
+
"Run the EEG signal-processing pipeline (bandpass + ICA + "
|
| 173 |
+
"epoching + feature extraction) on an EEG recording file. Use "
|
| 174 |
+
"when input_path ends in .fif or .edf. Returns row/column "
|
| 175 |
+
"counts + duration."
|
| 176 |
+
),
|
| 177 |
+
input_model=EEGPipelineInput,
|
| 178 |
+
output_model=EEGPipelineOutput,
|
| 179 |
+
execute=_execute_eeg,
|
| 180 |
+
),
|
| 181 |
+
Tool(
|
| 182 |
+
name="run_mri_pipeline",
|
| 183 |
+
description=(
|
| 184 |
+
"Run the multi-site MRI ComBat-harmonization pipeline. Use "
|
| 185 |
+
"when input is a directory containing .nii.gz volumes paired "
|
| 186 |
+
"with a sites.csv. Returns row/column counts + duration."
|
| 187 |
+
),
|
| 188 |
+
input_model=MRIPipelineInput,
|
| 189 |
+
output_model=MRIPipelineOutput,
|
| 190 |
+
execute=_execute_mri,
|
| 191 |
+
),
|
| 192 |
+
Tool(
|
| 193 |
+
name="retrieve_context",
|
| 194 |
+
description=(
|
| 195 |
+
"Retrieve up to k passages from the curated reference knowledge "
|
| 196 |
+
"base. Use AFTER a pipeline tool returns, to ground your final "
|
| 197 |
+
"synthesis in cited literature. Formulate a focused query "
|
| 198 |
+
"based on the pipeline output (e.g., 'BBB permeability of "
|
| 199 |
+
"small lipophilic molecules' or 'ComBat site harmonization')."
|
| 200 |
+
),
|
| 201 |
+
input_model=RetrieveContextInput,
|
| 202 |
+
output_model=RetrieveContextOutput,
|
| 203 |
+
execute=_make_retrieve_executor(rag_index_dir),
|
| 204 |
+
),
|
| 205 |
+
]
|
tests/agents/__init__.py
ADDED
|
File without changes
|
tests/agents/test_tools.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.agents.tools — Tool dataclass + registry + 4 tool wrappers."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
from src.agents.tools import (
|
| 10 |
+
Tool,
|
| 11 |
+
build_default_tools,
|
| 12 |
+
BBBPipelineInput,
|
| 13 |
+
EEGPipelineInput,
|
| 14 |
+
MRIPipelineInput,
|
| 15 |
+
RetrieveContextInput,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _DummyInput(BaseModel):
|
| 20 |
+
x: int
|
| 21 |
+
y: str = "default"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _DummyOutput(BaseModel):
|
| 25 |
+
result: int
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TestTool:
|
| 29 |
+
def test_openai_schema_shape(self) -> None:
|
| 30 |
+
tool = Tool(
|
| 31 |
+
name="dummy",
|
| 32 |
+
description="A dummy tool",
|
| 33 |
+
input_model=_DummyInput,
|
| 34 |
+
output_model=_DummyOutput,
|
| 35 |
+
execute=lambda inp: _DummyOutput(result=inp.x * 2),
|
| 36 |
+
)
|
| 37 |
+
schema = tool.openai_schema()
|
| 38 |
+
assert schema["type"] == "function"
|
| 39 |
+
assert schema["function"]["name"] == "dummy"
|
| 40 |
+
assert schema["function"]["description"] == "A dummy tool"
|
| 41 |
+
params = schema["function"]["parameters"]
|
| 42 |
+
assert params["type"] == "object"
|
| 43 |
+
assert "x" in params["properties"]
|
| 44 |
+
assert "x" in params["required"]
|
| 45 |
+
assert "y" not in params["required"] # has default
|
| 46 |
+
|
| 47 |
+
def test_invoke_validates_and_returns_dict(self) -> None:
|
| 48 |
+
tool = Tool(
|
| 49 |
+
name="dummy",
|
| 50 |
+
description="d",
|
| 51 |
+
input_model=_DummyInput,
|
| 52 |
+
output_model=_DummyOutput,
|
| 53 |
+
execute=lambda inp: _DummyOutput(result=inp.x * 2),
|
| 54 |
+
)
|
| 55 |
+
out = tool.invoke({"x": 5})
|
| 56 |
+
assert out == {"result": 10}
|
| 57 |
+
|
| 58 |
+
def test_invoke_invalid_input_raises(self) -> None:
|
| 59 |
+
tool = Tool(
|
| 60 |
+
name="dummy",
|
| 61 |
+
description="d",
|
| 62 |
+
input_model=_DummyInput,
|
| 63 |
+
output_model=_DummyOutput,
|
| 64 |
+
execute=lambda inp: _DummyOutput(result=inp.x * 2),
|
| 65 |
+
)
|
| 66 |
+
with pytest.raises(ValueError, match="invalid input"):
|
| 67 |
+
tool.invoke({"y": "missing-x"})
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TestBuildDefaultTools:
|
| 71 |
+
def test_default_set_has_four_tools(self, tmp_path: Path) -> None:
|
| 72 |
+
# build with placeholder paths; tools won't be invoked here
|
| 73 |
+
tools = build_default_tools(rag_index_dir=None)
|
| 74 |
+
names = {t.name for t in tools}
|
| 75 |
+
assert names == {
|
| 76 |
+
"run_bbb_pipeline",
|
| 77 |
+
"run_eeg_pipeline",
|
| 78 |
+
"run_mri_pipeline",
|
| 79 |
+
"retrieve_context",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
def test_each_tool_has_pydantic_input_model(self) -> None:
|
| 83 |
+
tools = build_default_tools(rag_index_dir=None)
|
| 84 |
+
for t in tools:
|
| 85 |
+
assert issubclass(t.input_model, BaseModel)
|
| 86 |
+
assert issubclass(t.output_model, BaseModel)
|
| 87 |
+
|
| 88 |
+
def test_input_models_have_smiles_paths(self) -> None:
|
| 89 |
+
# verify the field names downstream system prompt depends on
|
| 90 |
+
assert "smiles" in BBBPipelineInput.model_fields
|
| 91 |
+
assert "input_path" in EEGPipelineInput.model_fields
|
| 92 |
+
assert "input_dir" in MRIPipelineInput.model_fields
|
| 93 |
+
assert "sites_csv" in MRIPipelineInput.model_fields
|
| 94 |
+
assert "query" in RetrieveContextInput.model_fields
|
| 95 |
+
assert "k" in RetrieveContextInput.model_fields
|