File size: 8,581 Bytes
0d16e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87eb9bf
0d16e44
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Streaming & Async Engine — Real-time token streaming and concurrent execution.

Adds streaming support to all modules:
  - Actor streams its thought process as it reasons
  - Purpose Function streams its evaluation
  - Orchestrator streams step-by-step progress

Async support via asyncio:
  - All core operations have async variants
  - Concurrent tool execution
  - Background experience replay updates

Pattern: sync methods remain the default. Async wrappers use asyncio.to_thread
for backends that don't support native async (per smolagents pattern).
"""

from __future__ import annotations

import asyncio
import json
import logging
import time
from typing import Any, AsyncIterator, Callable, Iterator

from purpose_agent.types import (
    Action, PurposeScore, State, Trajectory, TrajectoryStep,
)
from purpose_agent.llm_backend import ChatMessage, LLMBackend

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Streaming Mixin — adds generate_stream to any LLMBackend
# ---------------------------------------------------------------------------

class StreamingMixin:
    """
    Mixin that adds streaming to any LLMBackend that doesn't natively support it.
    
    Falls back to returning the full response as a single chunk.
    Override generate_stream() for native streaming.
    """

    def generate_stream(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
    ) -> Iterator[str]:
        """
        Stream tokens. Default: generate full response, yield as one chunk.
        Override in subclasses for real token-level streaming.
        """
        full = self.generate(messages, temperature=temperature, max_tokens=max_tokens)
        yield full

    async def agenerate(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stop: list[str] | None = None,
    ) -> str:
        """Async wrapper around sync generate."""
        return await asyncio.to_thread(
            self.generate, messages, temperature, max_tokens, stop
        )

    async def agenerate_structured(
        self,
        messages: list[ChatMessage],
        schema: dict[str, Any],
        temperature: float = 0.3,
        max_tokens: int = 1024,
    ) -> dict[str, Any]:
        """Async wrapper around sync generate_structured."""
        return await asyncio.to_thread(
            self.generate_structured, messages, schema, temperature, max_tokens
        )

    async def agenerate_stream(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
    ) -> AsyncIterator[str]:
        """Async streaming. Default: wrap sync stream in async iterator."""
        loop = asyncio.get_event_loop()
        # Run sync generator in thread, yield results
        gen = self.generate_stream(messages, temperature, max_tokens)
        while True:
            try:
                token = await asyncio.to_thread(next, gen)
                yield token
            except StopIteration:
                break


# ---------------------------------------------------------------------------
# Event types for streaming orchestration
# ---------------------------------------------------------------------------

class StreamEvent:
    """An event emitted during streaming orchestration."""

    def __init__(
        self,
        event_type: str,
        data: dict[str, Any] | None = None,
        step: int = 0,
        token: str = "",
    ):
        self.event_type = event_type  # "step_start", "token", "score", "step_end", "task_end", etc.
        self.data = data or {}
        self.step = step
        self.token = token
        self.timestamp = time.time()

    def __repr__(self) -> str:
        if self.token:
            return f"StreamEvent({self.event_type}, token='{self.token[:20]}')"
        return f"StreamEvent({self.event_type}, step={self.step})"


# ---------------------------------------------------------------------------
# Async Orchestrator — streams events during task execution
# ---------------------------------------------------------------------------

class AsyncOrchestrator:
    """
    Async wrapper around the synchronous Orchestrator that streams events.
    
    Usage:
        async for event in async_orch.run_task_stream(purpose="...", ...):
            if event.event_type == "token":
                print(event.token, end="", flush=True)
            elif event.event_type == "score":
                print(f"\\nΦ: {event.data['phi_before']:.1f} → {event.data['phi_after']:.1f}")
    """

    def __init__(self, orchestrator):
        self.orch = orchestrator

    async def run_task_stream(
        self,
        purpose: str,
        initial_state: State | None = None,
        max_steps: int = 20,
        early_stop_phi: float = 9.0,
    ) -> AsyncIterator[StreamEvent]:
        """Run a task and stream events as they happen."""

        current_state = initial_state or self.orch.environment.reset()
        self.orch.purpose_fn.reset_trajectory_stats()

        trajectory = Trajectory(task_description=purpose, purpose=purpose)
        history: list[dict[str, Any]] = []

        yield StreamEvent("task_start", {"purpose": purpose, "max_steps": max_steps})

        for step_idx in range(max_steps):
            yield StreamEvent("step_start", {"step": step_idx + 1}, step=step_idx + 1)

            # Actor decides (run in thread to not block)
            action = await asyncio.to_thread(
                self.orch.actor.decide, purpose, current_state, history
            )

            yield StreamEvent("action", {
                "name": action.name,
                "thought": action.thought,
                "expected_delta": action.expected_delta,
            }, step=step_idx + 1)

            if action.name.upper() == "DONE":
                yield StreamEvent("done", {}, step=step_idx + 1)
                break

            # Environment executes
            try:
                new_state = await asyncio.to_thread(
                    self.orch.environment.execute, action, current_state
                )
            except Exception as e:
                new_state = State(data={**current_state.data, "_error": str(e)})
                yield StreamEvent("error", {"error": str(e)}, step=step_idx + 1)

            # Purpose Function scores
            score = await asyncio.to_thread(
                self.orch.purpose_fn.evaluate, current_state, action, new_state, purpose
            )

            yield StreamEvent("score", {
                "phi_before": score.phi_before,
                "phi_after": score.phi_after,
                "delta": score.delta,
                "confidence": score.confidence,
                "improved": score.improved,
                "evidence": score.evidence,
            }, step=step_idx + 1)

            # Record step
            step = TrajectoryStep(
                state_before=current_state, action=action, state_after=new_state,
                score=score, step_index=step_idx + 1,
            )
            trajectory.steps.append(step)
            history.append({
                "action": f"{action.name}({json.dumps(action.params, default=str)})",
                "result": new_state.describe()[:200],
                "score": f"Δ={score.delta:+.2f}",
            })

            yield StreamEvent("step_end", {
                "state_summary": new_state.describe()[:200],
            }, step=step_idx + 1)

            if score.phi_after >= early_stop_phi:
                yield StreamEvent("early_stop", {"phi": score.phi_after}, step=step_idx + 1)
                break

            if self.orch.environment.is_terminal(new_state):
                yield StreamEvent("terminal", {}, step=step_idx + 1)
                break

            current_state = new_state

        # Post-task (run in background)
        await asyncio.to_thread(self.orch.post_task, trajectory, [])

        yield StreamEvent("task_end", {
            "total_steps": len(trajectory.steps),
            "cumulative_reward": trajectory.cumulative_reward,
            "success_rate": trajectory.success_rate,
            "final_phi": trajectory.final_phi,
        })

    async def run_task(self, **kwargs):
        """Non-streaming async task execution."""
        from purpose_agent.orchestrator import TaskResult
        result = await asyncio.to_thread(self.orch.run_task, **kwargs)
        return result