purpose-agent / tests /test_sprint2_checkpoint.py
Rohan03's picture
v3.0.0 Production Release: Hardened framework, strict tool validation, test suite robustification
36d2671
#!/usr/bin/env python3
"""
Sprint 2 Tests β€” Durable Execution.
T2.1 Interrupt after node 2 in 5-node flow; resume at node 3
T2.2 Crash during tool; idempotent tool doesn't run twice on resume
T2.3 HITL checkpoint pauses and resumes with modified state
T2.4 SQLite checkpointer survives (simulated) process restart
T2.5 JSONL event log reconstructs run transcript
"""
import sys
import os
import json
import tempfile
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
PASS = 0
FAIL = 0
def check(name, condition, detail=""):
global PASS, FAIL
if condition:
PASS += 1
print(f" βœ“ {name}")
else:
FAIL += 1
print(f" βœ— {name}" + (f": {detail}" if detail else ""))
from purpose_agent.runtime.events import PAEvent, EventKind, create_event
from purpose_agent.runtime.state import RunState, RunStatus, NodeState
from purpose_agent.runtime.checkpoint import (
InMemoryCheckpointer, JSONLCheckpointer, SQLiteCheckpointer,
)
# ═══ T2.1: Interrupt and resume ═══
print("T2.1: Interrupt after node 2, resume at node 3")
state = RunState(run_id="flow1", purpose="5 node flow", max_steps=5)
checkpointer = InMemoryCheckpointer()
# Simulate nodes 1-2 completing
for i in range(1, 3):
state.mark_node_started(f"node_{i}")
state.mark_node_completed(f"node_{i}", output={"result": f"done_{i}"})
state.current_step = i
checkpointer.save_event(create_event("flow1", EventKind.AGENT_FINISHED, seq=i, node=f"node_{i}"))
# Save checkpoint (simulating interrupt)
checkpointer.save_snapshot("flow1", state)
# "Restart" β€” load from checkpoint
restored = checkpointer.load_latest("flow1")
check("State restored", restored is not None)
check("Correct step (2)", restored.current_step == 2)
check("2 nodes completed", len(restored.completed_nodes) == 2)
check("Node 3 not started", "node_3" not in restored.completed_nodes)
# Resume from node 3
restored.mark_node_started("node_3")
check("Resume at node 3", restored.current_node == "node_3")
# ═══ T2.2: Idempotent tool calls ═══
print("\nT2.2: Idempotent tool β€” no double execution")
state2 = RunState(run_id="tool_test")
state2.current_step = 1
# First execution β€” tool runs
key = state2.get_idempotency_key("calculator", "2+2")
check("No cached result initially", not state2.has_cached_result(key))
# Tool executes, result cached
state2.cache_result(key, "4")
check("Result cached", state2.has_cached_result(key))
check("Cached value correct", state2.get_cached_result(key) == "4")
# On "resume" β€” same key, should use cached result
key2 = state2.get_idempotency_key("calculator", "2+2")
check("Same key on resume", key == key2)
check("Uses cached (idempotent)", state2.has_cached_result(key2))
# ═══ T2.3: HITL checkpoint pause/resume ═══
print("\nT2.3: HITL pause and resume with modified state")
state3 = RunState(run_id="hitl_test", status=RunStatus.RUNNING)
state3.data = {"task": "review code", "approval": None}
checkpointer.save_snapshot("hitl_test", state3)
# Simulate pause for human
state3.status = RunStatus.PAUSED
checkpointer.save_snapshot("hitl_test", state3)
# "Human" modifies state
loaded = checkpointer.load_latest("hitl_test")
check("Loaded paused state", loaded.status == RunStatus.PAUSED)
loaded.data["approval"] = "approved"
loaded.status = RunStatus.RUNNING
checkpointer.save_snapshot("hitl_test", loaded)
# Resume
final = checkpointer.load_latest("hitl_test")
check("Human modification preserved", final.data["approval"] == "approved")
check("Status back to running", final.status == RunStatus.RUNNING)
# ═══ T2.4: SQLite survives "restart" ═══
print("\nT2.4: SQLite durability")
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
# First "process"
cp1 = SQLiteCheckpointer(db_path)
state4 = RunState(run_id="sqlite_test", purpose="durable", current_step=7)
state4.data = {"progress": "midway"}
cp1.save_snapshot("sqlite_test", state4)
for i in range(5):
cp1.save_event(create_event("sqlite_test", EventKind.AGENT_PROGRESS, seq=i+1))
del cp1 # "Process dies"
# Second "process" β€” new connection
cp2 = SQLiteCheckpointer(db_path)
restored4 = cp2.load_latest("sqlite_test")
events4 = cp2.list_events("sqlite_test")
check("SQLite snapshot survives", restored4 is not None)
check("SQLite state correct", restored4.current_step == 7)
check("SQLite data correct", restored4.data.get("progress") == "midway")
check("SQLite events survive", len(events4) == 5)
check("SQLite lists runs", "sqlite_test" in cp2.list_runs())
finally:
try:
del cp1
except: pass
try:
del cp2
except: pass
try:
os.unlink(db_path)
except PermissionError:
pass
# ═══ T2.5: JSONL event log reconstruction ═══
print("\nT2.5: JSONL event log β†’ transcript reconstruction")
with tempfile.TemporaryDirectory() as tmpdir:
cp = JSONLCheckpointer(tmpdir)
# Write events
for i in range(10):
cp.save_event(create_event("jsonl_test", EventKind.AGENT_PROGRESS, seq=i+1,
message=f"step {i+1}"))
# Save snapshot
s = RunState(run_id="jsonl_test", current_step=10)
cp.save_snapshot("jsonl_test", s)
# Reconstruct
events = cp.list_events("jsonl_test")
check("JSONL has all events", len(events) == 10)
check("Events ordered", all(events[i].seq <= events[i+1].seq for i in range(9)))
check("Payloads preserved", events[0].payload.get("message") == "step 1")
check("Snapshot loads", cp.load_latest("jsonl_test") is not None)
# Partial replay (since_seq)
partial = cp.list_events("jsonl_test", since_seq=5)
check("Partial replay works", len(partial) == 5)
check("Partial starts at seq 6", partial[0].seq == 6)
# ═══ T2.x: RunState serialization roundtrip ═══
print("\nT2.x: RunState to_dict/from_dict roundtrip")
original = RunState(
run_id="rt_test", session_id="sess1", status=RunStatus.PAUSED,
purpose="test roundtrip", current_node="node_2", current_step=3,
data={"key": "value"}, completed_nodes=["node_1"],
)
original.mark_node_started("node_2")
original.cache_result("tool:calc:2+2", "4")
d = original.to_dict()
restored_rt = RunState.from_dict(d)
check("run_id roundtrip", restored_rt.run_id == "rt_test")
check("status roundtrip", restored_rt.status == RunStatus.PAUSED)
check("current_node roundtrip", restored_rt.current_node == "node_2")
check("data roundtrip", restored_rt.data == {"key": "value"})
check("tool_cache roundtrip", restored_rt.has_cached_result("tool:calc:2+2"))
check("completed_nodes roundtrip", "node_1" in restored_rt.completed_nodes)
# ═══ REPORT ═══
print(f"\n{'='*50}")
print(f" Sprint 2 Tests: {PASS} pass, {FAIL} fail")
print(f" {'ALL PASS βœ“' if FAIL == 0 else f'{FAIL} FAILURES'}")
print(f"{'='*50}")
sys.exit(0 if FAIL == 0 else 1)