Rohan03 commited on
Commit
8da4ecb
·
verified ·
1 Parent(s): d64d322

Sprint 2: runtime/state.py — typed RunState for durable execution

Browse files
Files changed (1) hide show
  1. purpose_agent/runtime/state.py +185 -0
purpose_agent/runtime/state.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ state.py — Typed run/session/node state for durable execution.
3
+
4
+ RunState captures everything needed to resume a run after interruption:
5
+ - Current node/step position
6
+ - Accumulated state data
7
+ - Pending actions
8
+ - Memory snapshot references
9
+ - Tool call idempotency keys
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import time
14
+ import uuid
15
+ from dataclasses import dataclass, field
16
+ from enum import Enum
17
+ from typing import Any
18
+
19
+
20
+ class RunStatus(str, Enum):
21
+ """Status of a run."""
22
+ PENDING = "pending"
23
+ RUNNING = "running"
24
+ PAUSED = "paused" # HITL or checkpoint pause
25
+ COMPLETED = "completed"
26
+ FAILED = "failed"
27
+ CANCELLED = "cancelled"
28
+
29
+
30
+ @dataclass
31
+ class NodeState:
32
+ """State of a single node in a Flow/Graph."""
33
+ node_id: str
34
+ status: str = "pending" # pending, running, completed, failed, skipped
35
+ input_data: dict[str, Any] = field(default_factory=dict)
36
+ output_data: dict[str, Any] = field(default_factory=dict)
37
+ started_at: float = 0.0
38
+ finished_at: float = 0.0
39
+ error: str | None = None
40
+ attempts: int = 0
41
+ idempotency_key: str = field(default_factory=lambda: uuid.uuid4().hex[:16])
42
+
43
+
44
+ @dataclass
45
+ class RunState:
46
+ """
47
+ Complete state of an execution run — serializable for checkpointing.
48
+
49
+ Contains everything needed to resume from any point.
50
+ """
51
+ run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
52
+ session_id: str = ""
53
+ status: RunStatus = RunStatus.PENDING
54
+ purpose: str = ""
55
+
56
+ # Position
57
+ current_node: str = ""
58
+ current_step: int = 0
59
+ max_steps: int = 20
60
+
61
+ # State data
62
+ data: dict[str, Any] = field(default_factory=dict)
63
+ node_states: dict[str, NodeState] = field(default_factory=dict)
64
+
65
+ # History
66
+ completed_nodes: list[str] = field(default_factory=list)
67
+ action_history: list[dict[str, Any]] = field(default_factory=list)
68
+
69
+ # Timing
70
+ started_at: float = field(default_factory=time.time)
71
+ updated_at: float = field(default_factory=time.time)
72
+ finished_at: float = 0.0
73
+
74
+ # Idempotency tracking (tool_call_id → result)
75
+ tool_results_cache: dict[str, Any] = field(default_factory=dict)
76
+
77
+ # Memory references (for post-resume sync)
78
+ heuristic_count_at_checkpoint: int = 0
79
+ experience_count_at_checkpoint: int = 0
80
+
81
+ # Metadata
82
+ metadata: dict[str, Any] = field(default_factory=dict)
83
+ version: int = 1 # For state schema migration
84
+
85
+ def mark_node_started(self, node_id: str) -> None:
86
+ if node_id not in self.node_states:
87
+ self.node_states[node_id] = NodeState(node_id=node_id)
88
+ self.node_states[node_id].status = "running"
89
+ self.node_states[node_id].started_at = time.time()
90
+ self.node_states[node_id].attempts += 1
91
+ self.current_node = node_id
92
+ self.updated_at = time.time()
93
+
94
+ def mark_node_completed(self, node_id: str, output: dict[str, Any] | None = None) -> None:
95
+ if node_id in self.node_states:
96
+ self.node_states[node_id].status = "completed"
97
+ self.node_states[node_id].finished_at = time.time()
98
+ if output:
99
+ self.node_states[node_id].output_data = output
100
+ self.completed_nodes.append(node_id)
101
+ self.updated_at = time.time()
102
+
103
+ def mark_node_failed(self, node_id: str, error: str) -> None:
104
+ if node_id in self.node_states:
105
+ self.node_states[node_id].status = "failed"
106
+ self.node_states[node_id].error = error
107
+ self.node_states[node_id].finished_at = time.time()
108
+ self.updated_at = time.time()
109
+
110
+ def get_idempotency_key(self, tool_name: str, args_hash: str) -> str:
111
+ """Generate a stable idempotency key for a tool call."""
112
+ return f"{self.run_id}:{self.current_step}:{tool_name}:{args_hash}"
113
+
114
+ def has_cached_result(self, key: str) -> bool:
115
+ return key in self.tool_results_cache
116
+
117
+ def get_cached_result(self, key: str) -> Any:
118
+ return self.tool_results_cache.get(key)
119
+
120
+ def cache_result(self, key: str, result: Any) -> None:
121
+ self.tool_results_cache[key] = result
122
+
123
+ def to_dict(self) -> dict[str, Any]:
124
+ """Serialize for checkpointing."""
125
+ return {
126
+ "run_id": self.run_id,
127
+ "session_id": self.session_id,
128
+ "status": self.status.value,
129
+ "purpose": self.purpose,
130
+ "current_node": self.current_node,
131
+ "current_step": self.current_step,
132
+ "max_steps": self.max_steps,
133
+ "data": self.data,
134
+ "node_states": {k: {
135
+ "node_id": v.node_id, "status": v.status,
136
+ "input_data": v.input_data, "output_data": v.output_data,
137
+ "attempts": v.attempts, "error": v.error,
138
+ "idempotency_key": v.idempotency_key,
139
+ } for k, v in self.node_states.items()},
140
+ "completed_nodes": self.completed_nodes,
141
+ "action_history": self.action_history[-50:], # Keep last 50
142
+ "started_at": self.started_at,
143
+ "updated_at": self.updated_at,
144
+ "finished_at": self.finished_at,
145
+ "tool_results_cache": self.tool_results_cache,
146
+ "heuristic_count_at_checkpoint": self.heuristic_count_at_checkpoint,
147
+ "experience_count_at_checkpoint": self.experience_count_at_checkpoint,
148
+ "metadata": self.metadata,
149
+ "version": self.version,
150
+ }
151
+
152
+ @classmethod
153
+ def from_dict(cls, d: dict[str, Any]) -> "RunState":
154
+ """Deserialize from checkpoint."""
155
+ state = cls(
156
+ run_id=d.get("run_id", ""),
157
+ session_id=d.get("session_id", ""),
158
+ status=RunStatus(d.get("status", "pending")),
159
+ purpose=d.get("purpose", ""),
160
+ current_node=d.get("current_node", ""),
161
+ current_step=d.get("current_step", 0),
162
+ max_steps=d.get("max_steps", 20),
163
+ data=d.get("data", {}),
164
+ completed_nodes=d.get("completed_nodes", []),
165
+ action_history=d.get("action_history", []),
166
+ started_at=d.get("started_at", 0),
167
+ updated_at=d.get("updated_at", 0),
168
+ finished_at=d.get("finished_at", 0),
169
+ tool_results_cache=d.get("tool_results_cache", {}),
170
+ heuristic_count_at_checkpoint=d.get("heuristic_count_at_checkpoint", 0),
171
+ experience_count_at_checkpoint=d.get("experience_count_at_checkpoint", 0),
172
+ metadata=d.get("metadata", {}),
173
+ version=d.get("version", 1),
174
+ )
175
+ for k, v in d.get("node_states", {}).items():
176
+ state.node_states[k] = NodeState(
177
+ node_id=v.get("node_id", k),
178
+ status=v.get("status", "pending"),
179
+ input_data=v.get("input_data", {}),
180
+ output_data=v.get("output_data", {}),
181
+ attempts=v.get("attempts", 0),
182
+ error=v.get("error"),
183
+ idempotency_key=v.get("idempotency_key", ""),
184
+ )
185
+ return state