Rohan03's picture
Sprint 2: checkpoint.py β€” Checkpointer protocol + InMemory/JSONL/SQLite implementations
b0e5b32 verified
"""
checkpoint.py β€” Durable execution via event sourcing and state snapshots.
Provides:
- Checkpointer protocol (save events, save snapshots, load, list)
- InMemoryCheckpointer (testing)
- JSONLCheckpointer (file-based, portable)
- SQLiteCheckpointer (production, concurrent-safe)
Usage:
checkpointer = JSONLCheckpointer("./checkpoints")
# During execution
checkpointer.save_event(event)
checkpointer.save_snapshot(run_id, state)
# On crash recovery
state = checkpointer.load_latest(run_id)
events = checkpointer.list_events(run_id)
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import time
from pathlib import Path
from typing import Any, Protocol
from purpose_agent.runtime.events import PAEvent
from purpose_agent.runtime.state import RunState
logger = logging.getLogger(__name__)
class Checkpointer(Protocol):
"""
Protocol for durable state persistence.
Implementations must support:
- Appending events (event sourcing)
- Saving full state snapshots
- Loading the latest snapshot for a run
- Listing all events for replay
"""
def save_event(self, event: PAEvent) -> None:
"""Append an event to the durable log."""
...
def save_snapshot(self, run_id: str, state: RunState) -> None:
"""Save a full state snapshot (overwrites previous for same run_id)."""
...
def load_latest(self, run_id: str) -> RunState | None:
"""Load the most recent snapshot for a run. Returns None if not found."""
...
def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
"""List all events for a run, optionally since a sequence number."""
...
# ═══════════════════════════════════════════════════════════════════
# InMemoryCheckpointer β€” for testing
# ═══════════════════════════════════════════════════════════════════
class InMemoryCheckpointer:
"""In-memory checkpointer for testing. Not durable across restarts."""
def __init__(self):
self._events: dict[str, list[PAEvent]] = {}
self._snapshots: dict[str, RunState] = {}
def save_event(self, event: PAEvent) -> None:
self._events.setdefault(event.run_id, []).append(event)
def save_snapshot(self, run_id: str, state: RunState) -> None:
self._snapshots[run_id] = state
def load_latest(self, run_id: str) -> RunState | None:
return self._snapshots.get(run_id)
def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
events = self._events.get(run_id, [])
if since_seq > 0:
events = [e for e in events if e.seq > since_seq]
return events
@property
def event_count(self) -> int:
return sum(len(v) for v in self._events.values())
@property
def snapshot_count(self) -> int:
return len(self._snapshots)
# ═══════════════════════════════════════════════════════════════════
# JSONLCheckpointer β€” file-based, portable
# ═══════════════════════════════════════════════════════════════════
class JSONLCheckpointer:
"""
File-based checkpointer using JSONL for events and JSON for snapshots.
Directory structure:
base_dir/
{run_id}/
events.jsonl
snapshot.json
"""
def __init__(self, base_dir: str):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
def _run_dir(self, run_id: str) -> Path:
d = self.base_dir / run_id
d.mkdir(parents=True, exist_ok=True)
return d
def save_event(self, event: PAEvent) -> None:
path = self._run_dir(event.run_id) / "events.jsonl"
with open(path, "a") as f:
f.write(json.dumps(event.to_dict(), default=str) + "\n")
def save_snapshot(self, run_id: str, state: RunState) -> None:
path = self._run_dir(run_id) / "snapshot.json"
with open(path, "w") as f:
json.dump(state.to_dict(), f, indent=2, default=str)
def load_latest(self, run_id: str) -> RunState | None:
path = self._run_dir(run_id) / "snapshot.json"
if not path.exists():
return None
try:
with open(path) as f:
data = json.load(f)
return RunState.from_dict(data)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Corrupt snapshot for {run_id}: {e}")
# Try previous backup
backup = self._run_dir(run_id) / "snapshot.backup.json"
if backup.exists():
with open(backup) as f:
return RunState.from_dict(json.load(f))
return None
def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
path = self._run_dir(run_id) / "events.jsonl"
if not path.exists():
return []
events = []
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
try:
d = json.loads(line)
event = PAEvent.from_dict(d)
if event.seq > since_seq:
events.append(event)
except (json.JSONDecodeError, KeyError):
continue
return events
# ═══════════════════════════════════════════════════════════════════
# SQLiteCheckpointer β€” production, concurrent-safe
# ═══════════════════════════════════════════════════════════════════
class SQLiteCheckpointer:
"""
SQLite-based checkpointer. Durable, concurrent-safe, supports multiple runs.
Uses stdlib sqlite3 β€” no extra dependencies.
"""
def __init__(self, db_path: str = "purpose_agent_checkpoints.db"):
self.db_path = db_path
self._init_db()
def _init_db(self) -> None:
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
seq INTEGER NOT NULL,
ts REAL NOT NULL,
kind TEXT NOT NULL,
lane_id TEXT DEFAULT 'main',
data TEXT NOT NULL,
created_at REAL DEFAULT (strftime('%s','now'))
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS snapshots (
run_id TEXT PRIMARY KEY,
state TEXT NOT NULL,
updated_at REAL DEFAULT (strftime('%s','now'))
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_run ON events(run_id, seq)")
conn.commit()
def save_event(self, event: PAEvent) -> None:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT INTO events (run_id, seq, ts, kind, lane_id, data) VALUES (?, ?, ?, ?, ?, ?)",
(event.run_id, event.seq, event.ts, event.kind.value, event.lane_id,
json.dumps(event.to_dict(), default=str)),
)
conn.commit()
def save_snapshot(self, run_id: str, state: RunState) -> None:
data = json.dumps(state.to_dict(), default=str)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT OR REPLACE INTO snapshots (run_id, state, updated_at) VALUES (?, ?, ?)",
(run_id, data, time.time()),
)
conn.commit()
def load_latest(self, run_id: str) -> RunState | None:
with sqlite3.connect(self.db_path) as conn:
row = conn.execute(
"SELECT state FROM snapshots WHERE run_id = ?", (run_id,)
).fetchone()
if not row:
return None
try:
return RunState.from_dict(json.loads(row[0]))
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Corrupt snapshot in SQLite for {run_id}: {e}")
return None
def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
with sqlite3.connect(self.db_path) as conn:
rows = conn.execute(
"SELECT data FROM events WHERE run_id = ? AND seq > ? ORDER BY seq",
(run_id, since_seq),
).fetchall()
events = []
for row in rows:
try:
events.append(PAEvent.from_dict(json.loads(row[0])))
except (json.JSONDecodeError, KeyError):
continue
return events
def delete_run(self, run_id: str) -> None:
"""Remove all data for a run (cleanup)."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM events WHERE run_id = ?", (run_id,))
conn.execute("DELETE FROM snapshots WHERE run_id = ?", (run_id,))
conn.commit()
def list_runs(self) -> list[str]:
"""List all run IDs with snapshots."""
with sqlite3.connect(self.db_path) as conn:
rows = conn.execute("SELECT run_id FROM snapshots ORDER BY updated_at DESC").fetchall()
return [r[0] for r in rows]