File size: 6,976 Bytes
8f83eeb 36d2671 8f83eeb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | #!/usr/bin/env python3
"""
Sprint 2 Tests β Durable Execution.
T2.1 Interrupt after node 2 in 5-node flow; resume at node 3
T2.2 Crash during tool; idempotent tool doesn't run twice on resume
T2.3 HITL checkpoint pauses and resumes with modified state
T2.4 SQLite checkpointer survives (simulated) process restart
T2.5 JSONL event log reconstructs run transcript
"""
import sys
import os
import json
import tempfile
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
PASS = 0
FAIL = 0
def check(name, condition, detail=""):
global PASS, FAIL
if condition:
PASS += 1
print(f" β {name}")
else:
FAIL += 1
print(f" β {name}" + (f": {detail}" if detail else ""))
from purpose_agent.runtime.events import PAEvent, EventKind, create_event
from purpose_agent.runtime.state import RunState, RunStatus, NodeState
from purpose_agent.runtime.checkpoint import (
InMemoryCheckpointer, JSONLCheckpointer, SQLiteCheckpointer,
)
# βββ T2.1: Interrupt and resume βββ
print("T2.1: Interrupt after node 2, resume at node 3")
state = RunState(run_id="flow1", purpose="5 node flow", max_steps=5)
checkpointer = InMemoryCheckpointer()
# Simulate nodes 1-2 completing
for i in range(1, 3):
state.mark_node_started(f"node_{i}")
state.mark_node_completed(f"node_{i}", output={"result": f"done_{i}"})
state.current_step = i
checkpointer.save_event(create_event("flow1", EventKind.AGENT_FINISHED, seq=i, node=f"node_{i}"))
# Save checkpoint (simulating interrupt)
checkpointer.save_snapshot("flow1", state)
# "Restart" β load from checkpoint
restored = checkpointer.load_latest("flow1")
check("State restored", restored is not None)
check("Correct step (2)", restored.current_step == 2)
check("2 nodes completed", len(restored.completed_nodes) == 2)
check("Node 3 not started", "node_3" not in restored.completed_nodes)
# Resume from node 3
restored.mark_node_started("node_3")
check("Resume at node 3", restored.current_node == "node_3")
# βββ T2.2: Idempotent tool calls βββ
print("\nT2.2: Idempotent tool β no double execution")
state2 = RunState(run_id="tool_test")
state2.current_step = 1
# First execution β tool runs
key = state2.get_idempotency_key("calculator", "2+2")
check("No cached result initially", not state2.has_cached_result(key))
# Tool executes, result cached
state2.cache_result(key, "4")
check("Result cached", state2.has_cached_result(key))
check("Cached value correct", state2.get_cached_result(key) == "4")
# On "resume" β same key, should use cached result
key2 = state2.get_idempotency_key("calculator", "2+2")
check("Same key on resume", key == key2)
check("Uses cached (idempotent)", state2.has_cached_result(key2))
# βββ T2.3: HITL checkpoint pause/resume βββ
print("\nT2.3: HITL pause and resume with modified state")
state3 = RunState(run_id="hitl_test", status=RunStatus.RUNNING)
state3.data = {"task": "review code", "approval": None}
checkpointer.save_snapshot("hitl_test", state3)
# Simulate pause for human
state3.status = RunStatus.PAUSED
checkpointer.save_snapshot("hitl_test", state3)
# "Human" modifies state
loaded = checkpointer.load_latest("hitl_test")
check("Loaded paused state", loaded.status == RunStatus.PAUSED)
loaded.data["approval"] = "approved"
loaded.status = RunStatus.RUNNING
checkpointer.save_snapshot("hitl_test", loaded)
# Resume
final = checkpointer.load_latest("hitl_test")
check("Human modification preserved", final.data["approval"] == "approved")
check("Status back to running", final.status == RunStatus.RUNNING)
# βββ T2.4: SQLite survives "restart" βββ
print("\nT2.4: SQLite durability")
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
# First "process"
cp1 = SQLiteCheckpointer(db_path)
state4 = RunState(run_id="sqlite_test", purpose="durable", current_step=7)
state4.data = {"progress": "midway"}
cp1.save_snapshot("sqlite_test", state4)
for i in range(5):
cp1.save_event(create_event("sqlite_test", EventKind.AGENT_PROGRESS, seq=i+1))
del cp1 # "Process dies"
# Second "process" β new connection
cp2 = SQLiteCheckpointer(db_path)
restored4 = cp2.load_latest("sqlite_test")
events4 = cp2.list_events("sqlite_test")
check("SQLite snapshot survives", restored4 is not None)
check("SQLite state correct", restored4.current_step == 7)
check("SQLite data correct", restored4.data.get("progress") == "midway")
check("SQLite events survive", len(events4) == 5)
check("SQLite lists runs", "sqlite_test" in cp2.list_runs())
finally:
try:
del cp1
except: pass
try:
del cp2
except: pass
try:
os.unlink(db_path)
except PermissionError:
pass
# βββ T2.5: JSONL event log reconstruction βββ
print("\nT2.5: JSONL event log β transcript reconstruction")
with tempfile.TemporaryDirectory() as tmpdir:
cp = JSONLCheckpointer(tmpdir)
# Write events
for i in range(10):
cp.save_event(create_event("jsonl_test", EventKind.AGENT_PROGRESS, seq=i+1,
message=f"step {i+1}"))
# Save snapshot
s = RunState(run_id="jsonl_test", current_step=10)
cp.save_snapshot("jsonl_test", s)
# Reconstruct
events = cp.list_events("jsonl_test")
check("JSONL has all events", len(events) == 10)
check("Events ordered", all(events[i].seq <= events[i+1].seq for i in range(9)))
check("Payloads preserved", events[0].payload.get("message") == "step 1")
check("Snapshot loads", cp.load_latest("jsonl_test") is not None)
# Partial replay (since_seq)
partial = cp.list_events("jsonl_test", since_seq=5)
check("Partial replay works", len(partial) == 5)
check("Partial starts at seq 6", partial[0].seq == 6)
# βββ T2.x: RunState serialization roundtrip βββ
print("\nT2.x: RunState to_dict/from_dict roundtrip")
original = RunState(
run_id="rt_test", session_id="sess1", status=RunStatus.PAUSED,
purpose="test roundtrip", current_node="node_2", current_step=3,
data={"key": "value"}, completed_nodes=["node_1"],
)
original.mark_node_started("node_2")
original.cache_result("tool:calc:2+2", "4")
d = original.to_dict()
restored_rt = RunState.from_dict(d)
check("run_id roundtrip", restored_rt.run_id == "rt_test")
check("status roundtrip", restored_rt.status == RunStatus.PAUSED)
check("current_node roundtrip", restored_rt.current_node == "node_2")
check("data roundtrip", restored_rt.data == {"key": "value"})
check("tool_cache roundtrip", restored_rt.has_cached_result("tool:calc:2+2"))
check("completed_nodes roundtrip", "node_1" in restored_rt.completed_nodes)
# βββ REPORT βββ
print(f"\n{'='*50}")
print(f" Sprint 2 Tests: {PASS} pass, {FAIL} fail")
print(f" {'ALL PASS β' if FAIL == 0 else f'{FAIL} FAILURES'}")
print(f"{'='*50}")
sys.exit(0 if FAIL == 0 else 1)
|