feat(agents): orchestrator loop (function-calling + tool trace + max-steps gate)
Browse files- src/agents/orchestrator.py +108 -0
- src/agents/prompts.py +49 -0
- tests/agents/test_orchestrator.py +161 -0
src/agents/orchestrator.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Orchestrator agent: function-calling loop over a list of Tools.
|
| 2 |
+
|
| 3 |
+
No agent framework — uses the openai SDK's chat-completions function-calling
|
| 4 |
+
interface directly. This is the same SDK already used by src/llm/explainer.py,
|
| 5 |
+
keeping the dependency surface minimal.
|
| 6 |
+
|
| 7 |
+
Public entry: `Orchestrator(llm_client, tools, system_prompt, model).run(user_input)`.
|
| 8 |
+
Returns an `AgentResult` with synthesized text + full tool-call trace.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from src.agents.schemas import AgentResult, ToolTraceItem
|
| 16 |
+
from src.agents.tools import Tool
|
| 17 |
+
from src.core.logger import get_logger
|
| 18 |
+
|
| 19 |
+
logger = get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Orchestrator:
|
| 23 |
+
"""Single-agent function-calling loop. Stops on (a) text response, (b) max steps."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
llm_client: Any,
|
| 28 |
+
tools: list[Tool],
|
| 29 |
+
system_prompt: str,
|
| 30 |
+
model: str,
|
| 31 |
+
max_steps: int = 5,
|
| 32 |
+
temperature: float = 0.0,
|
| 33 |
+
) -> None:
|
| 34 |
+
self._client = llm_client
|
| 35 |
+
self._tools_by_name = {t.name: t for t in tools}
|
| 36 |
+
self._tool_schemas = [t.openai_schema() for t in tools]
|
| 37 |
+
self._system_prompt = system_prompt
|
| 38 |
+
self._model = model
|
| 39 |
+
self._max_steps = max_steps
|
| 40 |
+
self._temperature = temperature
|
| 41 |
+
|
| 42 |
+
def run(self, user_input: str) -> AgentResult:
|
| 43 |
+
messages: list[dict[str, Any]] = [
|
| 44 |
+
{"role": "system", "content": self._system_prompt},
|
| 45 |
+
{"role": "user", "content": user_input},
|
| 46 |
+
]
|
| 47 |
+
trace: list[ToolTraceItem] = []
|
| 48 |
+
|
| 49 |
+
for _step in range(self._max_steps):
|
| 50 |
+
response = self._client.chat.completions.create(
|
| 51 |
+
model=self._model,
|
| 52 |
+
messages=messages,
|
| 53 |
+
tools=self._tool_schemas,
|
| 54 |
+
tool_choice="auto",
|
| 55 |
+
temperature=self._temperature,
|
| 56 |
+
)
|
| 57 |
+
msg = response.choices[0].message
|
| 58 |
+
|
| 59 |
+
if not getattr(msg, "tool_calls", None):
|
| 60 |
+
return AgentResult(
|
| 61 |
+
text=(msg.content or "").strip(),
|
| 62 |
+
trace=trace,
|
| 63 |
+
model=self._model,
|
| 64 |
+
finish_reason="complete",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
messages.append({
|
| 68 |
+
"role": "assistant",
|
| 69 |
+
"content": msg.content,
|
| 70 |
+
"tool_calls": [tc.model_dump() for tc in msg.tool_calls],
|
| 71 |
+
})
|
| 72 |
+
|
| 73 |
+
for tc in msg.tool_calls:
|
| 74 |
+
name = tc.function.name
|
| 75 |
+
tool = self._tools_by_name.get(name)
|
| 76 |
+
if tool is None:
|
| 77 |
+
err = f"unknown tool: {name}"
|
| 78 |
+
trace.append(ToolTraceItem(name=name, args={}, error=err))
|
| 79 |
+
messages.append({
|
| 80 |
+
"role": "tool",
|
| 81 |
+
"tool_call_id": tc.id,
|
| 82 |
+
"content": json.dumps({"error": err}),
|
| 83 |
+
})
|
| 84 |
+
continue
|
| 85 |
+
try:
|
| 86 |
+
args = json.loads(tc.function.arguments or "{}")
|
| 87 |
+
result = tool.invoke(args)
|
| 88 |
+
trace.append(ToolTraceItem(name=name, args=args, result=result))
|
| 89 |
+
messages.append({
|
| 90 |
+
"role": "tool",
|
| 91 |
+
"tool_call_id": tc.id,
|
| 92 |
+
"content": json.dumps({"result": result}, default=str),
|
| 93 |
+
})
|
| 94 |
+
except Exception as e:
|
| 95 |
+
err = str(e)
|
| 96 |
+
trace.append(ToolTraceItem(name=name, args={}, error=err))
|
| 97 |
+
messages.append({
|
| 98 |
+
"role": "tool",
|
| 99 |
+
"tool_call_id": tc.id,
|
| 100 |
+
"content": json.dumps({"error": err}),
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
return AgentResult(
|
| 104 |
+
text="Max steps reached without a final answer.",
|
| 105 |
+
trace=trace,
|
| 106 |
+
model=self._model,
|
| 107 |
+
finish_reason="max_steps",
|
| 108 |
+
)
|
src/agents/prompts.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""System prompts for the orchestrator agent.
|
| 2 |
+
|
| 3 |
+
Kept in a dedicated module so prompt edits are diff-readable and reviewable
|
| 4 |
+
in isolation from the orchestrator loop.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
ORCHESTRATOR_SYSTEM_PROMPT = """\
|
| 10 |
+
You are the NeuroBridge clinical-ML orchestrator. You have four tools:
|
| 11 |
+
|
| 12 |
+
- run_bbb_pipeline(smiles, top_k=5) → for a SMILES molecular string
|
| 13 |
+
- run_eeg_pipeline(input_path) → for a .fif or .edf EEG file path
|
| 14 |
+
- run_mri_pipeline(input_dir, sites_csv) → for a directory of NIfTI MRI files
|
| 15 |
+
- retrieve_context(query, k=4) → for grounding chunks from the knowledge base
|
| 16 |
+
|
| 17 |
+
Workflow — follow exactly:
|
| 18 |
+
|
| 19 |
+
1. Look at the user input. Decide which ONE pipeline tool fits:
|
| 20 |
+
- SMILES (short, all-letters/digits, no slashes, no .ext) → run_bbb_pipeline
|
| 21 |
+
- Path ending in .fif or .edf → run_eeg_pipeline
|
| 22 |
+
- Path that is a directory (no file extension at the tail) → run_mri_pipeline
|
| 23 |
+
If ambiguous, prefer SMILES if it parses; otherwise return:
|
| 24 |
+
"Cannot identify modality. Provide a SMILES, .fif/.edf path, or NIfTI directory."
|
| 25 |
+
|
| 26 |
+
2. Call the chosen pipeline tool exactly once with the user input.
|
| 27 |
+
|
| 28 |
+
3. After the pipeline returns, formulate ONE focused retrieval query that
|
| 29 |
+
captures the scientific concept behind the prediction (NOT the raw input).
|
| 30 |
+
Examples of good queries:
|
| 31 |
+
- "BBB permeability of small lipophilic molecules" (after BBB predict)
|
| 32 |
+
- "ICA artifact removal in multi-channel EEG" (after EEG run)
|
| 33 |
+
- "ComBat scanner site harmonization in multi-center MRI" (after MRI run)
|
| 34 |
+
Then call retrieve_context with that query.
|
| 35 |
+
|
| 36 |
+
4. Synthesize a final response in 3-5 sentences:
|
| 37 |
+
- State the concrete pipeline result (label, confidence, key numbers).
|
| 38 |
+
- Cite at least one specific fact from the retrieved chunks (mention the
|
| 39 |
+
source file in parentheses, e.g. "(lipinski_rule_of_five.md)").
|
| 40 |
+
- Match the user's question language: Turkish in → Turkish out, etc.
|
| 41 |
+
- If retrieve_context returned 0 chunks, say so explicitly and answer
|
| 42 |
+
using only the pipeline result.
|
| 43 |
+
|
| 44 |
+
Hard constraints:
|
| 45 |
+
- Call exactly ONE pipeline tool, then exactly ONE retrieve_context, then stop.
|
| 46 |
+
- Do NOT invent facts. Only use numbers from the pipeline tool output and
|
| 47 |
+
text from the retrieved chunks.
|
| 48 |
+
- No preamble, no apologies, no meta-commentary about being an AI.
|
| 49 |
+
"""
|
tests/agents/test_orchestrator.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.agents.orchestrator — agent loop with stubbed LLM client.
|
| 2 |
+
|
| 3 |
+
We do NOT hit OpenRouter here. We construct a fake client that returns
|
| 4 |
+
scripted tool-call responses, then verify the orchestrator dispatches
|
| 5 |
+
tools and assembles the trace correctly.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from typing import Any
|
| 11 |
+
from unittest.mock import MagicMock
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
from src.agents.orchestrator import Orchestrator
|
| 17 |
+
from src.agents.tools import Tool
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# --- Helpers ----------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _fake_choice_with_tool_call(name: str, args: dict[str, Any], call_id: str = "c1") -> Any:
|
| 24 |
+
msg = MagicMock()
|
| 25 |
+
msg.content = None
|
| 26 |
+
tc = MagicMock()
|
| 27 |
+
tc.id = call_id
|
| 28 |
+
tc.function.name = name
|
| 29 |
+
tc.function.arguments = json.dumps(args)
|
| 30 |
+
tc.model_dump = MagicMock(return_value={"id": call_id, "type": "function",
|
| 31 |
+
"function": {"name": name,
|
| 32 |
+
"arguments": json.dumps(args)}})
|
| 33 |
+
msg.tool_calls = [tc]
|
| 34 |
+
choice = MagicMock()
|
| 35 |
+
choice.message = msg
|
| 36 |
+
response = MagicMock()
|
| 37 |
+
response.choices = [choice]
|
| 38 |
+
return response
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _fake_choice_with_text(text: str) -> Any:
|
| 42 |
+
msg = MagicMock()
|
| 43 |
+
msg.content = text
|
| 44 |
+
msg.tool_calls = None
|
| 45 |
+
choice = MagicMock()
|
| 46 |
+
choice.message = msg
|
| 47 |
+
response = MagicMock()
|
| 48 |
+
response.choices = [choice]
|
| 49 |
+
return response
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _PingInput(BaseModel):
|
| 53 |
+
msg: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class _PingOutput(BaseModel):
|
| 57 |
+
echo: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _make_ping_tool() -> Tool:
|
| 61 |
+
return Tool(
|
| 62 |
+
name="ping",
|
| 63 |
+
description="Echo a string back.",
|
| 64 |
+
input_model=_PingInput,
|
| 65 |
+
output_model=_PingOutput,
|
| 66 |
+
execute=lambda inp: _PingOutput(echo=f"pong:{inp.msg}"),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --- Tests ------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TestOrchestrator:
|
| 74 |
+
def test_single_tool_then_text_response(self) -> None:
|
| 75 |
+
client = MagicMock()
|
| 76 |
+
client.chat.completions.create.side_effect = [
|
| 77 |
+
_fake_choice_with_tool_call("ping", {"msg": "hello"}),
|
| 78 |
+
_fake_choice_with_text("All done."),
|
| 79 |
+
]
|
| 80 |
+
orch = Orchestrator(
|
| 81 |
+
llm_client=client,
|
| 82 |
+
tools=[_make_ping_tool()],
|
| 83 |
+
system_prompt="sys",
|
| 84 |
+
model="stub-model",
|
| 85 |
+
max_steps=4,
|
| 86 |
+
)
|
| 87 |
+
result = orch.run("test input")
|
| 88 |
+
assert result.text == "All done."
|
| 89 |
+
assert result.finish_reason == "complete"
|
| 90 |
+
assert len(result.trace) == 1
|
| 91 |
+
assert result.trace[0].name == "ping"
|
| 92 |
+
assert result.trace[0].args == {"msg": "hello"}
|
| 93 |
+
assert result.trace[0].result == {"echo": "pong:hello"}
|
| 94 |
+
|
| 95 |
+
def test_unknown_tool_recorded_as_error(self) -> None:
|
| 96 |
+
client = MagicMock()
|
| 97 |
+
client.chat.completions.create.side_effect = [
|
| 98 |
+
_fake_choice_with_tool_call("nonexistent_tool", {"x": 1}),
|
| 99 |
+
_fake_choice_with_text("Done."),
|
| 100 |
+
]
|
| 101 |
+
orch = Orchestrator(
|
| 102 |
+
llm_client=client,
|
| 103 |
+
tools=[_make_ping_tool()],
|
| 104 |
+
system_prompt="sys",
|
| 105 |
+
model="stub-model",
|
| 106 |
+
max_steps=4,
|
| 107 |
+
)
|
| 108 |
+
result = orch.run("test")
|
| 109 |
+
assert result.trace[0].error is not None
|
| 110 |
+
assert "unknown tool" in result.trace[0].error
|
| 111 |
+
assert result.text == "Done."
|
| 112 |
+
|
| 113 |
+
def test_invalid_tool_args_recorded_as_error(self) -> None:
|
| 114 |
+
client = MagicMock()
|
| 115 |
+
client.chat.completions.create.side_effect = [
|
| 116 |
+
_fake_choice_with_tool_call("ping", {"wrong_field": "x"}),
|
| 117 |
+
_fake_choice_with_text("Recovered."),
|
| 118 |
+
]
|
| 119 |
+
orch = Orchestrator(
|
| 120 |
+
llm_client=client,
|
| 121 |
+
tools=[_make_ping_tool()],
|
| 122 |
+
system_prompt="sys",
|
| 123 |
+
model="stub-model",
|
| 124 |
+
max_steps=4,
|
| 125 |
+
)
|
| 126 |
+
result = orch.run("test")
|
| 127 |
+
assert result.trace[0].error is not None
|
| 128 |
+
assert result.text == "Recovered."
|
| 129 |
+
|
| 130 |
+
def test_max_steps_exhausted_returns_finish_reason(self) -> None:
|
| 131 |
+
client = MagicMock()
|
| 132 |
+
# Always return another tool call — never terminates with text
|
| 133 |
+
client.chat.completions.create.side_effect = [
|
| 134 |
+
_fake_choice_with_tool_call("ping", {"msg": f"{i}"}, call_id=f"c{i}")
|
| 135 |
+
for i in range(10)
|
| 136 |
+
]
|
| 137 |
+
orch = Orchestrator(
|
| 138 |
+
llm_client=client,
|
| 139 |
+
tools=[_make_ping_tool()],
|
| 140 |
+
system_prompt="sys",
|
| 141 |
+
model="stub-model",
|
| 142 |
+
max_steps=3,
|
| 143 |
+
)
|
| 144 |
+
result = orch.run("test")
|
| 145 |
+
assert result.finish_reason == "max_steps"
|
| 146 |
+
assert len(result.trace) == 3
|
| 147 |
+
|
| 148 |
+
def test_first_response_is_text_no_tools(self) -> None:
|
| 149 |
+
client = MagicMock()
|
| 150 |
+
client.chat.completions.create.side_effect = [
|
| 151 |
+
_fake_choice_with_text("Direct answer."),
|
| 152 |
+
]
|
| 153 |
+
orch = Orchestrator(
|
| 154 |
+
llm_client=client,
|
| 155 |
+
tools=[_make_ping_tool()],
|
| 156 |
+
system_prompt="sys",
|
| 157 |
+
model="stub-model",
|
| 158 |
+
)
|
| 159 |
+
result = orch.run("trivial input")
|
| 160 |
+
assert result.text == "Direct answer."
|
| 161 |
+
assert result.trace == []
|