mekosotto Claude Sonnet 4.6 commited on
Commit
2091a1b
·
1 Parent(s): d3e290f

feat(agents): orchestrator loop (function-calling + tool trace + max-steps gate)

Browse files
src/agents/orchestrator.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Orchestrator agent: function-calling loop over a list of Tools.
2
+
3
+ No agent framework — uses the openai SDK's chat-completions function-calling
4
+ interface directly. This is the same SDK already used by src/llm/explainer.py,
5
+ keeping the dependency surface minimal.
6
+
7
+ Public entry: `Orchestrator(llm_client, tools, system_prompt, model).run(user_input)`.
8
+ Returns an `AgentResult` with synthesized text + full tool-call trace.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from typing import Any
14
+
15
+ from src.agents.schemas import AgentResult, ToolTraceItem
16
+ from src.agents.tools import Tool
17
+ from src.core.logger import get_logger
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ class Orchestrator:
23
+ """Single-agent function-calling loop. Stops on (a) text response, (b) max steps."""
24
+
25
+ def __init__(
26
+ self,
27
+ llm_client: Any,
28
+ tools: list[Tool],
29
+ system_prompt: str,
30
+ model: str,
31
+ max_steps: int = 5,
32
+ temperature: float = 0.0,
33
+ ) -> None:
34
+ self._client = llm_client
35
+ self._tools_by_name = {t.name: t for t in tools}
36
+ self._tool_schemas = [t.openai_schema() for t in tools]
37
+ self._system_prompt = system_prompt
38
+ self._model = model
39
+ self._max_steps = max_steps
40
+ self._temperature = temperature
41
+
42
+ def run(self, user_input: str) -> AgentResult:
43
+ messages: list[dict[str, Any]] = [
44
+ {"role": "system", "content": self._system_prompt},
45
+ {"role": "user", "content": user_input},
46
+ ]
47
+ trace: list[ToolTraceItem] = []
48
+
49
+ for _step in range(self._max_steps):
50
+ response = self._client.chat.completions.create(
51
+ model=self._model,
52
+ messages=messages,
53
+ tools=self._tool_schemas,
54
+ tool_choice="auto",
55
+ temperature=self._temperature,
56
+ )
57
+ msg = response.choices[0].message
58
+
59
+ if not getattr(msg, "tool_calls", None):
60
+ return AgentResult(
61
+ text=(msg.content or "").strip(),
62
+ trace=trace,
63
+ model=self._model,
64
+ finish_reason="complete",
65
+ )
66
+
67
+ messages.append({
68
+ "role": "assistant",
69
+ "content": msg.content,
70
+ "tool_calls": [tc.model_dump() for tc in msg.tool_calls],
71
+ })
72
+
73
+ for tc in msg.tool_calls:
74
+ name = tc.function.name
75
+ tool = self._tools_by_name.get(name)
76
+ if tool is None:
77
+ err = f"unknown tool: {name}"
78
+ trace.append(ToolTraceItem(name=name, args={}, error=err))
79
+ messages.append({
80
+ "role": "tool",
81
+ "tool_call_id": tc.id,
82
+ "content": json.dumps({"error": err}),
83
+ })
84
+ continue
85
+ try:
86
+ args = json.loads(tc.function.arguments or "{}")
87
+ result = tool.invoke(args)
88
+ trace.append(ToolTraceItem(name=name, args=args, result=result))
89
+ messages.append({
90
+ "role": "tool",
91
+ "tool_call_id": tc.id,
92
+ "content": json.dumps({"result": result}, default=str),
93
+ })
94
+ except Exception as e:
95
+ err = str(e)
96
+ trace.append(ToolTraceItem(name=name, args={}, error=err))
97
+ messages.append({
98
+ "role": "tool",
99
+ "tool_call_id": tc.id,
100
+ "content": json.dumps({"error": err}),
101
+ })
102
+
103
+ return AgentResult(
104
+ text="Max steps reached without a final answer.",
105
+ trace=trace,
106
+ model=self._model,
107
+ finish_reason="max_steps",
108
+ )
src/agents/prompts.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """System prompts for the orchestrator agent.
2
+
3
+ Kept in a dedicated module so prompt edits are diff-readable and reviewable
4
+ in isolation from the orchestrator loop.
5
+ """
6
+ from __future__ import annotations
7
+
8
+
9
+ ORCHESTRATOR_SYSTEM_PROMPT = """\
10
+ You are the NeuroBridge clinical-ML orchestrator. You have four tools:
11
+
12
+ - run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
13
+ - run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
14
+ - run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
15
+ - retrieve_context(query, k=4) → for grounding chunks from the knowledge base
16
+
17
+ Workflow — follow exactly:
18
+
19
+ 1. Look at the user input. Decide which ONE pipeline tool fits:
20
+ - SMILES (short, all-letters/digits, no slashes, no .ext) → run_bbb_pipeline
21
+ - Path ending in .fif or .edf → run_eeg_pipeline
22
+ - Path that is a directory (no file extension at the tail) → run_mri_pipeline
23
+ If ambiguous, prefer SMILES if it parses; otherwise return:
24
+ "Cannot identify modality. Provide a SMILES, .fif/.edf path, or NIfTI directory."
25
+
26
+ 2. Call the chosen pipeline tool exactly once with the user input.
27
+
28
+ 3. After the pipeline returns, formulate ONE focused retrieval query that
29
+ captures the scientific concept behind the prediction (NOT the raw input).
30
+ Examples of good queries:
31
+ - "BBB permeability of small lipophilic molecules" (after BBB predict)
32
+ - "ICA artifact removal in multi-channel EEG" (after EEG run)
33
+ - "ComBat scanner site harmonization in multi-center MRI" (after MRI run)
34
+ Then call retrieve_context with that query.
35
+
36
+ 4. Synthesize a final response in 3-5 sentences:
37
+ - State the concrete pipeline result (label, confidence, key numbers).
38
+ - Cite at least one specific fact from the retrieved chunks (mention the
39
+ source file in parentheses, e.g. "(lipinski_rule_of_five.md)").
40
+ - Match the user's question language: Turkish in → Turkish out, etc.
41
+ - If retrieve_context returned 0 chunks, say so explicitly and answer
42
+ using only the pipeline result.
43
+
44
+ Hard constraints:
45
+ - Call exactly ONE pipeline tool, then exactly ONE retrieve_context, then stop.
46
+ - Do NOT invent facts. Only use numbers from the pipeline tool output and
47
+ text from the retrieved chunks.
48
+ - No preamble, no apologies, no meta-commentary about being an AI.
49
+ """
tests/agents/test_orchestrator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.agents.orchestrator — agent loop with stubbed LLM client.
2
+
3
+ We do NOT hit OpenRouter here. We construct a fake client that returns
4
+ scripted tool-call responses, then verify the orchestrator dispatches
5
+ tools and assembles the trace correctly.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from typing import Any
11
+ from unittest.mock import MagicMock
12
+
13
+ import pytest
14
+ from pydantic import BaseModel
15
+
16
+ from src.agents.orchestrator import Orchestrator
17
+ from src.agents.tools import Tool
18
+
19
+
20
+ # --- Helpers ----------------------------------------------------------------
21
+
22
+
23
+ def _fake_choice_with_tool_call(name: str, args: dict[str, Any], call_id: str = "c1") -> Any:
24
+ msg = MagicMock()
25
+ msg.content = None
26
+ tc = MagicMock()
27
+ tc.id = call_id
28
+ tc.function.name = name
29
+ tc.function.arguments = json.dumps(args)
30
+ tc.model_dump = MagicMock(return_value={"id": call_id, "type": "function",
31
+ "function": {"name": name,
32
+ "arguments": json.dumps(args)}})
33
+ msg.tool_calls = [tc]
34
+ choice = MagicMock()
35
+ choice.message = msg
36
+ response = MagicMock()
37
+ response.choices = [choice]
38
+ return response
39
+
40
+
41
+ def _fake_choice_with_text(text: str) -> Any:
42
+ msg = MagicMock()
43
+ msg.content = text
44
+ msg.tool_calls = None
45
+ choice = MagicMock()
46
+ choice.message = msg
47
+ response = MagicMock()
48
+ response.choices = [choice]
49
+ return response
50
+
51
+
52
+ class _PingInput(BaseModel):
53
+ msg: str
54
+
55
+
56
+ class _PingOutput(BaseModel):
57
+ echo: str
58
+
59
+
60
+ def _make_ping_tool() -> Tool:
61
+ return Tool(
62
+ name="ping",
63
+ description="Echo a string back.",
64
+ input_model=_PingInput,
65
+ output_model=_PingOutput,
66
+ execute=lambda inp: _PingOutput(echo=f"pong:{inp.msg}"),
67
+ )
68
+
69
+
70
+ # --- Tests ------------------------------------------------------------------
71
+
72
+
73
+ class TestOrchestrator:
74
+ def test_single_tool_then_text_response(self) -> None:
75
+ client = MagicMock()
76
+ client.chat.completions.create.side_effect = [
77
+ _fake_choice_with_tool_call("ping", {"msg": "hello"}),
78
+ _fake_choice_with_text("All done."),
79
+ ]
80
+ orch = Orchestrator(
81
+ llm_client=client,
82
+ tools=[_make_ping_tool()],
83
+ system_prompt="sys",
84
+ model="stub-model",
85
+ max_steps=4,
86
+ )
87
+ result = orch.run("test input")
88
+ assert result.text == "All done."
89
+ assert result.finish_reason == "complete"
90
+ assert len(result.trace) == 1
91
+ assert result.trace[0].name == "ping"
92
+ assert result.trace[0].args == {"msg": "hello"}
93
+ assert result.trace[0].result == {"echo": "pong:hello"}
94
+
95
+ def test_unknown_tool_recorded_as_error(self) -> None:
96
+ client = MagicMock()
97
+ client.chat.completions.create.side_effect = [
98
+ _fake_choice_with_tool_call("nonexistent_tool", {"x": 1}),
99
+ _fake_choice_with_text("Done."),
100
+ ]
101
+ orch = Orchestrator(
102
+ llm_client=client,
103
+ tools=[_make_ping_tool()],
104
+ system_prompt="sys",
105
+ model="stub-model",
106
+ max_steps=4,
107
+ )
108
+ result = orch.run("test")
109
+ assert result.trace[0].error is not None
110
+ assert "unknown tool" in result.trace[0].error
111
+ assert result.text == "Done."
112
+
113
+ def test_invalid_tool_args_recorded_as_error(self) -> None:
114
+ client = MagicMock()
115
+ client.chat.completions.create.side_effect = [
116
+ _fake_choice_with_tool_call("ping", {"wrong_field": "x"}),
117
+ _fake_choice_with_text("Recovered."),
118
+ ]
119
+ orch = Orchestrator(
120
+ llm_client=client,
121
+ tools=[_make_ping_tool()],
122
+ system_prompt="sys",
123
+ model="stub-model",
124
+ max_steps=4,
125
+ )
126
+ result = orch.run("test")
127
+ assert result.trace[0].error is not None
128
+ assert result.text == "Recovered."
129
+
130
+ def test_max_steps_exhausted_returns_finish_reason(self) -> None:
131
+ client = MagicMock()
132
+ # Always return another tool call — never terminates with text
133
+ client.chat.completions.create.side_effect = [
134
+ _fake_choice_with_tool_call("ping", {"msg": f"{i}"}, call_id=f"c{i}")
135
+ for i in range(10)
136
+ ]
137
+ orch = Orchestrator(
138
+ llm_client=client,
139
+ tools=[_make_ping_tool()],
140
+ system_prompt="sys",
141
+ model="stub-model",
142
+ max_steps=3,
143
+ )
144
+ result = orch.run("test")
145
+ assert result.finish_reason == "max_steps"
146
+ assert len(result.trace) == 3
147
+
148
+ def test_first_response_is_text_no_tools(self) -> None:
149
+ client = MagicMock()
150
+ client.chat.completions.create.side_effect = [
151
+ _fake_choice_with_text("Direct answer."),
152
+ ]
153
+ orch = Orchestrator(
154
+ llm_client=client,
155
+ tools=[_make_ping_tool()],
156
+ system_prompt="sys",
157
+ model="stub-model",
158
+ )
159
+ result = orch.run("trivial input")
160
+ assert result.text == "Direct answer."
161
+ assert result.trace == []