| |
| """ |
| Multi-Step Tool Calling Orchestrator |
| ===================================== |
| |
| A production-ready orchestrator for multi-step LLM tool calling workflows. |
| Handles step isolation, Pydantic validation, retry logic, and error feedback. |
| |
| This is the architecture pattern that makes complex tool calling reliable: |
| - Each step has its own isolated set of tools |
| - LLM responses are validated against Pydantic schemas |
| - Failed validations are fed back to the LLM with structured error messages |
| - Validation tracks whether tools PASSED, not just whether they were CALLED |
| |
| Usage: |
| python multi_step_orchestrator.py --url http://localhost:8000 --model NousResearch/Hermes-3-Llama-3.1-70B-FP8 |
| |
| Example workflow (generic): |
| Step 1: Discover what components are needed (search, list, get_info tools) |
| Step 2: Configure each component (get_details, validate tools) |
| """ |
|
|
| import argparse |
| import json |
| import requests |
| from typing import Any, Callable, Dict, List, Optional |
| from pydantic import BaseModel, ValidationError |
|
|
| from robust_json_extraction import extract_json, extract_tool_calls |
| from pydantic_tool_schemas import ( |
| FunctionCall, ToolCall, StepResponse, ValidationTracker, make_tool_schema |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class VLLMClient: |
| """Simple client for VLLM's OpenAI-compatible API.""" |
|
|
| def __init__(self, base_url: str, model: str): |
| self.base_url = base_url.rstrip('/') |
| self.model = model |
|
|
| def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None, |
| temperature: float = 0.1, max_tokens: int = 4096) -> str: |
| """Send a chat completion and return the assistant's content.""" |
| payload = { |
| "model": self.model, |
| "messages": messages, |
| "temperature": temperature, |
| "max_tokens": max_tokens, |
| } |
| if tools: |
| payload["tools"] = tools |
| payload["tool_choice"] = "auto" |
|
|
| response = requests.post( |
| f"{self.base_url}/v1/chat/completions", |
| json=payload, |
| timeout=120 |
| ) |
| response.raise_for_status() |
| return response.json()["choices"][0]["message"].get("content", "") |
|
|
|
|
| |
| |
| |
|
|
| class ToolRegistry: |
| """Registry of available tools, organized by step.""" |
|
|
| def __init__(self): |
| self._tools: Dict[str, Dict[str, Any]] = {} |
|
|
| def register(self, name: str, description: str, parameters: dict, |
| function: Callable, steps: Optional[List[int]] = None): |
| """Register a tool available in specific steps (or all steps if None).""" |
| self._tools[name] = { |
| "schema": make_tool_schema(name, description, parameters), |
| "function": function, |
| "steps": steps, |
| } |
|
|
| def get_schemas(self, step: Optional[int] = None) -> List[Dict]: |
| """Get OpenAI-format tool schemas for a given step.""" |
| schemas = [] |
| for name, tool in self._tools.items(): |
| if tool["steps"] is None or (step is not None and step in tool["steps"]): |
| schemas.append(tool["schema"]) |
| return schemas |
|
|
| def execute(self, name: str, arguments: dict) -> dict: |
| """Execute a tool by name with given arguments.""" |
| if name not in self._tools: |
| return {"error": f"Unknown tool: {name}"} |
| try: |
| return self._tools[name]["function"](**arguments) |
| except Exception as e: |
| return {"error": f"Error executing {name}: {str(e)}"} |
|
|
|
|
| |
| |
| |
|
|
| def run_step( |
| client: VLLMClient, |
| registry: ToolRegistry, |
| step_num: int, |
| system_prompt: str, |
| initial_context: str, |
| schema_class: type = StepResponse, |
| max_iterations: int = 10, |
| validation_tools: Optional[List[str]] = None, |
| ) -> Optional[Dict]: |
| """ |
| Run a single step of a multi-step workflow. |
| |
| Args: |
| client: VLLM client |
| registry: Tool registry |
| step_num: Step number (for tool filtering) |
| system_prompt: System prompt for this step |
| initial_context: User message / context from previous step |
| schema_class: Pydantic schema for validating responses |
| max_iterations: Max LLM turns before giving up |
| validation_tools: Names of tools that must be called AND pass |
| |
| Returns: |
| Parsed response dict, or None if step failed. |
| """ |
| |
| tool_schemas = registry.get_schemas(step=step_num) |
|
|
| |
| tracker = None |
| if validation_tools: |
| tracker = ValidationTracker(validation_tools) |
|
|
| |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": initial_context}, |
| ] |
|
|
| for iteration in range(1, max_iterations + 1): |
| print(f"\n [Step {step_num}] Iteration {iteration}/{max_iterations}") |
|
|
| |
| completion = client.chat(messages, tools=tool_schemas if tool_schemas else None) |
|
|
| if not completion: |
| print(f" [Step {step_num}] Empty response") |
| messages.append({"role": "assistant", "content": ""}) |
| messages.append({ |
| "role": "user", |
| "content": "Your response was empty. Please provide a valid JSON response." |
| }) |
| continue |
|
|
| |
| tool_call_list = extract_tool_calls(completion) |
| if tool_call_list: |
| print(f" [Step {step_num}] Found {len(tool_call_list)} tool call(s)") |
| messages.append({"role": "assistant", "content": completion}) |
|
|
| for tc in tool_call_list: |
| print(f" -> {tc.get('name')}({json.dumps(tc.get('arguments', {}))})") |
| result = registry.execute(tc["name"], tc.get("arguments", {})) |
|
|
| |
| if tracker and tc["name"] in (validation_tools or []): |
| tracker.record_call(tc["name"], result) |
|
|
| |
| status = "ERROR" if "error" in result else "SUCCESS" |
| response_text = ( |
| f"<tool_response>\n" |
| f"<tool_name>{tc['name']}</tool_name>\n" |
| f"<status>{status}</status>\n" |
| f"<result>{json.dumps(result, indent=2)}</result>\n" |
| f"</tool_response>" |
| ) |
| if "error" in result: |
| response_text += ( |
| "\nIMPORTANT: This tool call failed. " |
| "Read the error, understand the issue, fix your parameters, and retry." |
| ) |
|
|
| messages.append({"role": "user", "content": response_text}) |
| continue |
|
|
| |
| try: |
| json_data = extract_json(completion) |
| schema_class(**json_data) |
| result_data = json_data |
|
|
| |
| if result_data.get("tool_calls"): |
| messages.append({"role": "assistant", "content": completion}) |
| for tc in result_data["tool_calls"]: |
| result = registry.execute(tc["name"], tc.get("arguments", {})) |
| if tracker and tc["name"] in (validation_tools or []): |
| tracker.record_call(tc["name"], result) |
| messages.append({ |
| "role": "user", |
| "content": f"Tool result for {tc['name']}: {json.dumps(result)}" |
| }) |
| continue |
|
|
| |
| if tracker and not tracker.all_passed(): |
| error_feedback = tracker.format_errors() |
| enforcement_msg = ( |
| f"You returned a final response but validations have not all passed.\n" |
| f"{error_feedback}\n" |
| f"Please fix the errors and call the validation tools again." |
| ) |
| messages.append({"role": "assistant", "content": completion}) |
| messages.append({"role": "user", "content": enforcement_msg}) |
| continue |
|
|
| print(f" [Step {step_num}] Final response received and validated") |
| return result_data |
|
|
| except (json.JSONDecodeError, ValidationError) as e: |
| print(f" [Step {step_num}] Parse/validation error: {e}") |
| messages.append({"role": "assistant", "content": completion}) |
| messages.append({ |
| "role": "user", |
| "content": ( |
| f"Your response could not be parsed as valid JSON. Error: {str(e)}\n" |
| f"Please respond with ONLY valid JSON matching the required schema. " |
| f"No markdown, no explanatory text." |
| ) |
| }) |
|
|
| print(f" [Step {step_num}] Exhausted {max_iterations} iterations") |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def run_workflow( |
| client: VLLMClient, |
| registry: ToolRegistry, |
| steps: List[Dict], |
| initial_query: str, |
| max_step_retries: int = 3, |
| ) -> Optional[Dict]: |
| """ |
| Run a multi-step workflow. |
| |
| Args: |
| client: VLLM client |
| registry: Tool registry |
| steps: List of step configs, each with: |
| - step_num: int |
| - system_prompt: str |
| - max_iterations: int |
| - validation_tools: Optional[List[str]] |
| initial_query: User's original request |
| max_step_retries: How many times to retry each step |
| |
| Returns: |
| Final result dict, or None if workflow failed. |
| """ |
| previous_result = initial_query |
|
|
| for step_config in steps: |
| step_num = step_config["step_num"] |
| print(f"\n{'='*60}") |
| print(f"STEP {step_num}: {step_config.get('name', 'Unnamed')}") |
| print(f"{'='*60}") |
|
|
| for retry in range(max_step_retries): |
| if retry > 0: |
| print(f"\n Retry {retry}/{max_step_retries}") |
|
|
| context = previous_result if isinstance(previous_result, str) else json.dumps(previous_result) |
|
|
| result = run_step( |
| client=client, |
| registry=registry, |
| step_num=step_num, |
| system_prompt=step_config["system_prompt"], |
| initial_context=context, |
| max_iterations=step_config.get("max_iterations", 10), |
| validation_tools=step_config.get("validation_tools"), |
| ) |
|
|
| if result: |
| previous_result = result |
| break |
| else: |
| print(f"\n Step {step_num} failed after {max_step_retries} retries") |
| return None |
|
|
| return previous_result |
|
|
|
|
| |
| |
| |
|
|
| def example_search(query: str) -> dict: |
| """Example search tool (replace with real implementation).""" |
| return { |
| "results": [ |
| {"name": f"result_for_{query}", "type": "example", "description": f"Found match for '{query}'"} |
| ] |
| } |
|
|
|
|
| def example_get_details(name: str) -> dict: |
| """Example detail-fetching tool (replace with real implementation).""" |
| return { |
| "name": name, |
| "required_fields": ["field_a", "field_b"], |
| "version": 2, |
| "examples": [{"field_a": "value1", "field_b": "value2"}] |
| } |
|
|
|
|
| def example_validate(name: str, config: dict) -> dict: |
| """Example validation tool (replace with real implementation).""" |
| errors = [] |
| if "field_a" not in config: |
| errors.append({"property": "field_a", "message": "Required field missing"}) |
| if "field_b" not in config: |
| errors.append({"property": "field_b", "message": "Required field missing"}) |
| return {"valid": len(errors) == 0, "errors": errors} |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Multi-step tool calling orchestrator") |
| parser.add_argument("--url", default="http://localhost:8000", help="VLLM server URL") |
| parser.add_argument("--model", default="NousResearch/Hermes-3-Llama-3.1-70B-FP8") |
| parser.add_argument("--query", default="Find and configure components for a data processing pipeline") |
| parser.add_argument("--max-iterations", type=int, default=10) |
| parser.add_argument("--max-retries", type=int, default=3) |
| args = parser.parse_args() |
|
|
| |
| client = VLLMClient(args.url, args.model) |
|
|
| |
| registry = ToolRegistry() |
| registry.register( |
| name="search", |
| description="Search for components by keyword", |
| parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, |
| function=example_search, |
| steps=[1], |
| ) |
| registry.register( |
| name="get_details", |
| description="Get configuration details for a component", |
| parameters={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, |
| function=example_get_details, |
| steps=[1, 2], |
| ) |
| registry.register( |
| name="validate", |
| description="Validate a component configuration", |
| parameters={ |
| "type": "object", |
| "properties": { |
| "name": {"type": "string"}, |
| "config": {"type": "object"} |
| }, |
| "required": ["name", "config"] |
| }, |
| function=example_validate, |
| steps=[2], |
| ) |
|
|
| |
| steps = [ |
| { |
| "step_num": 1, |
| "name": "Component Discovery", |
| "system_prompt": ( |
| "You are a component selection expert. Use the available tools to find " |
| "the right components for the user's request.\n\n" |
| "Respond with JSON: either {\"tool_calls\": [...]} to call tools, " |
| "or {\"success\": true, \"result\": {\"components\": [...]}, \"reasoning\": \"...\"} " |
| "when done.\n\n" |
| "Do NOT wrap JSON in markdown. Do NOT add explanatory text." |
| ), |
| "max_iterations": args.max_iterations, |
| }, |
| { |
| "step_num": 2, |
| "name": "Component Configuration", |
| "system_prompt": ( |
| "You are a configuration expert. For each component from the previous step, " |
| "get its details, configure all required fields, and validate the configuration.\n\n" |
| "You MUST call 'validate' for each component before returning.\n\n" |
| "Respond with JSON: either {\"tool_calls\": [...]} to call tools, " |
| "or {\"success\": true, \"result\": {\"configured\": [...]}, \"reasoning\": \"...\"} " |
| "when done.\n\n" |
| "Do NOT wrap JSON in markdown. Do NOT add explanatory text." |
| ), |
| "max_iterations": args.max_iterations, |
| "validation_tools": ["validate"], |
| }, |
| ] |
|
|
| |
| print(f"\nQuery: {args.query}") |
| result = run_workflow(client, registry, steps, args.query, args.max_retries) |
|
|
| if result: |
| print(f"\n{'='*60}") |
| print("WORKFLOW COMPLETE") |
| print(f"{'='*60}") |
| print(json.dumps(result, indent=2)) |
| else: |
| print(f"\n{'='*60}") |
| print("WORKFLOW FAILED") |
| print(f"{'='*60}") |
|
|