File size: 8,927 Bytes
44f8122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
llm_compiler.py — Parallel function calling via DAG planning.

From LLMCompiler (arxiv:2312.04511):
  Instead of sequential ReAct (think → act → observe → think → act → ...),
  plan ALL needed function calls upfront as a DAG, then execute
  independent calls in parallel. Up to 3.7x latency speedup.

Components:
  1. Planner: LLM decomposes task into a dependency graph of tool calls
  2. TaskFetcher: Identifies which tasks are ready (all dependencies met)
  3. Executor: Runs ready tasks in parallel via ThreadPoolExecutor

Adaptation for Purpose Agent:
  The LLMCompiler sits between the Actor and the Environment.
  When the Actor decides an action that requires multiple tool calls,
  the Compiler plans and executes them in parallel, then returns
  the combined result as a single state transition.
"""
from __future__ import annotations

import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable

from purpose_agent.llm_backend import LLMBackend, ChatMessage
from purpose_agent.tools import Tool, ToolResult, ToolRegistry

logger = logging.getLogger(__name__)


@dataclass
class TaskNode:
    """A single task in the execution DAG."""
    id: str
    tool_name: str
    args: dict[str, Any]
    dependencies: list[str] = field(default_factory=list)
    result: ToolResult | None = None
    status: str = "pending"  # pending, running, done, failed


@dataclass
class ExecutionPlan:
    """A DAG of tool calls with dependencies."""
    tasks: list[TaskNode] = field(default_factory=list)
    join_instruction: str = ""  # How to combine results

    @property
    def task_map(self) -> dict[str, TaskNode]:
        return {t.id: t for t in self.tasks}

    def get_ready(self) -> list[TaskNode]:
        """Get tasks whose dependencies are all satisfied."""
        done_ids = {t.id for t in self.tasks if t.status == "done"}
        return [
            t for t in self.tasks
            if t.status == "pending" and all(d in done_ids for d in t.dependencies)
        ]


PLANNER_PROMPT = """\
You are a TASK PLANNER. Given a complex task and available tools, decompose it
into a DAG (directed acyclic graph) of tool calls that can be executed in parallel.

Available tools:
{tools_desc}

Rules:
- Each task has an id, tool_name, args, and dependencies (list of task ids that must complete first)
- Tasks with no dependencies can run in parallel
- Minimize the number of sequential steps (maximize parallelism)

Respond with JSON:
{{
  "tasks": [
    {{"id": "t1", "tool_name": "...", "args": {{...}}, "dependencies": []}},
    {{"id": "t2", "tool_name": "...", "args": {{...}}, "dependencies": ["t1"]}}
  ],
  "join_instruction": "How to combine the results into a final answer"
}}
"""

PLAN_SCHEMA = {
    "type": "object",
    "properties": {
        "tasks": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "id": {"type": "string"},
                    "tool_name": {"type": "string"},
                    "args": {"type": "object"},
                    "dependencies": {"type": "array", "items": {"type": "string"}},
                },
                "required": ["id", "tool_name", "args"],
            },
        },
        "join_instruction": {"type": "string"},
    },
    "required": ["tasks"],
}


class LLMCompiler:
    """
    Parallel function calling via DAG planning.

    Usage:
        compiler = LLMCompiler(
            planner_llm=model,
            tool_registry=registry,
            max_workers=4,
        )

        # Plan and execute a complex task
        result = compiler.compile_and_execute(
            task="Search for X AND calculate Y AND read file Z, then combine results"
        )
        # → Plans 3 parallel tool calls, executes them concurrently, joins results
    """

    def __init__(
        self,
        planner_llm: LLMBackend,
        tool_registry: ToolRegistry,
        max_workers: int = 4,
    ):
        self.planner = planner_llm
        self.tools = tool_registry
        self.max_workers = max_workers

    def plan(self, task: str) -> ExecutionPlan:
        """Use the LLM to decompose a task into a parallel execution plan."""
        tools_desc = self.tools.format_for_prompt(compact=True)

        messages = [
            ChatMessage(role="system", content=PLANNER_PROMPT.format(tools_desc=tools_desc)),
            ChatMessage(role="user", content=f"Task: {task}"),
        ]

        try:
            result = self.planner.generate_structured(messages, schema=PLAN_SCHEMA)
        except Exception as e:
            logger.warning(f"LLMCompiler: Planning failed ({e}), creating single-task plan")
            return ExecutionPlan(
                tasks=[TaskNode(id="t1", tool_name="", args={"task": task})],
                join_instruction="Return the result directly",
            )

        plan = ExecutionPlan(join_instruction=result.get("join_instruction", ""))
        for t in result.get("tasks", []):
            plan.tasks.append(TaskNode(
                id=t.get("id", f"t{len(plan.tasks)+1}"),
                tool_name=t.get("tool_name", ""),
                args=t.get("args", {}),
                dependencies=t.get("dependencies", []),
            ))

        logger.info(
            f"LLMCompiler: Planned {len(plan.tasks)} tasks, "
            f"max parallel={len(plan.get_ready())}"
        )
        return plan

    def execute(self, plan: ExecutionPlan) -> dict[str, ToolResult]:
        """Execute a plan, running independent tasks in parallel."""
        results: dict[str, ToolResult] = {}
        start = time.time()

        while True:
            ready = plan.get_ready()
            if not ready:
                # Check if we're done or stuck
                pending = [t for t in plan.tasks if t.status == "pending"]
                if not pending:
                    break
                # Stuck: dependencies can't be resolved
                logger.warning(f"LLMCompiler: {len(pending)} tasks stuck with unresolved dependencies")
                for t in pending:
                    t.status = "failed"
                    t.result = ToolResult(output="", success=False, error="Unresolved dependencies")
                break

            # Execute ready tasks in parallel
            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                future_map = {}
                for task in ready:
                    task.status = "running"
                    future = executor.submit(self._execute_task, task, results)
                    future_map[future] = task

                for future in as_completed(future_map):
                    task = future_map[future]
                    try:
                        result = future.result()
                        task.result = result
                        task.status = "done"
                        results[task.id] = result
                    except Exception as e:
                        task.status = "failed"
                        task.result = ToolResult(output="", success=False, error=str(e))
                        results[task.id] = task.result

        elapsed = time.time() - start
        success_count = sum(1 for r in results.values() if r.success)
        logger.info(
            f"LLMCompiler: Executed {len(results)} tasks in {elapsed:.2f}s "
            f"({success_count} succeeded)"
        )
        return results

    def compile_and_execute(self, task: str) -> str:
        """Plan + execute + join results into a single string."""
        plan = self.plan(task)
        results = self.execute(plan)

        # Join results
        parts = []
        for t in plan.tasks:
            r = results.get(t.id)
            if r and r.success:
                parts.append(f"[{t.tool_name}] {r.output}")
            elif r:
                parts.append(f"[{t.tool_name}] ERROR: {r.error}")

        if plan.join_instruction:
            parts.append(f"\nJoin: {plan.join_instruction}")

        return "\n\n".join(parts)

    def _execute_task(
        self, task: TaskNode, prior_results: dict[str, ToolResult],
    ) -> ToolResult:
        """Execute a single task, substituting dependency results into args."""
        # Resolve dependency references in args
        resolved_args = {}
        for key, value in task.args.items():
            if isinstance(value, str) and value.startswith("$"):
                dep_id = value[1:]
                if dep_id in prior_results and prior_results[dep_id].success:
                    resolved_args[key] = prior_results[dep_id].output
                else:
                    resolved_args[key] = value
            else:
                resolved_args[key] = value

        return self.tools.execute(task.tool_name, **resolved_args)