Rohan03 commited on
Commit
44f8122
·
verified ·
1 Parent(s): 6dd7984

V2 merge: purpose_agent/llm_compiler.py

Browse files
Files changed (1) hide show
  1. purpose_agent/llm_compiler.py +252 -0
purpose_agent/llm_compiler.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ llm_compiler.py — Parallel function calling via DAG planning.
3
+
4
+ From LLMCompiler (arxiv:2312.04511):
5
+ Instead of sequential ReAct (think → act → observe → think → act → ...),
6
+ plan ALL needed function calls upfront as a DAG, then execute
7
+ independent calls in parallel. Up to 3.7x latency speedup.
8
+
9
+ Components:
10
+ 1. Planner: LLM decomposes task into a dependency graph of tool calls
11
+ 2. TaskFetcher: Identifies which tasks are ready (all dependencies met)
12
+ 3. Executor: Runs ready tasks in parallel via ThreadPoolExecutor
13
+
14
+ Adaptation for Purpose Agent:
15
+ The LLMCompiler sits between the Actor and the Environment.
16
+ When the Actor decides an action that requires multiple tool calls,
17
+ the Compiler plans and executes them in parallel, then returns
18
+ the combined result as a single state transition.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import logging
24
+ import time
25
+ from concurrent.futures import ThreadPoolExecutor, as_completed
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Callable
28
+
29
+ from purpose_agent.llm_backend import LLMBackend, ChatMessage
30
+ from purpose_agent.tools import Tool, ToolResult, ToolRegistry
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @dataclass
36
+ class TaskNode:
37
+ """A single task in the execution DAG."""
38
+ id: str
39
+ tool_name: str
40
+ args: dict[str, Any]
41
+ dependencies: list[str] = field(default_factory=list)
42
+ result: ToolResult | None = None
43
+ status: str = "pending" # pending, running, done, failed
44
+
45
+
46
+ @dataclass
47
+ class ExecutionPlan:
48
+ """A DAG of tool calls with dependencies."""
49
+ tasks: list[TaskNode] = field(default_factory=list)
50
+ join_instruction: str = "" # How to combine results
51
+
52
+ @property
53
+ def task_map(self) -> dict[str, TaskNode]:
54
+ return {t.id: t for t in self.tasks}
55
+
56
+ def get_ready(self) -> list[TaskNode]:
57
+ """Get tasks whose dependencies are all satisfied."""
58
+ done_ids = {t.id for t in self.tasks if t.status == "done"}
59
+ return [
60
+ t for t in self.tasks
61
+ if t.status == "pending" and all(d in done_ids for d in t.dependencies)
62
+ ]
63
+
64
+
65
+ PLANNER_PROMPT = """\
66
+ You are a TASK PLANNER. Given a complex task and available tools, decompose it
67
+ into a DAG (directed acyclic graph) of tool calls that can be executed in parallel.
68
+
69
+ Available tools:
70
+ {tools_desc}
71
+
72
+ Rules:
73
+ - Each task has an id, tool_name, args, and dependencies (list of task ids that must complete first)
74
+ - Tasks with no dependencies can run in parallel
75
+ - Minimize the number of sequential steps (maximize parallelism)
76
+
77
+ Respond with JSON:
78
+ {{
79
+ "tasks": [
80
+ {{"id": "t1", "tool_name": "...", "args": {{...}}, "dependencies": []}},
81
+ {{"id": "t2", "tool_name": "...", "args": {{...}}, "dependencies": ["t1"]}}
82
+ ],
83
+ "join_instruction": "How to combine the results into a final answer"
84
+ }}
85
+ """
86
+
87
+ PLAN_SCHEMA = {
88
+ "type": "object",
89
+ "properties": {
90
+ "tasks": {
91
+ "type": "array",
92
+ "items": {
93
+ "type": "object",
94
+ "properties": {
95
+ "id": {"type": "string"},
96
+ "tool_name": {"type": "string"},
97
+ "args": {"type": "object"},
98
+ "dependencies": {"type": "array", "items": {"type": "string"}},
99
+ },
100
+ "required": ["id", "tool_name", "args"],
101
+ },
102
+ },
103
+ "join_instruction": {"type": "string"},
104
+ },
105
+ "required": ["tasks"],
106
+ }
107
+
108
+
109
+ class LLMCompiler:
110
+ """
111
+ Parallel function calling via DAG planning.
112
+
113
+ Usage:
114
+ compiler = LLMCompiler(
115
+ planner_llm=model,
116
+ tool_registry=registry,
117
+ max_workers=4,
118
+ )
119
+
120
+ # Plan and execute a complex task
121
+ result = compiler.compile_and_execute(
122
+ task="Search for X AND calculate Y AND read file Z, then combine results"
123
+ )
124
+ # → Plans 3 parallel tool calls, executes them concurrently, joins results
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ planner_llm: LLMBackend,
130
+ tool_registry: ToolRegistry,
131
+ max_workers: int = 4,
132
+ ):
133
+ self.planner = planner_llm
134
+ self.tools = tool_registry
135
+ self.max_workers = max_workers
136
+
137
+ def plan(self, task: str) -> ExecutionPlan:
138
+ """Use the LLM to decompose a task into a parallel execution plan."""
139
+ tools_desc = self.tools.format_for_prompt(compact=True)
140
+
141
+ messages = [
142
+ ChatMessage(role="system", content=PLANNER_PROMPT.format(tools_desc=tools_desc)),
143
+ ChatMessage(role="user", content=f"Task: {task}"),
144
+ ]
145
+
146
+ try:
147
+ result = self.planner.generate_structured(messages, schema=PLAN_SCHEMA)
148
+ except Exception as e:
149
+ logger.warning(f"LLMCompiler: Planning failed ({e}), creating single-task plan")
150
+ return ExecutionPlan(
151
+ tasks=[TaskNode(id="t1", tool_name="", args={"task": task})],
152
+ join_instruction="Return the result directly",
153
+ )
154
+
155
+ plan = ExecutionPlan(join_instruction=result.get("join_instruction", ""))
156
+ for t in result.get("tasks", []):
157
+ plan.tasks.append(TaskNode(
158
+ id=t.get("id", f"t{len(plan.tasks)+1}"),
159
+ tool_name=t.get("tool_name", ""),
160
+ args=t.get("args", {}),
161
+ dependencies=t.get("dependencies", []),
162
+ ))
163
+
164
+ logger.info(
165
+ f"LLMCompiler: Planned {len(plan.tasks)} tasks, "
166
+ f"max parallel={len(plan.get_ready())}"
167
+ )
168
+ return plan
169
+
170
+ def execute(self, plan: ExecutionPlan) -> dict[str, ToolResult]:
171
+ """Execute a plan, running independent tasks in parallel."""
172
+ results: dict[str, ToolResult] = {}
173
+ start = time.time()
174
+
175
+ while True:
176
+ ready = plan.get_ready()
177
+ if not ready:
178
+ # Check if we're done or stuck
179
+ pending = [t for t in plan.tasks if t.status == "pending"]
180
+ if not pending:
181
+ break
182
+ # Stuck: dependencies can't be resolved
183
+ logger.warning(f"LLMCompiler: {len(pending)} tasks stuck with unresolved dependencies")
184
+ for t in pending:
185
+ t.status = "failed"
186
+ t.result = ToolResult(output="", success=False, error="Unresolved dependencies")
187
+ break
188
+
189
+ # Execute ready tasks in parallel
190
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
191
+ future_map = {}
192
+ for task in ready:
193
+ task.status = "running"
194
+ future = executor.submit(self._execute_task, task, results)
195
+ future_map[future] = task
196
+
197
+ for future in as_completed(future_map):
198
+ task = future_map[future]
199
+ try:
200
+ result = future.result()
201
+ task.result = result
202
+ task.status = "done"
203
+ results[task.id] = result
204
+ except Exception as e:
205
+ task.status = "failed"
206
+ task.result = ToolResult(output="", success=False, error=str(e))
207
+ results[task.id] = task.result
208
+
209
+ elapsed = time.time() - start
210
+ success_count = sum(1 for r in results.values() if r.success)
211
+ logger.info(
212
+ f"LLMCompiler: Executed {len(results)} tasks in {elapsed:.2f}s "
213
+ f"({success_count} succeeded)"
214
+ )
215
+ return results
216
+
217
+ def compile_and_execute(self, task: str) -> str:
218
+ """Plan + execute + join results into a single string."""
219
+ plan = self.plan(task)
220
+ results = self.execute(plan)
221
+
222
+ # Join results
223
+ parts = []
224
+ for t in plan.tasks:
225
+ r = results.get(t.id)
226
+ if r and r.success:
227
+ parts.append(f"[{t.tool_name}] {r.output}")
228
+ elif r:
229
+ parts.append(f"[{t.tool_name}] ERROR: {r.error}")
230
+
231
+ if plan.join_instruction:
232
+ parts.append(f"\nJoin: {plan.join_instruction}")
233
+
234
+ return "\n\n".join(parts)
235
+
236
+ def _execute_task(
237
+ self, task: TaskNode, prior_results: dict[str, ToolResult],
238
+ ) -> ToolResult:
239
+ """Execute a single task, substituting dependency results into args."""
240
+ # Resolve dependency references in args
241
+ resolved_args = {}
242
+ for key, value in task.args.items():
243
+ if isinstance(value, str) and value.startswith("$"):
244
+ dep_id = value[1:]
245
+ if dep_id in prior_results and prior_results[dep_id].success:
246
+ resolved_args[key] = prior_results[dep_id].output
247
+ else:
248
+ resolved_args[key] = value
249
+ else:
250
+ resolved_args[key] = value
251
+
252
+ return self.tools.execute(task.tool_name, **resolved_args)