"""Orchestrator agent: function-calling loop over a list of Tools. No agent framework — uses the openai SDK's chat-completions function-calling interface directly. This is the same SDK already used by src/llm/explainer.py, keeping the dependency surface minimal. Public entry: `Orchestrator(llm_client, tools, system_prompt, model).run(user_input)`. Returns an `AgentResult` with synthesized text + full tool-call trace. """ from __future__ import annotations import json from collections.abc import Callable from typing import Any from src.agents.schemas import AgentResult, ToolTraceItem from src.agents.tools import Tool from src.core.logger import get_logger logger = get_logger(__name__) WorkflowRouter = Callable[[str, dict[str, Any] | None], tuple[str, dict[str, Any]] | None] WorkflowQueryBuilder = Callable[[str, ToolTraceItem, dict[str, Any] | None], str] class Orchestrator: """Single-agent function-calling loop. Stops on (a) text response, (b) max steps.""" def __init__( self, llm_client: Any, tools: list[Tool], system_prompt: str, model: str, max_steps: int = 5, temperature: float = 0.0, enforce_workflow: bool = False, workflow_pipeline_tools: set[str] | None = None, workflow_retrieval_tool: str | None = None, workflow_router: WorkflowRouter | None = None, workflow_query_builder: WorkflowQueryBuilder | None = None, ) -> None: self._client = llm_client self._tools_by_name = {t.name: t for t in tools} self._tool_schemas = [t.openai_schema() for t in tools] self._tool_schemas_by_name = { t.name: t.openai_schema() for t in tools } self._system_prompt = system_prompt self._model = model self._max_steps = max_steps self._temperature = temperature self._enforce_workflow = enforce_workflow self._workflow_pipeline_tools = workflow_pipeline_tools or set() self._workflow_retrieval_tool = workflow_retrieval_tool self._workflow_router = workflow_router self._workflow_query_builder = workflow_query_builder def run( self, user_input: str, context: dict[str, Any] | None = None, ) -> AgentResult: messages: list[dict[str, Any]] = [ {"role": "system", "content": self._system_prompt}, {"role": "user", "content": user_input}, ] trace: list[ToolTraceItem] = [] for _step in range(self._max_steps): stage = self._workflow_stage(trace) request_kwargs = self._completion_kwargs(messages, stage) response = self._client.chat.completions.create(**request_kwargs) msg = response.choices[0].message if not getattr(msg, "tool_calls", None): if self._enforce_workflow and stage == "pipeline": if self._invoke_routed_pipeline(user_input, context, trace, messages): continue return AgentResult( text=( "Cannot identify modality. Provide a SMILES, .fif/.edf " "path, or NIfTI directory." ), trace=trace, model=self._model, finish_reason="error", ) if self._enforce_workflow and stage == "retrieve": if self._invoke_fallback_retrieval(user_input, context, trace, messages): continue return AgentResult( text="Pipeline completed, but retrieval could not be executed.", trace=trace, model=self._model, finish_reason="error", ) return AgentResult( text=(msg.content or "").strip(), trace=trace, model=self._model, finish_reason="complete", ) selected_tool_calls = self._select_tool_calls(msg.tool_calls, stage) if self._enforce_workflow and not selected_tool_calls: if stage == "pipeline": if self._invoke_routed_pipeline(user_input, context, trace, messages): continue return AgentResult( text=( "Cannot identify modality. Provide a SMILES, .fif/.edf " "path, or NIfTI directory." ), trace=trace, model=self._model, finish_reason="error", ) if stage == "retrieve": if self._invoke_fallback_retrieval(user_input, context, trace, messages): continue return AgentResult( text="Pipeline completed, but retrieval could not be executed.", trace=trace, model=self._model, finish_reason="error", ) messages.append({ "role": "assistant", "content": msg.content, "tool_calls": [tc.model_dump() for tc in selected_tool_calls], }) for tc in selected_tool_calls: name = tc.function.name tool = self._tools_by_name.get(name) if tool is None: err = f"unknown tool: {name}" trace.append(ToolTraceItem(name=name, args={}, error=err)) messages.append({ "role": "tool", "tool_call_id": tc.id, "content": json.dumps({"error": err}), }) continue try: args = json.loads(tc.function.arguments or "{}") result = tool.invoke(args) trace.append(ToolTraceItem(name=name, args=args, result=result)) messages.append({ "role": "tool", "tool_call_id": tc.id, "content": json.dumps({"result": result}, default=str), }) except Exception as e: err = str(e) trace.append(ToolTraceItem(name=name, args={}, error=err)) messages.append({ "role": "tool", "tool_call_id": tc.id, "content": json.dumps({"error": err}), }) return AgentResult( text="Max steps reached without a final answer.", trace=trace, model=self._model, finish_reason="max_steps", ) def _completion_kwargs( self, messages: list[dict[str, Any]], stage: str, ) -> dict[str, Any]: kwargs: dict[str, Any] = { "model": self._model, "messages": messages, "temperature": self._temperature, } if not self._enforce_workflow: kwargs["tools"] = self._tool_schemas kwargs["tool_choice"] = "auto" return kwargs schemas = self._schemas_for_stage(stage) if schemas: kwargs["tools"] = schemas kwargs["tool_choice"] = "auto" return kwargs def _schemas_for_stage(self, stage: str) -> list[dict[str, Any]]: if stage == "pipeline": return [ self._tool_schemas_by_name[name] for name in sorted(self._workflow_pipeline_tools) if name in self._tool_schemas_by_name ] if stage == "retrieve" and self._workflow_retrieval_tool: schema = self._tool_schemas_by_name.get(self._workflow_retrieval_tool) return [schema] if schema else [] return [] def _workflow_stage(self, trace: list[ToolTraceItem]) -> str: if not self._enforce_workflow: return "open" has_pipeline = any( t.name in self._workflow_pipeline_tools and t.result is not None and t.error is None for t in trace ) if not has_pipeline: return "pipeline" has_retrieval = any( t.name == self._workflow_retrieval_tool and t.result is not None and t.error is None for t in trace ) return "final" if has_retrieval else "retrieve" def _select_tool_calls(self, tool_calls: list[Any], stage: str) -> list[Any]: if not self._enforce_workflow: return list(tool_calls) if stage == "pipeline": for tc in tool_calls: if tc.function.name in self._workflow_pipeline_tools: return [tc] for tc in tool_calls: logger.info( "dropped out-of-stage tool call: name=%s stage=%s", tc.function.name, stage, ) return [] if stage == "retrieve": for tc in tool_calls: if tc.function.name == self._workflow_retrieval_tool: return [tc] for tc in tool_calls: logger.info( "dropped out-of-stage tool call: name=%s stage=%s", tc.function.name, stage, ) return [] for tc in tool_calls: logger.info( "dropped out-of-stage tool call: name=%s stage=%s", tc.function.name, stage, ) return [] def _invoke_routed_pipeline( self, user_input: str, context: dict[str, Any] | None, trace: list[ToolTraceItem], messages: list[dict[str, Any]], ) -> bool: if self._workflow_router is None: return False routed = self._workflow_router(user_input, context) if routed is None: return False name, args = routed tool = self._tools_by_name.get(name) if tool is None: trace.append(ToolTraceItem(name=name, args=args, error=f"unknown tool: {name}")) return False try: result = tool.invoke(args) trace.append(ToolTraceItem(name=name, args=args, result=result)) messages.append({ "role": "user", "content": ( "Workflow guard executed the required pipeline tool. " f"Tool: {name}. Result: {json.dumps(result, default=str)}. " "Now call retrieve_context with a focused scientific query." ), }) return True except Exception as e: trace.append(ToolTraceItem(name=name, args=args, error=str(e))) return False def _invoke_fallback_retrieval( self, user_input: str, context: dict[str, Any] | None, trace: list[ToolTraceItem], messages: list[dict[str, Any]], ) -> bool: if self._workflow_retrieval_tool is None or self._workflow_query_builder is None: return False pipeline_trace = next( ( t for t in trace if t.name in self._workflow_pipeline_tools and t.result is not None and t.error is None ), None, ) if pipeline_trace is None: return False tool = self._tools_by_name.get(self._workflow_retrieval_tool) if tool is None: return False query = self._workflow_query_builder(user_input, pipeline_trace, context) args = {"query": query, "k": 4} try: result = tool.invoke(args) trace.append(ToolTraceItem( name=self._workflow_retrieval_tool, args=args, result=result, )) messages.append({ "role": "user", "content": ( "Workflow guard executed retrieve_context. " f"Result: {json.dumps(result, default=str)}. " "Now synthesize the final answer in the user's language." ), }) return True except Exception as e: trace.append(ToolTraceItem( name=self._workflow_retrieval_tool, args=args, error=str(e), )) return False