File size: 8,927 Bytes
44f8122 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """
llm_compiler.py — Parallel function calling via DAG planning.
From LLMCompiler (arxiv:2312.04511):
Instead of sequential ReAct (think → act → observe → think → act → ...),
plan ALL needed function calls upfront as a DAG, then execute
independent calls in parallel. Up to 3.7x latency speedup.
Components:
1. Planner: LLM decomposes task into a dependency graph of tool calls
2. TaskFetcher: Identifies which tasks are ready (all dependencies met)
3. Executor: Runs ready tasks in parallel via ThreadPoolExecutor
Adaptation for Purpose Agent:
The LLMCompiler sits between the Actor and the Environment.
When the Actor decides an action that requires multiple tool calls,
the Compiler plans and executes them in parallel, then returns
the combined result as a single state transition.
"""
from __future__ import annotations
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable
from purpose_agent.llm_backend import LLMBackend, ChatMessage
from purpose_agent.tools import Tool, ToolResult, ToolRegistry
logger = logging.getLogger(__name__)
@dataclass
class TaskNode:
"""A single task in the execution DAG."""
id: str
tool_name: str
args: dict[str, Any]
dependencies: list[str] = field(default_factory=list)
result: ToolResult | None = None
status: str = "pending" # pending, running, done, failed
@dataclass
class ExecutionPlan:
"""A DAG of tool calls with dependencies."""
tasks: list[TaskNode] = field(default_factory=list)
join_instruction: str = "" # How to combine results
@property
def task_map(self) -> dict[str, TaskNode]:
return {t.id: t for t in self.tasks}
def get_ready(self) -> list[TaskNode]:
"""Get tasks whose dependencies are all satisfied."""
done_ids = {t.id for t in self.tasks if t.status == "done"}
return [
t for t in self.tasks
if t.status == "pending" and all(d in done_ids for d in t.dependencies)
]
PLANNER_PROMPT = """\
You are a TASK PLANNER. Given a complex task and available tools, decompose it
into a DAG (directed acyclic graph) of tool calls that can be executed in parallel.
Available tools:
{tools_desc}
Rules:
- Each task has an id, tool_name, args, and dependencies (list of task ids that must complete first)
- Tasks with no dependencies can run in parallel
- Minimize the number of sequential steps (maximize parallelism)
Respond with JSON:
{{
"tasks": [
{{"id": "t1", "tool_name": "...", "args": {{...}}, "dependencies": []}},
{{"id": "t2", "tool_name": "...", "args": {{...}}, "dependencies": ["t1"]}}
],
"join_instruction": "How to combine the results into a final answer"
}}
"""
PLAN_SCHEMA = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"tool_name": {"type": "string"},
"args": {"type": "object"},
"dependencies": {"type": "array", "items": {"type": "string"}},
},
"required": ["id", "tool_name", "args"],
},
},
"join_instruction": {"type": "string"},
},
"required": ["tasks"],
}
class LLMCompiler:
"""
Parallel function calling via DAG planning.
Usage:
compiler = LLMCompiler(
planner_llm=model,
tool_registry=registry,
max_workers=4,
)
# Plan and execute a complex task
result = compiler.compile_and_execute(
task="Search for X AND calculate Y AND read file Z, then combine results"
)
# → Plans 3 parallel tool calls, executes them concurrently, joins results
"""
def __init__(
self,
planner_llm: LLMBackend,
tool_registry: ToolRegistry,
max_workers: int = 4,
):
self.planner = planner_llm
self.tools = tool_registry
self.max_workers = max_workers
def plan(self, task: str) -> ExecutionPlan:
"""Use the LLM to decompose a task into a parallel execution plan."""
tools_desc = self.tools.format_for_prompt(compact=True)
messages = [
ChatMessage(role="system", content=PLANNER_PROMPT.format(tools_desc=tools_desc)),
ChatMessage(role="user", content=f"Task: {task}"),
]
try:
result = self.planner.generate_structured(messages, schema=PLAN_SCHEMA)
except Exception as e:
logger.warning(f"LLMCompiler: Planning failed ({e}), creating single-task plan")
return ExecutionPlan(
tasks=[TaskNode(id="t1", tool_name="", args={"task": task})],
join_instruction="Return the result directly",
)
plan = ExecutionPlan(join_instruction=result.get("join_instruction", ""))
for t in result.get("tasks", []):
plan.tasks.append(TaskNode(
id=t.get("id", f"t{len(plan.tasks)+1}"),
tool_name=t.get("tool_name", ""),
args=t.get("args", {}),
dependencies=t.get("dependencies", []),
))
logger.info(
f"LLMCompiler: Planned {len(plan.tasks)} tasks, "
f"max parallel={len(plan.get_ready())}"
)
return plan
def execute(self, plan: ExecutionPlan) -> dict[str, ToolResult]:
"""Execute a plan, running independent tasks in parallel."""
results: dict[str, ToolResult] = {}
start = time.time()
while True:
ready = plan.get_ready()
if not ready:
# Check if we're done or stuck
pending = [t for t in plan.tasks if t.status == "pending"]
if not pending:
break
# Stuck: dependencies can't be resolved
logger.warning(f"LLMCompiler: {len(pending)} tasks stuck with unresolved dependencies")
for t in pending:
t.status = "failed"
t.result = ToolResult(output="", success=False, error="Unresolved dependencies")
break
# Execute ready tasks in parallel
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_map = {}
for task in ready:
task.status = "running"
future = executor.submit(self._execute_task, task, results)
future_map[future] = task
for future in as_completed(future_map):
task = future_map[future]
try:
result = future.result()
task.result = result
task.status = "done"
results[task.id] = result
except Exception as e:
task.status = "failed"
task.result = ToolResult(output="", success=False, error=str(e))
results[task.id] = task.result
elapsed = time.time() - start
success_count = sum(1 for r in results.values() if r.success)
logger.info(
f"LLMCompiler: Executed {len(results)} tasks in {elapsed:.2f}s "
f"({success_count} succeeded)"
)
return results
def compile_and_execute(self, task: str) -> str:
"""Plan + execute + join results into a single string."""
plan = self.plan(task)
results = self.execute(plan)
# Join results
parts = []
for t in plan.tasks:
r = results.get(t.id)
if r and r.success:
parts.append(f"[{t.tool_name}] {r.output}")
elif r:
parts.append(f"[{t.tool_name}] ERROR: {r.error}")
if plan.join_instruction:
parts.append(f"\nJoin: {plan.join_instruction}")
return "\n\n".join(parts)
def _execute_task(
self, task: TaskNode, prior_results: dict[str, ToolResult],
) -> ToolResult:
"""Execute a single task, substituting dependency results into args."""
# Resolve dependency references in args
resolved_args = {}
for key, value in task.args.items():
if isinstance(value, str) and value.startswith("$"):
dep_id = value[1:]
if dep_id in prior_results and prior_results[dep_id].success:
resolved_args[key] = prior_results[dep_id].output
else:
resolved_args[key] = value
else:
resolved_args[key] = value
return self.tools.execute(task.tool_name, **resolved_args)
|