fix(agents/orchestrator): log dropped out-of-stage tool calls (was silent)
Browse files- src/agents/orchestrator.py +18 -0
- tests/agents/test_orchestrator.py +41 -0
src/agents/orchestrator.py
CHANGED
|
@@ -226,12 +226,30 @@ class Orchestrator:
|
|
| 226 |
for tc in tool_calls:
|
| 227 |
if tc.function.name in self._workflow_pipeline_tools:
|
| 228 |
return [tc]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
return []
|
| 230 |
if stage == "retrieve":
|
| 231 |
for tc in tool_calls:
|
| 232 |
if tc.function.name == self._workflow_retrieval_tool:
|
| 233 |
return [tc]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
return []
|
| 236 |
|
| 237 |
def _invoke_routed_pipeline(
|
|
|
|
| 226 |
for tc in tool_calls:
|
| 227 |
if tc.function.name in self._workflow_pipeline_tools:
|
| 228 |
return [tc]
|
| 229 |
+
for tc in tool_calls:
|
| 230 |
+
logger.info(
|
| 231 |
+
"dropped out-of-stage tool call: name=%s stage=%s",
|
| 232 |
+
tc.function.name,
|
| 233 |
+
stage,
|
| 234 |
+
)
|
| 235 |
return []
|
| 236 |
if stage == "retrieve":
|
| 237 |
for tc in tool_calls:
|
| 238 |
if tc.function.name == self._workflow_retrieval_tool:
|
| 239 |
return [tc]
|
| 240 |
+
for tc in tool_calls:
|
| 241 |
+
logger.info(
|
| 242 |
+
"dropped out-of-stage tool call: name=%s stage=%s",
|
| 243 |
+
tc.function.name,
|
| 244 |
+
stage,
|
| 245 |
+
)
|
| 246 |
return []
|
| 247 |
+
for tc in tool_calls:
|
| 248 |
+
logger.info(
|
| 249 |
+
"dropped out-of-stage tool call: name=%s stage=%s",
|
| 250 |
+
tc.function.name,
|
| 251 |
+
stage,
|
| 252 |
+
)
|
| 253 |
return []
|
| 254 |
|
| 255 |
def _invoke_routed_pipeline(
|
tests/agents/test_orchestrator.py
CHANGED
|
@@ -229,3 +229,44 @@ class TestOrchestrator:
|
|
| 229 |
assert [t.name for t in result.trace] == ["run_bbb_pipeline", "retrieve_context"]
|
| 230 |
assert result.trace[0].result == {"label_text": "permeable", "confidence": 0.82}
|
| 231 |
assert result.trace[1].args["query"] == "BBB permeability of small lipophilic molecules"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
assert [t.name for t in result.trace] == ["run_bbb_pipeline", "retrieve_context"]
|
| 230 |
assert result.trace[0].result == {"label_text": "permeable", "confidence": 0.82}
|
| 231 |
assert result.trace[1].args["query"] == "BBB permeability of small lipophilic molecules"
|
| 232 |
+
|
| 233 |
+
def test_workflow_drops_out_of_stage_tool_call_with_log(
|
| 234 |
+
self, caplog: pytest.LogCaptureFixture
|
| 235 |
+
) -> None:
|
| 236 |
+
client = MagicMock()
|
| 237 |
+
client.chat.completions.create.side_effect = [
|
| 238 |
+
# During the pipeline stage the model wrongly calls retrieve_context
|
| 239 |
+
_fake_choice_with_tool_call("retrieve_context", {"query": "x", "k": 4}),
|
| 240 |
+
# After the workflow guard runs the BBB pipeline, model produces text
|
| 241 |
+
_fake_choice_with_text("Skipping retrieval."),
|
| 242 |
+
# Then the guard runs retrieve_context, model finalizes
|
| 243 |
+
_fake_choice_with_text("Final answer."),
|
| 244 |
+
]
|
| 245 |
+
orch = Orchestrator(
|
| 246 |
+
llm_client=client,
|
| 247 |
+
tools=_make_workflow_tools(),
|
| 248 |
+
system_prompt="sys",
|
| 249 |
+
model="stub-model",
|
| 250 |
+
max_steps=5,
|
| 251 |
+
enforce_workflow=True,
|
| 252 |
+
workflow_pipeline_tools={"run_bbb_pipeline"},
|
| 253 |
+
workflow_retrieval_tool="retrieve_context",
|
| 254 |
+
workflow_router=lambda user_input, context: (
|
| 255 |
+
"run_bbb_pipeline",
|
| 256 |
+
{"smiles": user_input},
|
| 257 |
+
),
|
| 258 |
+
workflow_query_builder=lambda user_input, pipeline_trace, context: "q",
|
| 259 |
+
)
|
| 260 |
+
from src.agents import orchestrator as orch_module
|
| 261 |
+
orch_module.logger.addHandler(caplog.handler)
|
| 262 |
+
try:
|
| 263 |
+
result = orch.run("CCO")
|
| 264 |
+
finally:
|
| 265 |
+
orch_module.logger.removeHandler(caplog.handler)
|
| 266 |
+
assert result.finish_reason == "complete"
|
| 267 |
+
assert any(
|
| 268 |
+
"dropped out-of-stage tool call" in rec.message
|
| 269 |
+
and "retrieve_context" in rec.message
|
| 270 |
+
and "stage=pipeline" in rec.message
|
| 271 |
+
for rec in caplog.records
|
| 272 |
+
), [rec.message for rec in caplog.records]
|