File size: 10,157 Bytes
b0e5b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
checkpoint.py β€” Durable execution via event sourcing and state snapshots.

Provides:
  - Checkpointer protocol (save events, save snapshots, load, list)
  - InMemoryCheckpointer (testing)
  - JSONLCheckpointer (file-based, portable)
  - SQLiteCheckpointer (production, concurrent-safe)

Usage:
    checkpointer = JSONLCheckpointer("./checkpoints")
    
    # During execution
    checkpointer.save_event(event)
    checkpointer.save_snapshot(run_id, state)
    
    # On crash recovery
    state = checkpointer.load_latest(run_id)
    events = checkpointer.list_events(run_id)
"""
from __future__ import annotations

import json
import logging
import os
import sqlite3
import time
from pathlib import Path
from typing import Any, Protocol

from purpose_agent.runtime.events import PAEvent
from purpose_agent.runtime.state import RunState

logger = logging.getLogger(__name__)


class Checkpointer(Protocol):
    """
    Protocol for durable state persistence.
    
    Implementations must support:
      - Appending events (event sourcing)
      - Saving full state snapshots
      - Loading the latest snapshot for a run
      - Listing all events for replay
    """

    def save_event(self, event: PAEvent) -> None:
        """Append an event to the durable log."""
        ...

    def save_snapshot(self, run_id: str, state: RunState) -> None:
        """Save a full state snapshot (overwrites previous for same run_id)."""
        ...

    def load_latest(self, run_id: str) -> RunState | None:
        """Load the most recent snapshot for a run. Returns None if not found."""
        ...

    def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
        """List all events for a run, optionally since a sequence number."""
        ...


# ═══════════════════════════════════════════════════════════════════
# InMemoryCheckpointer β€” for testing
# ═══════════════════════════════════════════════════════════════════

class InMemoryCheckpointer:
    """In-memory checkpointer for testing. Not durable across restarts."""

    def __init__(self):
        self._events: dict[str, list[PAEvent]] = {}
        self._snapshots: dict[str, RunState] = {}

    def save_event(self, event: PAEvent) -> None:
        self._events.setdefault(event.run_id, []).append(event)

    def save_snapshot(self, run_id: str, state: RunState) -> None:
        self._snapshots[run_id] = state

    def load_latest(self, run_id: str) -> RunState | None:
        return self._snapshots.get(run_id)

    def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
        events = self._events.get(run_id, [])
        if since_seq > 0:
            events = [e for e in events if e.seq > since_seq]
        return events

    @property
    def event_count(self) -> int:
        return sum(len(v) for v in self._events.values())

    @property
    def snapshot_count(self) -> int:
        return len(self._snapshots)


# ═══════════════════════════════════════════════════════════════════
# JSONLCheckpointer β€” file-based, portable
# ═══════════════════════════════════════════════════════════════════

class JSONLCheckpointer:
    """
    File-based checkpointer using JSONL for events and JSON for snapshots.
    
    Directory structure:
        base_dir/
            {run_id}/
                events.jsonl
                snapshot.json
    """

    def __init__(self, base_dir: str):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)

    def _run_dir(self, run_id: str) -> Path:
        d = self.base_dir / run_id
        d.mkdir(parents=True, exist_ok=True)
        return d

    def save_event(self, event: PAEvent) -> None:
        path = self._run_dir(event.run_id) / "events.jsonl"
        with open(path, "a") as f:
            f.write(json.dumps(event.to_dict(), default=str) + "\n")

    def save_snapshot(self, run_id: str, state: RunState) -> None:
        path = self._run_dir(run_id) / "snapshot.json"
        with open(path, "w") as f:
            json.dump(state.to_dict(), f, indent=2, default=str)

    def load_latest(self, run_id: str) -> RunState | None:
        path = self._run_dir(run_id) / "snapshot.json"
        if not path.exists():
            return None
        try:
            with open(path) as f:
                data = json.load(f)
            return RunState.from_dict(data)
        except (json.JSONDecodeError, KeyError) as e:
            logger.warning(f"Corrupt snapshot for {run_id}: {e}")
            # Try previous backup
            backup = self._run_dir(run_id) / "snapshot.backup.json"
            if backup.exists():
                with open(backup) as f:
                    return RunState.from_dict(json.load(f))
            return None

    def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
        path = self._run_dir(run_id) / "events.jsonl"
        if not path.exists():
            return []
        events = []
        with open(path) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    d = json.loads(line)
                    event = PAEvent.from_dict(d)
                    if event.seq > since_seq:
                        events.append(event)
                except (json.JSONDecodeError, KeyError):
                    continue
        return events


# ═══════════════════════════════════════════════════════════════════
# SQLiteCheckpointer β€” production, concurrent-safe
# ═══════════════════════════════════════════════════════════════════

class SQLiteCheckpointer:
    """
    SQLite-based checkpointer. Durable, concurrent-safe, supports multiple runs.
    
    Uses stdlib sqlite3 β€” no extra dependencies.
    """

    def __init__(self, db_path: str = "purpose_agent_checkpoints.db"):
        self.db_path = db_path
        self._init_db()

    def _init_db(self) -> None:
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS events (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    run_id TEXT NOT NULL,
                    seq INTEGER NOT NULL,
                    ts REAL NOT NULL,
                    kind TEXT NOT NULL,
                    lane_id TEXT DEFAULT 'main',
                    data TEXT NOT NULL,
                    created_at REAL DEFAULT (strftime('%s','now'))
                )
            """)
            conn.execute("""
                CREATE TABLE IF NOT EXISTS snapshots (
                    run_id TEXT PRIMARY KEY,
                    state TEXT NOT NULL,
                    updated_at REAL DEFAULT (strftime('%s','now'))
                )
            """)
            conn.execute("CREATE INDEX IF NOT EXISTS idx_events_run ON events(run_id, seq)")
            conn.commit()

    def save_event(self, event: PAEvent) -> None:
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                "INSERT INTO events (run_id, seq, ts, kind, lane_id, data) VALUES (?, ?, ?, ?, ?, ?)",
                (event.run_id, event.seq, event.ts, event.kind.value, event.lane_id,
                 json.dumps(event.to_dict(), default=str)),
            )
            conn.commit()

    def save_snapshot(self, run_id: str, state: RunState) -> None:
        data = json.dumps(state.to_dict(), default=str)
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                "INSERT OR REPLACE INTO snapshots (run_id, state, updated_at) VALUES (?, ?, ?)",
                (run_id, data, time.time()),
            )
            conn.commit()

    def load_latest(self, run_id: str) -> RunState | None:
        with sqlite3.connect(self.db_path) as conn:
            row = conn.execute(
                "SELECT state FROM snapshots WHERE run_id = ?", (run_id,)
            ).fetchone()
        if not row:
            return None
        try:
            return RunState.from_dict(json.loads(row[0]))
        except (json.JSONDecodeError, KeyError) as e:
            logger.warning(f"Corrupt snapshot in SQLite for {run_id}: {e}")
            return None

    def list_events(self, run_id: str, since_seq: int = 0) -> list[PAEvent]:
        with sqlite3.connect(self.db_path) as conn:
            rows = conn.execute(
                "SELECT data FROM events WHERE run_id = ? AND seq > ? ORDER BY seq",
                (run_id, since_seq),
            ).fetchall()
        events = []
        for row in rows:
            try:
                events.append(PAEvent.from_dict(json.loads(row[0])))
            except (json.JSONDecodeError, KeyError):
                continue
        return events

    def delete_run(self, run_id: str) -> None:
        """Remove all data for a run (cleanup)."""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("DELETE FROM events WHERE run_id = ?", (run_id,))
            conn.execute("DELETE FROM snapshots WHERE run_id = ?", (run_id,))
            conn.commit()

    def list_runs(self) -> list[str]:
        """List all run IDs with snapshots."""
        with sqlite3.connect(self.db_path) as conn:
            rows = conn.execute("SELECT run_id FROM snapshots ORDER BY updated_at DESC").fetchall()
        return [r[0] for r in rows]