purpose-agent / purpose_agent /llm_compiler.py
Rohan03's picture
V2 merge: purpose_agent/llm_compiler.py
44f8122 verified
"""
llm_compiler.py — Parallel function calling via DAG planning.
From LLMCompiler (arxiv:2312.04511):
Instead of sequential ReAct (think → act → observe → think → act → ...),
plan ALL needed function calls upfront as a DAG, then execute
independent calls in parallel. Up to 3.7x latency speedup.
Components:
1. Planner: LLM decomposes task into a dependency graph of tool calls
2. TaskFetcher: Identifies which tasks are ready (all dependencies met)
3. Executor: Runs ready tasks in parallel via ThreadPoolExecutor
Adaptation for Purpose Agent:
The LLMCompiler sits between the Actor and the Environment.
When the Actor decides an action that requires multiple tool calls,
the Compiler plans and executes them in parallel, then returns
the combined result as a single state transition.
"""
from __future__ import annotations
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable
from purpose_agent.llm_backend import LLMBackend, ChatMessage
from purpose_agent.tools import Tool, ToolResult, ToolRegistry
logger = logging.getLogger(__name__)
@dataclass
class TaskNode:
"""A single task in the execution DAG."""
id: str
tool_name: str
args: dict[str, Any]
dependencies: list[str] = field(default_factory=list)
result: ToolResult | None = None
status: str = "pending" # pending, running, done, failed
@dataclass
class ExecutionPlan:
"""A DAG of tool calls with dependencies."""
tasks: list[TaskNode] = field(default_factory=list)
join_instruction: str = "" # How to combine results
@property
def task_map(self) -> dict[str, TaskNode]:
return {t.id: t for t in self.tasks}
def get_ready(self) -> list[TaskNode]:
"""Get tasks whose dependencies are all satisfied."""
done_ids = {t.id for t in self.tasks if t.status == "done"}
return [
t for t in self.tasks
if t.status == "pending" and all(d in done_ids for d in t.dependencies)
]
PLANNER_PROMPT = """\
You are a TASK PLANNER. Given a complex task and available tools, decompose it
into a DAG (directed acyclic graph) of tool calls that can be executed in parallel.
Available tools:
{tools_desc}
Rules:
- Each task has an id, tool_name, args, and dependencies (list of task ids that must complete first)
- Tasks with no dependencies can run in parallel
- Minimize the number of sequential steps (maximize parallelism)
Respond with JSON:
{{
"tasks": [
{{"id": "t1", "tool_name": "...", "args": {{...}}, "dependencies": []}},
{{"id": "t2", "tool_name": "...", "args": {{...}}, "dependencies": ["t1"]}}
],
"join_instruction": "How to combine the results into a final answer"
}}
"""
PLAN_SCHEMA = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"tool_name": {"type": "string"},
"args": {"type": "object"},
"dependencies": {"type": "array", "items": {"type": "string"}},
},
"required": ["id", "tool_name", "args"],
},
},
"join_instruction": {"type": "string"},
},
"required": ["tasks"],
}
class LLMCompiler:
"""
Parallel function calling via DAG planning.
Usage:
compiler = LLMCompiler(
planner_llm=model,
tool_registry=registry,
max_workers=4,
)
# Plan and execute a complex task
result = compiler.compile_and_execute(
task="Search for X AND calculate Y AND read file Z, then combine results"
)
# → Plans 3 parallel tool calls, executes them concurrently, joins results
"""
def __init__(
self,
planner_llm: LLMBackend,
tool_registry: ToolRegistry,
max_workers: int = 4,
):
self.planner = planner_llm
self.tools = tool_registry
self.max_workers = max_workers
def plan(self, task: str) -> ExecutionPlan:
"""Use the LLM to decompose a task into a parallel execution plan."""
tools_desc = self.tools.format_for_prompt(compact=True)
messages = [
ChatMessage(role="system", content=PLANNER_PROMPT.format(tools_desc=tools_desc)),
ChatMessage(role="user", content=f"Task: {task}"),
]
try:
result = self.planner.generate_structured(messages, schema=PLAN_SCHEMA)
except Exception as e:
logger.warning(f"LLMCompiler: Planning failed ({e}), creating single-task plan")
return ExecutionPlan(
tasks=[TaskNode(id="t1", tool_name="", args={"task": task})],
join_instruction="Return the result directly",
)
plan = ExecutionPlan(join_instruction=result.get("join_instruction", ""))
for t in result.get("tasks", []):
plan.tasks.append(TaskNode(
id=t.get("id", f"t{len(plan.tasks)+1}"),
tool_name=t.get("tool_name", ""),
args=t.get("args", {}),
dependencies=t.get("dependencies", []),
))
logger.info(
f"LLMCompiler: Planned {len(plan.tasks)} tasks, "
f"max parallel={len(plan.get_ready())}"
)
return plan
def execute(self, plan: ExecutionPlan) -> dict[str, ToolResult]:
"""Execute a plan, running independent tasks in parallel."""
results: dict[str, ToolResult] = {}
start = time.time()
while True:
ready = plan.get_ready()
if not ready:
# Check if we're done or stuck
pending = [t for t in plan.tasks if t.status == "pending"]
if not pending:
break
# Stuck: dependencies can't be resolved
logger.warning(f"LLMCompiler: {len(pending)} tasks stuck with unresolved dependencies")
for t in pending:
t.status = "failed"
t.result = ToolResult(output="", success=False, error="Unresolved dependencies")
break
# Execute ready tasks in parallel
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_map = {}
for task in ready:
task.status = "running"
future = executor.submit(self._execute_task, task, results)
future_map[future] = task
for future in as_completed(future_map):
task = future_map[future]
try:
result = future.result()
task.result = result
task.status = "done"
results[task.id] = result
except Exception as e:
task.status = "failed"
task.result = ToolResult(output="", success=False, error=str(e))
results[task.id] = task.result
elapsed = time.time() - start
success_count = sum(1 for r in results.values() if r.success)
logger.info(
f"LLMCompiler: Executed {len(results)} tasks in {elapsed:.2f}s "
f"({success_count} succeeded)"
)
return results
def compile_and_execute(self, task: str) -> str:
"""Plan + execute + join results into a single string."""
plan = self.plan(task)
results = self.execute(plan)
# Join results
parts = []
for t in plan.tasks:
r = results.get(t.id)
if r and r.success:
parts.append(f"[{t.tool_name}] {r.output}")
elif r:
parts.append(f"[{t.tool_name}] ERROR: {r.error}")
if plan.join_instruction:
parts.append(f"\nJoin: {plan.join_instruction}")
return "\n\n".join(parts)
def _execute_task(
self, task: TaskNode, prior_results: dict[str, ToolResult],
) -> ToolResult:
"""Execute a single task, substituting dependency results into args."""
# Resolve dependency references in args
resolved_args = {}
for key, value in task.args.items():
if isinstance(value, str) and value.startswith("$"):
dep_id = value[1:]
if dep_id in prior_results and prior_results[dep_id].success:
resolved_args[key] = prior_results[dep_id].output
else:
resolved_args[key] = value
else:
resolved_args[key] = value
return self.tools.execute(task.tool_name, **resolved_args)