File size: 5,567 Bytes
d1da4ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
state_delta.py — Markovian State-Delta Evaluation.

PROBLEM: Passing full conversation history to the Critic costs O(N²) tokens
as the trajectory grows. This makes long tasks unaffordable.

SOLUTION: Recognize this IS a Markov Decision Process. The Critic only needs:
  - The delta between state(t) and state(t+1)
  - The action taken
  - The purpose (constant)

This reduces the Critic's input to ~300 tokens regardless of trajectory length.

Mathematical basis:
  In an MDP, the value of state s' depends only on s' itself, not on the path
  taken to reach it. Therefore Φ(s') can be computed from s' alone.
  
  We compute: diff = state_diff(s_t, s_{t+1})
  Critic sees: (purpose, action, diff) → Φ score
  Token cost: O(1) per step, not O(N)
"""
from __future__ import annotations

import difflib
import json
from dataclasses import dataclass
from typing import Any

from purpose_agent.types import State


@dataclass
class StateDelta:
    """
    The computed difference between two states.
    
    This is ALL the Critic needs to see — not the full history.
    Keeps token cost constant (~100-300 tokens) regardless of trajectory length.
    """
    added_keys: dict[str, Any]      # New keys in s_{t+1} that weren't in s_t
    removed_keys: list[str]          # Keys in s_t that are gone in s_{t+1}
    changed_keys: dict[str, tuple[Any, Any]]  # key → (old_value, new_value)
    summary_diff: str                # Human-readable diff (what the Critic sees)
    token_estimate: int = 0          # Estimated tokens for this delta

    @property
    def is_empty(self) -> bool:
        """No state change occurred."""
        return not self.added_keys and not self.removed_keys and not self.changed_keys

    @property
    def change_magnitude(self) -> int:
        """How much changed — useful for detecting no-ops."""
        return len(self.added_keys) + len(self.removed_keys) + len(self.changed_keys)


def compute_state_delta(s_before: State, s_after: State) -> StateDelta:
    """
    Compute the minimal diff between two states.
    
    This is the core O(1) optimization: instead of passing both full states
    to the Critic (which grows with trajectory), we pass only what changed.
    
    Token cost: ~100-300 tokens regardless of state size or trajectory length.
    """
    d_before = s_before.data
    d_after = s_after.data

    added = {}
    removed = []
    changed = {}

    # Find additions and changes
    for key, val in d_after.items():
        if key.startswith("_"):  # Skip internal keys
            continue
        if key not in d_before:
            added[key] = val
        elif d_before[key] != val:
            changed[key] = (d_before[key], val)

    # Find removals
    for key in d_before:
        if key.startswith("_"):
            continue
        if key not in d_after:
            removed.append(key)

    # Build human-readable summary (what the Critic actually sees)
    lines = []
    if added:
        for k, v in added.items():
            lines.append(f"+ {k} = {_truncate(v)}")
    if removed:
        for k in removed:
            lines.append(f"- {k} (removed)")
    if changed:
        for k, (old, new) in changed.items():
            lines.append(f"~ {k}: {_truncate(old)}{_truncate(new)}")

    # If summaries available, use text diff
    if s_before.summary and s_after.summary and s_before.summary != s_after.summary:
        text_diff = _text_diff(s_before.summary, s_after.summary)
        if text_diff and not lines:
            lines.append(text_diff)

    summary = "\n".join(lines) if lines else "(no observable change)"
    token_est = len(summary) // 4  # ~4 chars per token

    return StateDelta(
        added_keys=added,
        removed_keys=removed,
        changed_keys=changed,
        summary_diff=summary,
        token_estimate=token_est,
    )


def format_critic_input(
    purpose: str,
    action_name: str,
    action_thought: str,
    delta: StateDelta,
    max_tokens: int = 300,
) -> str:
    """
    Format the minimal input for the Markovian Critic.
    
    Total budget: ~300 tokens. This is ALL the Critic gets.
    No history. No full state. Just: purpose + action + what changed.
    
    This is the fundamental insight: in an MDP, Φ(s') depends only on s',
    not on the path taken to reach it.
    """
    parts = [
        f"PURPOSE: {purpose[:100]}",
        f"ACTION: {action_name}",
        f"THOUGHT: {action_thought[:80]}",
        f"STATE CHANGE:\n{delta.summary_diff[:400]}",
    ]

    text = "\n".join(parts)

    # Hard cap
    char_budget = max_tokens * 4
    if len(text) > char_budget:
        text = text[:char_budget] + "\n...(truncated)"

    return text


def _truncate(val: Any, max_len: int = 80) -> str:
    """Truncate a value for display."""
    s = str(val)
    return s[:max_len] + "..." if len(s) > max_len else s


def _text_diff(before: str, after: str) -> str:
    """Compute a concise text diff between two summaries."""
    if before == after:
        return ""

    # Use unified diff for short texts
    before_lines = before.split("\n")
    after_lines = after.split("\n")

    diff_lines = []
    for line in difflib.unified_diff(before_lines, after_lines, lineterm="", n=0):
        if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
            continue
        if line.startswith("+"):
            diff_lines.append(f"+ {line[1:]}")
        elif line.startswith("-"):
            diff_lines.append(f"- {line[1:]}")

    return "\n".join(diff_lines[:5])  # Cap at 5 diff lines