File size: 9,348 Bytes
2091a1b 8418e43 2091a1b c0a7163 2091a1b c0a7163 a0c0f61 8418e43 a0c0f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 | """Tests for src.agents.orchestrator — agent loop with stubbed LLM client.
We do NOT hit OpenRouter here. We construct a fake client that returns
scripted tool-call responses, then verify the orchestrator dispatches
tools and assembles the trace correctly.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from src.agents.orchestrator import Orchestrator
from src.agents.tools import Tool
# --- Helpers ----------------------------------------------------------------
def _fake_choice_with_tool_call(name: str, args: dict[str, Any], call_id: str = "c1") -> Any:
msg = MagicMock()
msg.content = None
tc = MagicMock()
tc.id = call_id
tc.function.name = name
tc.function.arguments = json.dumps(args)
tc.model_dump = MagicMock(return_value={"id": call_id, "type": "function",
"function": {"name": name,
"arguments": json.dumps(args)}})
msg.tool_calls = [tc]
choice = MagicMock()
choice.message = msg
response = MagicMock()
response.choices = [choice]
return response
def _fake_choice_with_text(text: str) -> Any:
msg = MagicMock()
msg.content = text
msg.tool_calls = None
choice = MagicMock()
choice.message = msg
response = MagicMock()
response.choices = [choice]
return response
class _PingInput(BaseModel):
msg: str
class _PingOutput(BaseModel):
echo: str
def _make_ping_tool() -> Tool:
return Tool(
name="ping",
description="Echo a string back.",
input_model=_PingInput,
output_model=_PingOutput,
execute=lambda inp: _PingOutput(echo=f"pong:{inp.msg}"),
)
class _BBBInput(BaseModel):
smiles: str
class _BBBOutput(BaseModel):
label_text: str
confidence: float
class _RetrieveInput(BaseModel):
query: str
k: int = 4
class _RetrieveOutput(BaseModel):
chunks: list[dict[str, Any]]
def _make_workflow_tools() -> list[Tool]:
return [
Tool(
name="run_bbb_pipeline",
description="Run BBB.",
input_model=_BBBInput,
output_model=_BBBOutput,
execute=lambda inp: _BBBOutput(label_text="permeable", confidence=0.82),
),
Tool(
name="retrieve_context",
description="Retrieve context.",
input_model=_RetrieveInput,
output_model=_RetrieveOutput,
execute=lambda inp: _RetrieveOutput(
chunks=[{"source": "lipinski.md", "text": "BBB context"}]
),
),
]
# --- Tests ------------------------------------------------------------------
class TestOrchestrator:
def test_single_tool_then_text_response(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("ping", {"msg": "hello"}),
_fake_choice_with_text("All done."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=4,
)
result = orch.run("test input")
assert result.text == "All done."
assert result.finish_reason == "complete"
assert len(result.trace) == 1
assert result.trace[0].name == "ping"
assert result.trace[0].args == {"msg": "hello"}
assert result.trace[0].result == {"echo": "pong:hello"}
def test_unknown_tool_recorded_as_error(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("nonexistent_tool", {"x": 1}),
_fake_choice_with_text("Done."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=4,
)
result = orch.run("test")
assert result.trace[0].error is not None
assert "unknown tool" in result.trace[0].error
assert result.text == "Done."
def test_invalid_tool_args_recorded_as_error(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("ping", {"wrong_field": "x"}),
_fake_choice_with_text("Recovered."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=4,
)
result = orch.run("test")
assert result.trace[0].error is not None
assert result.text == "Recovered."
def test_max_steps_exhausted_returns_finish_reason(self) -> None:
client = MagicMock()
# Always return another tool call — never terminates with text
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("ping", {"msg": f"{i}"}, call_id=f"c{i}")
for i in range(10)
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=3,
)
result = orch.run("test")
assert result.finish_reason == "max_steps"
assert len(result.trace) == 3
def test_first_response_is_text_no_tools(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_text("Direct answer."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
)
result = orch.run("trivial input")
assert result.text == "Direct answer."
assert result.trace == []
def test_enforced_workflow_falls_back_when_model_skips_tool_calls(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_text("I will answer directly."),
_fake_choice_with_text("Still no retrieval."),
_fake_choice_with_text("Grounded final answer."),
]
orch = Orchestrator(
llm_client=client,
tools=_make_workflow_tools(),
system_prompt="sys",
model="stub-model",
max_steps=5,
enforce_workflow=True,
workflow_pipeline_tools={"run_bbb_pipeline"},
workflow_retrieval_tool="retrieve_context",
workflow_router=lambda user_input, context: (
"run_bbb_pipeline",
{"smiles": user_input},
),
workflow_query_builder=lambda user_input, pipeline_trace, context: (
"BBB permeability of small lipophilic molecules"
),
)
result = orch.run("CCO")
assert result.finish_reason == "complete"
assert result.text == "Grounded final answer."
assert [t.name for t in result.trace] == ["run_bbb_pipeline", "retrieve_context"]
assert result.trace[0].result == {"label_text": "permeable", "confidence": 0.82}
assert result.trace[1].args["query"] == "BBB permeability of small lipophilic molecules"
def test_workflow_drops_out_of_stage_tool_call_with_log(
self, caplog: pytest.LogCaptureFixture
) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
# During the pipeline stage the model wrongly calls retrieve_context
_fake_choice_with_tool_call("retrieve_context", {"query": "x", "k": 4}),
# After the workflow guard runs the BBB pipeline, model produces text
_fake_choice_with_text("Skipping retrieval."),
# Then the guard runs retrieve_context, model finalizes
_fake_choice_with_text("Final answer."),
]
orch = Orchestrator(
llm_client=client,
tools=_make_workflow_tools(),
system_prompt="sys",
model="stub-model",
max_steps=5,
enforce_workflow=True,
workflow_pipeline_tools={"run_bbb_pipeline"},
workflow_retrieval_tool="retrieve_context",
workflow_router=lambda user_input, context: (
"run_bbb_pipeline",
{"smiles": user_input},
),
workflow_query_builder=lambda user_input, pipeline_trace, context: "q",
)
from src.agents import orchestrator as orch_module
caplog.handler.setLevel(logging.INFO)
orch_module.logger.addHandler(caplog.handler)
try:
result = orch.run("CCO")
finally:
orch_module.logger.removeHandler(caplog.handler)
assert result.finish_reason == "complete"
assert any(
"dropped out-of-stage tool call" in rec.message
and "retrieve_context" in rec.message
and "stage=pipeline" in rec.message
for rec in caplog.records
), [rec.message for rec in caplog.records]
|