File size: 5,880 Bytes
460fcc2
 
 
 
c0a7163
 
460fcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327b23d
460fcc2
 
 
 
 
 
 
 
4fff9d2
327b23d
 
460fcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
c0a7163
460fcc2
 
6d2aa47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327b23d
 
6d2aa47
 
 
 
 
 
 
 
 
 
 
c0a7163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Tests for src.agents.tools — Tool dataclass + registry + 4 tool wrappers."""
from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch

import pytest
from pydantic import BaseModel

from src.agents.tools import (
    Tool,
    build_default_tools,
    BBBPipelineInput,
    EEGPipelineInput,
    MRIPipelineInput,
    RetrieveContextInput,
)


class _DummyInput(BaseModel):
    x: int
    y: str = "default"


class _DummyOutput(BaseModel):
    result: int


class TestTool:
    def test_openai_schema_shape(self) -> None:
        tool = Tool(
            name="dummy",
            description="A dummy tool",
            input_model=_DummyInput,
            output_model=_DummyOutput,
            execute=lambda inp: _DummyOutput(result=inp.x * 2),
        )
        schema = tool.openai_schema()
        assert schema["type"] == "function"
        assert schema["function"]["name"] == "dummy"
        assert schema["function"]["description"] == "A dummy tool"
        params = schema["function"]["parameters"]
        assert params["type"] == "object"
        assert "x" in params["properties"]
        assert "x" in params["required"]
        assert "y" not in params["required"]  # has default

    def test_invoke_validates_and_returns_dict(self) -> None:
        tool = Tool(
            name="dummy",
            description="d",
            input_model=_DummyInput,
            output_model=_DummyOutput,
            execute=lambda inp: _DummyOutput(result=inp.x * 2),
        )
        out = tool.invoke({"x": 5})
        assert out == {"result": 10}

    def test_invoke_invalid_input_raises(self) -> None:
        tool = Tool(
            name="dummy",
            description="d",
            input_model=_DummyInput,
            output_model=_DummyOutput,
            execute=lambda inp: _DummyOutput(result=inp.x * 2),
        )
        with pytest.raises(ValueError, match="invalid input"):
            tool.invoke({"y": "missing-x"})


class TestBuildDefaultTools:
    def test_default_set_has_seven_tools(self, tmp_path: Path) -> None:
        # build with placeholder paths; tools won't be invoked here
        tools = build_default_tools(rag_index_dir=None)
        names = {t.name for t in tools}
        assert names == {
            "run_bbb_pipeline",
            "run_eeg_pipeline",
            "run_mri_pipeline",
            "retrieve_context",
            "run_fusion",
            "compute_bbb_leakage_score",
            "adjust_drug_dose",
        }

    def test_each_tool_has_pydantic_input_model(self) -> None:
        tools = build_default_tools(rag_index_dir=None)
        for t in tools:
            assert issubclass(t.input_model, BaseModel)
            assert issubclass(t.output_model, BaseModel)

    def test_input_models_have_smiles_paths(self) -> None:
        # verify the field names downstream system prompt depends on
        assert "smiles" in BBBPipelineInput.model_fields
        assert "input_path" in EEGPipelineInput.model_fields
        assert "input_dir" in MRIPipelineInput.model_fields
        assert "sites_csv" in MRIPipelineInput.model_fields
        assert "sites_csv" not in MRIPipelineInput.model_json_schema().get("required", [])
        assert "query" in RetrieveContextInput.model_fields
        assert "k" in RetrieveContextInput.model_fields

    def test_retrieve_context_short_circuits_when_no_index(self) -> None:
        tools = build_default_tools(rag_index_dir=None)
        retrieve = next(t for t in tools if t.name == "retrieve_context")
        out = retrieve.invoke({"query": "anything", "k": 3})
        assert out == {"query": "anything", "chunks": []}

    def test_processed_dir_parameter_threads_to_executors(self, tmp_path: Path) -> None:
        # build_default_tools should accept processed_dir; executors should
        # eventually write under it (we don't invoke the pipelines here, just
        # verify the parameter is accepted and tools are built).
        tools = build_default_tools(rag_index_dir=None, processed_dir=tmp_path)
        names = {t.name for t in tools}
        assert "run_eeg_pipeline" in names
        assert "run_mri_pipeline" in names

    def test_default_processed_dir_when_omitted(self) -> None:
        # backwards-compat: omitting processed_dir keeps existing behavior
        tools = build_default_tools(rag_index_dir=None)
        # just ensure no exception and 7 tools returned
        assert len(tools) == 7

    def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None:
        from fastapi import HTTPException

        tools = build_default_tools(rag_index_dir=None)
        bbb = next(t for t in tools if t.name == "run_bbb_pipeline")

        with patch("src.api.routes.predict_bbb",
                   side_effect=HTTPException(status_code=503, detail="model missing")):
            with pytest.raises(ValueError, match="bbb tool failed"):
                bbb.invoke({"smiles": "CCO"})

    def test_mri_executor_defaults_sites_csv_to_input_dir_sites_csv(self, tmp_path: Path) -> None:
        tools = build_default_tools(rag_index_dir=None, processed_dir=tmp_path / "processed")
        mri = next(t for t in tools if t.name == "run_mri_pipeline")
        input_dir = tmp_path / "mri"
        input_dir.mkdir()

        with patch(
            "src.api.routes.run_mri",
            return_value=SimpleNamespace(
                output_path=str(tmp_path / "processed" / "mri_features.parquet"),
                rows=2,
                columns=3,
                duration_sec=0.1,
            ),
        ) as run_mri:
            out = mri.invoke({"input_dir": str(input_dir)})

        assert out["rows"] == 2
        req = run_mri.call_args.args[0]
        assert req.input_dir == str(input_dir)
        assert req.sites_csv == str(input_dir / "sites.csv")