hackathon / tests /agents /test_orchestrator.py
mekosotto's picture
test(agents/orchestrator): pin caplog handler to INFO (hardens against pytest variance)
8418e43
"""Tests for src.agents.orchestrator — agent loop with stubbed LLM client.
We do NOT hit OpenRouter here. We construct a fake client that returns
scripted tool-call responses, then verify the orchestrator dispatches
tools and assembles the trace correctly.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from src.agents.orchestrator import Orchestrator
from src.agents.tools import Tool
# --- Helpers ----------------------------------------------------------------
def _fake_choice_with_tool_call(name: str, args: dict[str, Any], call_id: str = "c1") -> Any:
msg = MagicMock()
msg.content = None
tc = MagicMock()
tc.id = call_id
tc.function.name = name
tc.function.arguments = json.dumps(args)
tc.model_dump = MagicMock(return_value={"id": call_id, "type": "function",
"function": {"name": name,
"arguments": json.dumps(args)}})
msg.tool_calls = [tc]
choice = MagicMock()
choice.message = msg
response = MagicMock()
response.choices = [choice]
return response
def _fake_choice_with_text(text: str) -> Any:
msg = MagicMock()
msg.content = text
msg.tool_calls = None
choice = MagicMock()
choice.message = msg
response = MagicMock()
response.choices = [choice]
return response
class _PingInput(BaseModel):
msg: str
class _PingOutput(BaseModel):
echo: str
def _make_ping_tool() -> Tool:
return Tool(
name="ping",
description="Echo a string back.",
input_model=_PingInput,
output_model=_PingOutput,
execute=lambda inp: _PingOutput(echo=f"pong:{inp.msg}"),
)
class _BBBInput(BaseModel):
smiles: str
class _BBBOutput(BaseModel):
label_text: str
confidence: float
class _RetrieveInput(BaseModel):
query: str
k: int = 4
class _RetrieveOutput(BaseModel):
chunks: list[dict[str, Any]]
def _make_workflow_tools() -> list[Tool]:
return [
Tool(
name="run_bbb_pipeline",
description="Run BBB.",
input_model=_BBBInput,
output_model=_BBBOutput,
execute=lambda inp: _BBBOutput(label_text="permeable", confidence=0.82),
),
Tool(
name="retrieve_context",
description="Retrieve context.",
input_model=_RetrieveInput,
output_model=_RetrieveOutput,
execute=lambda inp: _RetrieveOutput(
chunks=[{"source": "lipinski.md", "text": "BBB context"}]
),
),
]
# --- Tests ------------------------------------------------------------------
class TestOrchestrator:
def test_single_tool_then_text_response(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("ping", {"msg": "hello"}),
_fake_choice_with_text("All done."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=4,
)
result = orch.run("test input")
assert result.text == "All done."
assert result.finish_reason == "complete"
assert len(result.trace) == 1
assert result.trace[0].name == "ping"
assert result.trace[0].args == {"msg": "hello"}
assert result.trace[0].result == {"echo": "pong:hello"}
def test_unknown_tool_recorded_as_error(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("nonexistent_tool", {"x": 1}),
_fake_choice_with_text("Done."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=4,
)
result = orch.run("test")
assert result.trace[0].error is not None
assert "unknown tool" in result.trace[0].error
assert result.text == "Done."
def test_invalid_tool_args_recorded_as_error(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("ping", {"wrong_field": "x"}),
_fake_choice_with_text("Recovered."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=4,
)
result = orch.run("test")
assert result.trace[0].error is not None
assert result.text == "Recovered."
def test_max_steps_exhausted_returns_finish_reason(self) -> None:
client = MagicMock()
# Always return another tool call — never terminates with text
client.chat.completions.create.side_effect = [
_fake_choice_with_tool_call("ping", {"msg": f"{i}"}, call_id=f"c{i}")
for i in range(10)
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
max_steps=3,
)
result = orch.run("test")
assert result.finish_reason == "max_steps"
assert len(result.trace) == 3
def test_first_response_is_text_no_tools(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_text("Direct answer."),
]
orch = Orchestrator(
llm_client=client,
tools=[_make_ping_tool()],
system_prompt="sys",
model="stub-model",
)
result = orch.run("trivial input")
assert result.text == "Direct answer."
assert result.trace == []
def test_enforced_workflow_falls_back_when_model_skips_tool_calls(self) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
_fake_choice_with_text("I will answer directly."),
_fake_choice_with_text("Still no retrieval."),
_fake_choice_with_text("Grounded final answer."),
]
orch = Orchestrator(
llm_client=client,
tools=_make_workflow_tools(),
system_prompt="sys",
model="stub-model",
max_steps=5,
enforce_workflow=True,
workflow_pipeline_tools={"run_bbb_pipeline"},
workflow_retrieval_tool="retrieve_context",
workflow_router=lambda user_input, context: (
"run_bbb_pipeline",
{"smiles": user_input},
),
workflow_query_builder=lambda user_input, pipeline_trace, context: (
"BBB permeability of small lipophilic molecules"
),
)
result = orch.run("CCO")
assert result.finish_reason == "complete"
assert result.text == "Grounded final answer."
assert [t.name for t in result.trace] == ["run_bbb_pipeline", "retrieve_context"]
assert result.trace[0].result == {"label_text": "permeable", "confidence": 0.82}
assert result.trace[1].args["query"] == "BBB permeability of small lipophilic molecules"
def test_workflow_drops_out_of_stage_tool_call_with_log(
self, caplog: pytest.LogCaptureFixture
) -> None:
client = MagicMock()
client.chat.completions.create.side_effect = [
# During the pipeline stage the model wrongly calls retrieve_context
_fake_choice_with_tool_call("retrieve_context", {"query": "x", "k": 4}),
# After the workflow guard runs the BBB pipeline, model produces text
_fake_choice_with_text("Skipping retrieval."),
# Then the guard runs retrieve_context, model finalizes
_fake_choice_with_text("Final answer."),
]
orch = Orchestrator(
llm_client=client,
tools=_make_workflow_tools(),
system_prompt="sys",
model="stub-model",
max_steps=5,
enforce_workflow=True,
workflow_pipeline_tools={"run_bbb_pipeline"},
workflow_retrieval_tool="retrieve_context",
workflow_router=lambda user_input, context: (
"run_bbb_pipeline",
{"smiles": user_input},
),
workflow_query_builder=lambda user_input, pipeline_trace, context: "q",
)
from src.agents import orchestrator as orch_module
caplog.handler.setLevel(logging.INFO)
orch_module.logger.addHandler(caplog.handler)
try:
result = orch.run("CCO")
finally:
orch_module.logger.removeHandler(caplog.handler)
assert result.finish_reason == "complete"
assert any(
"dropped out-of-stage tool call" in rec.message
and "retrieve_context" in rec.message
and "stage=pipeline" in rec.message
for rec in caplog.records
), [rec.message for rec in caplog.records]