purpose-agent / purpose_agent /streaming.py
Rohan03's picture
refactor: modularity fixes + plugin registry + compiled research
87eb9bf verified
raw
history blame
8.58 kB
"""
Streaming & Async Engine — Real-time token streaming and concurrent execution.
Adds streaming support to all modules:
- Actor streams its thought process as it reasons
- Purpose Function streams its evaluation
- Orchestrator streams step-by-step progress
Async support via asyncio:
- All core operations have async variants
- Concurrent tool execution
- Background experience replay updates
Pattern: sync methods remain the default. Async wrappers use asyncio.to_thread
for backends that don't support native async (per smolagents pattern).
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, AsyncIterator, Callable, Iterator
from purpose_agent.types import (
Action, PurposeScore, State, Trajectory, TrajectoryStep,
)
from purpose_agent.llm_backend import ChatMessage, LLMBackend
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Streaming Mixin — adds generate_stream to any LLMBackend
# ---------------------------------------------------------------------------
class StreamingMixin:
"""
Mixin that adds streaming to any LLMBackend that doesn't natively support it.
Falls back to returning the full response as a single chunk.
Override generate_stream() for native streaming.
"""
def generate_stream(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
) -> Iterator[str]:
"""
Stream tokens. Default: generate full response, yield as one chunk.
Override in subclasses for real token-level streaming.
"""
full = self.generate(messages, temperature=temperature, max_tokens=max_tokens)
yield full
async def agenerate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
"""Async wrapper around sync generate."""
return await asyncio.to_thread(
self.generate, messages, temperature, max_tokens, stop
)
async def agenerate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
"""Async wrapper around sync generate_structured."""
return await asyncio.to_thread(
self.generate_structured, messages, schema, temperature, max_tokens
)
async def agenerate_stream(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
) -> AsyncIterator[str]:
"""Async streaming. Default: wrap sync stream in async iterator."""
loop = asyncio.get_event_loop()
# Run sync generator in thread, yield results
gen = self.generate_stream(messages, temperature, max_tokens)
while True:
try:
token = await asyncio.to_thread(next, gen)
yield token
except StopIteration:
break
# ---------------------------------------------------------------------------
# Event types for streaming orchestration
# ---------------------------------------------------------------------------
class StreamEvent:
"""An event emitted during streaming orchestration."""
def __init__(
self,
event_type: str,
data: dict[str, Any] | None = None,
step: int = 0,
token: str = "",
):
self.event_type = event_type # "step_start", "token", "score", "step_end", "task_end", etc.
self.data = data or {}
self.step = step
self.token = token
self.timestamp = time.time()
def __repr__(self) -> str:
if self.token:
return f"StreamEvent({self.event_type}, token='{self.token[:20]}')"
return f"StreamEvent({self.event_type}, step={self.step})"
# ---------------------------------------------------------------------------
# Async Orchestrator — streams events during task execution
# ---------------------------------------------------------------------------
class AsyncOrchestrator:
"""
Async wrapper around the synchronous Orchestrator that streams events.
Usage:
async for event in async_orch.run_task_stream(purpose="...", ...):
if event.event_type == "token":
print(event.token, end="", flush=True)
elif event.event_type == "score":
print(f"\\nΦ: {event.data['phi_before']:.1f} → {event.data['phi_after']:.1f}")
"""
def __init__(self, orchestrator):
self.orch = orchestrator
async def run_task_stream(
self,
purpose: str,
initial_state: State | None = None,
max_steps: int = 20,
early_stop_phi: float = 9.0,
) -> AsyncIterator[StreamEvent]:
"""Run a task and stream events as they happen."""
current_state = initial_state or self.orch.environment.reset()
self.orch.purpose_fn.reset_trajectory_stats()
trajectory = Trajectory(task_description=purpose, purpose=purpose)
history: list[dict[str, Any]] = []
yield StreamEvent("task_start", {"purpose": purpose, "max_steps": max_steps})
for step_idx in range(max_steps):
yield StreamEvent("step_start", {"step": step_idx + 1}, step=step_idx + 1)
# Actor decides (run in thread to not block)
action = await asyncio.to_thread(
self.orch.actor.decide, purpose, current_state, history
)
yield StreamEvent("action", {
"name": action.name,
"thought": action.thought,
"expected_delta": action.expected_delta,
}, step=step_idx + 1)
if action.name.upper() == "DONE":
yield StreamEvent("done", {}, step=step_idx + 1)
break
# Environment executes
try:
new_state = await asyncio.to_thread(
self.orch.environment.execute, action, current_state
)
except Exception as e:
new_state = State(data={**current_state.data, "_error": str(e)})
yield StreamEvent("error", {"error": str(e)}, step=step_idx + 1)
# Purpose Function scores
score = await asyncio.to_thread(
self.orch.purpose_fn.evaluate, current_state, action, new_state, purpose
)
yield StreamEvent("score", {
"phi_before": score.phi_before,
"phi_after": score.phi_after,
"delta": score.delta,
"confidence": score.confidence,
"improved": score.improved,
"evidence": score.evidence,
}, step=step_idx + 1)
# Record step
step = TrajectoryStep(
state_before=current_state, action=action, state_after=new_state,
score=score, step_index=step_idx + 1,
)
trajectory.steps.append(step)
history.append({
"action": f"{action.name}({json.dumps(action.params, default=str)})",
"result": new_state.describe()[:200],
"score": f"Δ={score.delta:+.2f}",
})
yield StreamEvent("step_end", {
"state_summary": new_state.describe()[:200],
}, step=step_idx + 1)
if score.phi_after >= early_stop_phi:
yield StreamEvent("early_stop", {"phi": score.phi_after}, step=step_idx + 1)
break
if self.orch.environment.is_terminal(new_state):
yield StreamEvent("terminal", {}, step=step_idx + 1)
break
current_state = new_state
# Post-task (run in background)
await asyncio.to_thread(self.orch.post_task, trajectory, [])
yield StreamEvent("task_end", {
"total_steps": len(trajectory.steps),
"cumulative_reward": trajectory.cumulative_reward,
"success_rate": trajectory.success_rate,
"final_phi": trajectory.final_phi,
})
async def run_task(self, **kwargs):
"""Non-streaming async task execution."""
from purpose_agent.orchestrator import TaskResult
result = await asyncio.to_thread(self.orch.run_task, **kwargs)
return result