Rohan03 commited on
Commit
d1da4ac
·
verified ·
1 Parent(s): fdbceb4

first-principles: state_delta.py — O(1) Markovian critic (no history needed)

Browse files
Files changed (1) hide show
  1. purpose_agent/state_delta.py +176 -0
purpose_agent/state_delta.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ state_delta.py — Markovian State-Delta Evaluation.
3
+
4
+ PROBLEM: Passing full conversation history to the Critic costs O(N²) tokens
5
+ as the trajectory grows. This makes long tasks unaffordable.
6
+
7
+ SOLUTION: Recognize this IS a Markov Decision Process. The Critic only needs:
8
+ - The delta between state(t) and state(t+1)
9
+ - The action taken
10
+ - The purpose (constant)
11
+
12
+ This reduces the Critic's input to ~300 tokens regardless of trajectory length.
13
+
14
+ Mathematical basis:
15
+ In an MDP, the value of state s' depends only on s' itself, not on the path
16
+ taken to reach it. Therefore Φ(s') can be computed from s' alone.
17
+
18
+ We compute: diff = state_diff(s_t, s_{t+1})
19
+ Critic sees: (purpose, action, diff) → Φ score
20
+ Token cost: O(1) per step, not O(N)
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import difflib
25
+ import json
26
+ from dataclasses import dataclass
27
+ from typing import Any
28
+
29
+ from purpose_agent.types import State
30
+
31
+
32
+ @dataclass
33
+ class StateDelta:
34
+ """
35
+ The computed difference between two states.
36
+
37
+ This is ALL the Critic needs to see — not the full history.
38
+ Keeps token cost constant (~100-300 tokens) regardless of trajectory length.
39
+ """
40
+ added_keys: dict[str, Any] # New keys in s_{t+1} that weren't in s_t
41
+ removed_keys: list[str] # Keys in s_t that are gone in s_{t+1}
42
+ changed_keys: dict[str, tuple[Any, Any]] # key → (old_value, new_value)
43
+ summary_diff: str # Human-readable diff (what the Critic sees)
44
+ token_estimate: int = 0 # Estimated tokens for this delta
45
+
46
+ @property
47
+ def is_empty(self) -> bool:
48
+ """No state change occurred."""
49
+ return not self.added_keys and not self.removed_keys and not self.changed_keys
50
+
51
+ @property
52
+ def change_magnitude(self) -> int:
53
+ """How much changed — useful for detecting no-ops."""
54
+ return len(self.added_keys) + len(self.removed_keys) + len(self.changed_keys)
55
+
56
+
57
+ def compute_state_delta(s_before: State, s_after: State) -> StateDelta:
58
+ """
59
+ Compute the minimal diff between two states.
60
+
61
+ This is the core O(1) optimization: instead of passing both full states
62
+ to the Critic (which grows with trajectory), we pass only what changed.
63
+
64
+ Token cost: ~100-300 tokens regardless of state size or trajectory length.
65
+ """
66
+ d_before = s_before.data
67
+ d_after = s_after.data
68
+
69
+ added = {}
70
+ removed = []
71
+ changed = {}
72
+
73
+ # Find additions and changes
74
+ for key, val in d_after.items():
75
+ if key.startswith("_"): # Skip internal keys
76
+ continue
77
+ if key not in d_before:
78
+ added[key] = val
79
+ elif d_before[key] != val:
80
+ changed[key] = (d_before[key], val)
81
+
82
+ # Find removals
83
+ for key in d_before:
84
+ if key.startswith("_"):
85
+ continue
86
+ if key not in d_after:
87
+ removed.append(key)
88
+
89
+ # Build human-readable summary (what the Critic actually sees)
90
+ lines = []
91
+ if added:
92
+ for k, v in added.items():
93
+ lines.append(f"+ {k} = {_truncate(v)}")
94
+ if removed:
95
+ for k in removed:
96
+ lines.append(f"- {k} (removed)")
97
+ if changed:
98
+ for k, (old, new) in changed.items():
99
+ lines.append(f"~ {k}: {_truncate(old)} → {_truncate(new)}")
100
+
101
+ # If summaries available, use text diff
102
+ if s_before.summary and s_after.summary and s_before.summary != s_after.summary:
103
+ text_diff = _text_diff(s_before.summary, s_after.summary)
104
+ if text_diff and not lines:
105
+ lines.append(text_diff)
106
+
107
+ summary = "\n".join(lines) if lines else "(no observable change)"
108
+ token_est = len(summary) // 4 # ~4 chars per token
109
+
110
+ return StateDelta(
111
+ added_keys=added,
112
+ removed_keys=removed,
113
+ changed_keys=changed,
114
+ summary_diff=summary,
115
+ token_estimate=token_est,
116
+ )
117
+
118
+
119
+ def format_critic_input(
120
+ purpose: str,
121
+ action_name: str,
122
+ action_thought: str,
123
+ delta: StateDelta,
124
+ max_tokens: int = 300,
125
+ ) -> str:
126
+ """
127
+ Format the minimal input for the Markovian Critic.
128
+
129
+ Total budget: ~300 tokens. This is ALL the Critic gets.
130
+ No history. No full state. Just: purpose + action + what changed.
131
+
132
+ This is the fundamental insight: in an MDP, Φ(s') depends only on s',
133
+ not on the path taken to reach it.
134
+ """
135
+ parts = [
136
+ f"PURPOSE: {purpose[:100]}",
137
+ f"ACTION: {action_name}",
138
+ f"THOUGHT: {action_thought[:80]}",
139
+ f"STATE CHANGE:\n{delta.summary_diff[:400]}",
140
+ ]
141
+
142
+ text = "\n".join(parts)
143
+
144
+ # Hard cap
145
+ char_budget = max_tokens * 4
146
+ if len(text) > char_budget:
147
+ text = text[:char_budget] + "\n...(truncated)"
148
+
149
+ return text
150
+
151
+
152
+ def _truncate(val: Any, max_len: int = 80) -> str:
153
+ """Truncate a value for display."""
154
+ s = str(val)
155
+ return s[:max_len] + "..." if len(s) > max_len else s
156
+
157
+
158
+ def _text_diff(before: str, after: str) -> str:
159
+ """Compute a concise text diff between two summaries."""
160
+ if before == after:
161
+ return ""
162
+
163
+ # Use unified diff for short texts
164
+ before_lines = before.split("\n")
165
+ after_lines = after.split("\n")
166
+
167
+ diff_lines = []
168
+ for line in difflib.unified_diff(before_lines, after_lines, lineterm="", n=0):
169
+ if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
170
+ continue
171
+ if line.startswith("+"):
172
+ diff_lines.append(f"+ {line[1:]}")
173
+ elif line.startswith("-"):
174
+ diff_lines.append(f"- {line[1:]}")
175
+
176
+ return "\n".join(diff_lines[:5]) # Cap at 5 diff lines