Rohan03 commited on
Commit
b0e5b32
·
verified ·
1 Parent(s): 8da4ecb

Sprint 2: checkpoint.py — Checkpointer protocol + InMemory/JSONL/SQLite implementations

Browse files
Files changed (1) hide show
  1. purpose_agent/runtime/checkpoint.py +266 -0
purpose_agent/runtime/checkpoint.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ checkpoint.py — Durable execution via event sourcing and state snapshots.
3
+
4
+ Provides:
5
+ - Checkpointer protocol (save events, save snapshots, load, list)
6
+ - InMemoryCheckpointer (testing)
7
+ - JSONLCheckpointer (file-based, portable)
8
+ - SQLiteCheckpointer (production, concurrent-safe)
9
+
10
+ Usage:
11
+ checkpointer = JSONLCheckpointer("./checkpoints")
12
+
13
+ # During execution
14
+ checkpointer.save_event(event)
15
+ checkpointer.save_snapshot(run_id, state)
16
+
17
+ # On crash recovery
18
+ state = checkpointer.load_latest(run_id)
19
+ events = checkpointer.list_events(run_id)
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import json
24
+ import logging
25
+ import os
26
+ import sqlite3
27
+ import time
28
+ from pathlib import Path
29
+ from typing import Any, Protocol
30
+
31
+ from purpose_agent.runtime.events import PAEvent
32
+ from purpose_agent.runtime.state import RunState
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Checkpointer(Protocol):
38
+ """
39
+ Protocol for durable state persistence.
40
+
41
+ Implementations must support:
42
+ - Appending events (event sourcing)
43
+ - Saving full state snapshots
44
+ - Loading the latest snapshot for a run
45
+ - Listing all events for replay
46
+ """
47
+
48
+ def save_event(self, event: PAEvent) -> None:
49
+ """Append an event to the durable log."""
50
+ ...
51
+
52
+ def save_snapshot(self, run_id: str, state: RunState) -> None:
53
+ """Save a full state snapshot (overwrites previous for same run_id)."""
54
+ ...
55
+
56
+ def load_latest(self, run_id: str) -> RunState | None:
57
+ """Load the most recent snapshot for a run. Returns None if not found."""
58
+ ...
59
+
60
+ def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
61
+ """List all events for a run, optionally since a sequence number."""
62
+ ...
63
+
64
+
65
+ # ═══════════════════════════════════════════════════════════════════
66
+ # InMemoryCheckpointer — for testing
67
+ # ═══════════════════════════════════════════════════════════════════
68
+
69
+ class InMemoryCheckpointer:
70
+ """In-memory checkpointer for testing. Not durable across restarts."""
71
+
72
+ def __init__(self):
73
+ self._events: dict[str, list[PAEvent]] = {}
74
+ self._snapshots: dict[str, RunState] = {}
75
+
76
+ def save_event(self, event: PAEvent) -> None:
77
+ self._events.setdefault(event.run_id, []).append(event)
78
+
79
+ def save_snapshot(self, run_id: str, state: RunState) -> None:
80
+ self._snapshots[run_id] = state
81
+
82
+ def load_latest(self, run_id: str) -> RunState | None:
83
+ return self._snapshots.get(run_id)
84
+
85
+ def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
86
+ events = self._events.get(run_id, [])
87
+ if since_seq > 0:
88
+ events = [e for e in events if e.seq > since_seq]
89
+ return events
90
+
91
+ @property
92
+ def event_count(self) -> int:
93
+ return sum(len(v) for v in self._events.values())
94
+
95
+ @property
96
+ def snapshot_count(self) -> int:
97
+ return len(self._snapshots)
98
+
99
+
100
+ # ═══════════════════════════════════════════════════════════════════
101
+ # JSONLCheckpointer — file-based, portable
102
+ # ═══════════════════════════════════════════════════════════════════
103
+
104
+ class JSONLCheckpointer:
105
+ """
106
+ File-based checkpointer using JSONL for events and JSON for snapshots.
107
+
108
+ Directory structure:
109
+ base_dir/
110
+ {run_id}/
111
+ events.jsonl
112
+ snapshot.json
113
+ """
114
+
115
+ def __init__(self, base_dir: str):
116
+ self.base_dir = Path(base_dir)
117
+ self.base_dir.mkdir(parents=True, exist_ok=True)
118
+
119
+ def _run_dir(self, run_id: str) -> Path:
120
+ d = self.base_dir / run_id
121
+ d.mkdir(parents=True, exist_ok=True)
122
+ return d
123
+
124
+ def save_event(self, event: PAEvent) -> None:
125
+ path = self._run_dir(event.run_id) / "events.jsonl"
126
+ with open(path, "a") as f:
127
+ f.write(json.dumps(event.to_dict(), default=str) + "\n")
128
+
129
+ def save_snapshot(self, run_id: str, state: RunState) -> None:
130
+ path = self._run_dir(run_id) / "snapshot.json"
131
+ with open(path, "w") as f:
132
+ json.dump(state.to_dict(), f, indent=2, default=str)
133
+
134
+ def load_latest(self, run_id: str) -> RunState | None:
135
+ path = self._run_dir(run_id) / "snapshot.json"
136
+ if not path.exists():
137
+ return None
138
+ try:
139
+ with open(path) as f:
140
+ data = json.load(f)
141
+ return RunState.from_dict(data)
142
+ except (json.JSONDecodeError, KeyError) as e:
143
+ logger.warning(f"Corrupt snapshot for {run_id}: {e}")
144
+ # Try previous backup
145
+ backup = self._run_dir(run_id) / "snapshot.backup.json"
146
+ if backup.exists():
147
+ with open(backup) as f:
148
+ return RunState.from_dict(json.load(f))
149
+ return None
150
+
151
+ def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
152
+ path = self._run_dir(run_id) / "events.jsonl"
153
+ if not path.exists():
154
+ return []
155
+ events = []
156
+ with open(path) as f:
157
+ for line in f:
158
+ line = line.strip()
159
+ if not line:
160
+ continue
161
+ try:
162
+ d = json.loads(line)
163
+ event = PAEvent.from_dict(d)
164
+ if event.seq > since_seq:
165
+ events.append(event)
166
+ except (json.JSONDecodeError, KeyError):
167
+ continue
168
+ return events
169
+
170
+
171
+ # ═══════════════════════════════════════════════════════════════════
172
+ # SQLiteCheckpointer — production, concurrent-safe
173
+ # ═══════════════════════════════════════════════════════════════════
174
+
175
+ class SQLiteCheckpointer:
176
+ """
177
+ SQLite-based checkpointer. Durable, concurrent-safe, supports multiple runs.
178
+
179
+ Uses stdlib sqlite3 — no extra dependencies.
180
+ """
181
+
182
+ def __init__(self, db_path: str = "purpose_agent_checkpoints.db"):
183
+ self.db_path = db_path
184
+ self._init_db()
185
+
186
+ def _init_db(self) -> None:
187
+ with sqlite3.connect(self.db_path) as conn:
188
+ conn.execute("""
189
+ CREATE TABLE IF NOT EXISTS events (
190
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
191
+ run_id TEXT NOT NULL,
192
+ seq INTEGER NOT NULL,
193
+ ts REAL NOT NULL,
194
+ kind TEXT NOT NULL,
195
+ lane_id TEXT DEFAULT 'main',
196
+ data TEXT NOT NULL,
197
+ created_at REAL DEFAULT (strftime('%s','now'))
198
+ )
199
+ """)
200
+ conn.execute("""
201
+ CREATE TABLE IF NOT EXISTS snapshots (
202
+ run_id TEXT PRIMARY KEY,
203
+ state TEXT NOT NULL,
204
+ updated_at REAL DEFAULT (strftime('%s','now'))
205
+ )
206
+ """)
207
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_events_run ON events(run_id, seq)")
208
+ conn.commit()
209
+
210
+ def save_event(self, event: PAEvent) -> None:
211
+ with sqlite3.connect(self.db_path) as conn:
212
+ conn.execute(
213
+ "INSERT INTO events (run_id, seq, ts, kind, lane_id, data) VALUES (?, ?, ?, ?, ?, ?)",
214
+ (event.run_id, event.seq, event.ts, event.kind.value, event.lane_id,
215
+ json.dumps(event.to_dict(), default=str)),
216
+ )
217
+ conn.commit()
218
+
219
+ def save_snapshot(self, run_id: str, state: RunState) -> None:
220
+ data = json.dumps(state.to_dict(), default=str)
221
+ with sqlite3.connect(self.db_path) as conn:
222
+ conn.execute(
223
+ "INSERT OR REPLACE INTO snapshots (run_id, state, updated_at) VALUES (?, ?, ?)",
224
+ (run_id, data, time.time()),
225
+ )
226
+ conn.commit()
227
+
228
+ def load_latest(self, run_id: str) -> RunState | None:
229
+ with sqlite3.connect(self.db_path) as conn:
230
+ row = conn.execute(
231
+ "SELECT state FROM snapshots WHERE run_id = ?", (run_id,)
232
+ ).fetchone()
233
+ if not row:
234
+ return None
235
+ try:
236
+ return RunState.from_dict(json.loads(row[0]))
237
+ except (json.JSONDecodeError, KeyError) as e:
238
+ logger.warning(f"Corrupt snapshot in SQLite for {run_id}: {e}")
239
+ return None
240
+
241
+ def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
242
+ with sqlite3.connect(self.db_path) as conn:
243
+ rows = conn.execute(
244
+ "SELECT data FROM events WHERE run_id = ? AND seq > ? ORDER BY seq",
245
+ (run_id, since_seq),
246
+ ).fetchall()
247
+ events = []
248
+ for row in rows:
249
+ try:
250
+ events.append(PAEvent.from_dict(json.loads(row[0])))
251
+ except (json.JSONDecodeError, KeyError):
252
+ continue
253
+ return events
254
+
255
+ def delete_run(self, run_id: str) -> None:
256
+ """Remove all data for a run (cleanup)."""
257
+ with sqlite3.connect(self.db_path) as conn:
258
+ conn.execute("DELETE FROM events WHERE run_id = ?", (run_id,))
259
+ conn.execute("DELETE FROM snapshots WHERE run_id = ?", (run_id,))
260
+ conn.commit()
261
+
262
+ def list_runs(self) -> list[str]:
263
+ """List all run IDs with snapshots."""
264
+ with sqlite3.connect(self.db_path) as conn:
265
+ rows = conn.execute("SELECT run_id FROM snapshots ORDER BY updated_at DESC").fetchall()
266
+ return [r[0] for r in rows]