""" state.py — Typed run/session/node state for durable execution. RunState captures everything needed to resume a run after interruption: - Current node/step position - Accumulated state data - Pending actions - Memory snapshot references - Tool call idempotency keys """ from __future__ import annotations import time import uuid from dataclasses import dataclass, field from enum import Enum from typing import Any class RunStatus(str, Enum): """Status of a run.""" PENDING = "pending" RUNNING = "running" PAUSED = "paused" # HITL or checkpoint pause COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" @dataclass class NodeState: """State of a single node in a Flow/Graph.""" node_id: str status: str = "pending" # pending, running, completed, failed, skipped input_data: dict[str, Any] = field(default_factory=dict) output_data: dict[str, Any] = field(default_factory=dict) started_at: float = 0.0 finished_at: float = 0.0 error: str | None = None attempts: int = 0 idempotency_key: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) @dataclass class RunState: """ Complete state of an execution run — serializable for checkpointing. Contains everything needed to resume from any point. """ run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) session_id: str = "" status: RunStatus = RunStatus.PENDING purpose: str = "" # Position current_node: str = "" current_step: int = 0 max_steps: int = 20 # State data data: dict[str, Any] = field(default_factory=dict) node_states: dict[str, NodeState] = field(default_factory=dict) # History completed_nodes: list[str] = field(default_factory=list) action_history: list[dict[str, Any]] = field(default_factory=list) # Timing started_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) finished_at: float = 0.0 # Idempotency tracking (tool_call_id → result) tool_results_cache: dict[str, Any] = field(default_factory=dict) # Memory references (for post-resume sync) heuristic_count_at_checkpoint: int = 0 experience_count_at_checkpoint: int = 0 # Metadata metadata: dict[str, Any] = field(default_factory=dict) version: int = 1 # For state schema migration def mark_node_started(self, node_id: str) -> None: if node_id not in self.node_states: self.node_states[node_id] = NodeState(node_id=node_id) self.node_states[node_id].status = "running" self.node_states[node_id].started_at = time.time() self.node_states[node_id].attempts += 1 self.current_node = node_id self.updated_at = time.time() def mark_node_completed(self, node_id: str, output: dict[str, Any] | None = None) -> None: if node_id in self.node_states: self.node_states[node_id].status = "completed" self.node_states[node_id].finished_at = time.time() if output: self.node_states[node_id].output_data = output self.completed_nodes.append(node_id) self.updated_at = time.time() def mark_node_failed(self, node_id: str, error: str) -> None: if node_id in self.node_states: self.node_states[node_id].status = "failed" self.node_states[node_id].error = error self.node_states[node_id].finished_at = time.time() self.updated_at = time.time() def get_idempotency_key(self, tool_name: str, args_hash: str) -> str: """Generate a stable idempotency key for a tool call.""" return f"{self.run_id}:{self.current_step}:{tool_name}:{args_hash}" def has_cached_result(self, key: str) -> bool: return key in self.tool_results_cache def get_cached_result(self, key: str) -> Any: return self.tool_results_cache.get(key) def cache_result(self, key: str, result: Any) -> None: self.tool_results_cache[key] = result def to_dict(self) -> dict[str, Any]: """Serialize for checkpointing.""" return { "run_id": self.run_id, "session_id": self.session_id, "status": self.status.value, "purpose": self.purpose, "current_node": self.current_node, "current_step": self.current_step, "max_steps": self.max_steps, "data": self.data, "node_states": {k: { "node_id": v.node_id, "status": v.status, "input_data": v.input_data, "output_data": v.output_data, "attempts": v.attempts, "error": v.error, "idempotency_key": v.idempotency_key, } for k, v in self.node_states.items()}, "completed_nodes": self.completed_nodes, "action_history": self.action_history[-50:], # Keep last 50 "started_at": self.started_at, "updated_at": self.updated_at, "finished_at": self.finished_at, "tool_results_cache": self.tool_results_cache, "heuristic_count_at_checkpoint": self.heuristic_count_at_checkpoint, "experience_count_at_checkpoint": self.experience_count_at_checkpoint, "metadata": self.metadata, "version": self.version, } @classmethod def from_dict(cls, d: dict[str, Any]) -> "RunState": """Deserialize from checkpoint.""" state = cls( run_id=d.get("run_id", ""), session_id=d.get("session_id", ""), status=RunStatus(d.get("status", "pending")), purpose=d.get("purpose", ""), current_node=d.get("current_node", ""), current_step=d.get("current_step", 0), max_steps=d.get("max_steps", 20), data=d.get("data", {}), completed_nodes=d.get("completed_nodes", []), action_history=d.get("action_history", []), started_at=d.get("started_at", 0), updated_at=d.get("updated_at", 0), finished_at=d.get("finished_at", 0), tool_results_cache=d.get("tool_results_cache", {}), heuristic_count_at_checkpoint=d.get("heuristic_count_at_checkpoint", 0), experience_count_at_checkpoint=d.get("experience_count_at_checkpoint", 0), metadata=d.get("metadata", {}), version=d.get("version", 1), ) for k, v in d.get("node_states", {}).items(): state.node_states[k] = NodeState( node_id=v.get("node_id", k), status=v.get("status", "pending"), input_data=v.get("input_data", {}), output_data=v.get("output_data", {}), attempts=v.get("attempts", 0), error=v.get("error"), idempotency_key=v.get("idempotency_key", ""), ) return state