""" checkpoint.py — Durable execution via event sourcing and state snapshots. Provides: - Checkpointer protocol (save events, save snapshots, load, list) - InMemoryCheckpointer (testing) - JSONLCheckpointer (file-based, portable) - SQLiteCheckpointer (production, concurrent-safe) Usage: checkpointer = JSONLCheckpointer("./checkpoints") # During execution checkpointer.save_event(event) checkpointer.save_snapshot(run_id, state) # On crash recovery state = checkpointer.load_latest(run_id) events = checkpointer.list_events(run_id) """ from __future__ import annotations import json import logging import os import sqlite3 import time from pathlib import Path from typing import Any, Protocol from purpose_agent.runtime.events import PAEvent from purpose_agent.runtime.state import RunState logger = logging.getLogger(__name__) class Checkpointer(Protocol): """ Protocol for durable state persistence. Implementations must support: - Appending events (event sourcing) - Saving full state snapshots - Loading the latest snapshot for a run - Listing all events for replay """ def save_event(self, event: PAEvent) -> None: """Append an event to the durable log.""" ... def save_snapshot(self, run_id: str, state: RunState) -> None: """Save a full state snapshot (overwrites previous for same run_id).""" ... def load_latest(self, run_id: str) -> RunState | None: """Load the most recent snapshot for a run. Returns None if not found.""" ... def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]: """List all events for a run, optionally since a sequence number.""" ... # ═══════════════════════════════════════════════════════════════════ # InMemoryCheckpointer — for testing # ═══════════════════════════════════════════════════════════════════ class InMemoryCheckpointer: """In-memory checkpointer for testing. Not durable across restarts.""" def __init__(self): self._events: dict[str, list[PAEvent]] = {} self._snapshots: dict[str, RunState] = {} def save_event(self, event: PAEvent) -> None: self._events.setdefault(event.run_id, []).append(event) def save_snapshot(self, run_id: str, state: RunState) -> None: self._snapshots[run_id] = state def load_latest(self, run_id: str) -> RunState | None: return self._snapshots.get(run_id) def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]: events = self._events.get(run_id, []) if since_seq > 0: events = [e for e in events if e.seq > since_seq] return events @property def event_count(self) -> int: return sum(len(v) for v in self._events.values()) @property def snapshot_count(self) -> int: return len(self._snapshots) # ═══════════════════════════════════════════════════════════════════ # JSONLCheckpointer — file-based, portable # ═══════════════════════════════════════════════════════════════════ class JSONLCheckpointer: """ File-based checkpointer using JSONL for events and JSON for snapshots. Directory structure: base_dir/ {run_id}/ events.jsonl snapshot.json """ def __init__(self, base_dir: str): self.base_dir = Path(base_dir) self.base_dir.mkdir(parents=True, exist_ok=True) def _run_dir(self, run_id: str) -> Path: d = self.base_dir / run_id d.mkdir(parents=True, exist_ok=True) return d def save_event(self, event: PAEvent) -> None: path = self._run_dir(event.run_id) / "events.jsonl" with open(path, "a") as f: f.write(json.dumps(event.to_dict(), default=str) + "\n") def save_snapshot(self, run_id: str, state: RunState) -> None: path = self._run_dir(run_id) / "snapshot.json" with open(path, "w") as f: json.dump(state.to_dict(), f, indent=2, default=str) def load_latest(self, run_id: str) -> RunState | None: path = self._run_dir(run_id) / "snapshot.json" if not path.exists(): return None try: with open(path) as f: data = json.load(f) return RunState.from_dict(data) except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Corrupt snapshot for {run_id}: {e}") # Try previous backup backup = self._run_dir(run_id) / "snapshot.backup.json" if backup.exists(): with open(backup) as f: return RunState.from_dict(json.load(f)) return None def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]: path = self._run_dir(run_id) / "events.jsonl" if not path.exists(): return [] events = [] with open(path) as f: for line in f: line = line.strip() if not line: continue try: d = json.loads(line) event = PAEvent.from_dict(d) if event.seq > since_seq: events.append(event) except (json.JSONDecodeError, KeyError): continue return events # ═══════════════════════════════════════════════════════════════════ # SQLiteCheckpointer — production, concurrent-safe # ═══════════════════════════════════════════════════════════════════ class SQLiteCheckpointer: """ SQLite-based checkpointer. Durable, concurrent-safe, supports multiple runs. Uses stdlib sqlite3 — no extra dependencies. """ def __init__(self, db_path: str = "purpose_agent_checkpoints.db"): self.db_path = db_path self._init_db() def _init_db(self) -> None: with sqlite3.connect(self.db_path) as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS events ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_id TEXT NOT NULL, seq INTEGER NOT NULL, ts REAL NOT NULL, kind TEXT NOT NULL, lane_id TEXT DEFAULT 'main', data TEXT NOT NULL, created_at REAL DEFAULT (strftime('%s','now')) ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS snapshots ( run_id TEXT PRIMARY KEY, state TEXT NOT NULL, updated_at REAL DEFAULT (strftime('%s','now')) ) """) conn.execute("CREATE INDEX IF NOT EXISTS idx_events_run ON events(run_id, seq)") conn.commit() def save_event(self, event: PAEvent) -> None: with sqlite3.connect(self.db_path) as conn: conn.execute( "INSERT INTO events (run_id, seq, ts, kind, lane_id, data) VALUES (?, ?, ?, ?, ?, ?)", (event.run_id, event.seq, event.ts, event.kind.value, event.lane_id, json.dumps(event.to_dict(), default=str)), ) conn.commit() def save_snapshot(self, run_id: str, state: RunState) -> None: data = json.dumps(state.to_dict(), default=str) with sqlite3.connect(self.db_path) as conn: conn.execute( "INSERT OR REPLACE INTO snapshots (run_id, state, updated_at) VALUES (?, ?, ?)", (run_id, data, time.time()), ) conn.commit() def load_latest(self, run_id: str) -> RunState | None: with sqlite3.connect(self.db_path) as conn: row = conn.execute( "SELECT state FROM snapshots WHERE run_id = ?", (run_id,) ).fetchone() if not row: return None try: return RunState.from_dict(json.loads(row[0])) except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Corrupt snapshot in SQLite for {run_id}: {e}") return None def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]: with sqlite3.connect(self.db_path) as conn: rows = conn.execute( "SELECT data FROM events WHERE run_id = ? AND seq > ? ORDER BY seq", (run_id, since_seq), ).fetchall() events = [] for row in rows: try: events.append(PAEvent.from_dict(json.loads(row[0]))) except (json.JSONDecodeError, KeyError): continue return events def delete_run(self, run_id: str) -> None: """Remove all data for a run (cleanup).""" with sqlite3.connect(self.db_path) as conn: conn.execute("DELETE FROM events WHERE run_id = ?", (run_id,)) conn.execute("DELETE FROM snapshots WHERE run_id = ?", (run_id,)) conn.commit() def list_runs(self) -> list[str]: """List all run IDs with snapshots.""" with sqlite3.connect(self.db_path) as conn: rows = conn.execute("SELECT run_id FROM snapshots ORDER BY updated_at DESC").fetchall() return [r[0] for r in rows]