File size: 7,570 Bytes
28b87a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
self_taught.py — Synthetic training data for the Purpose Function.

From Self-Taught Evaluators (arxiv:2408.02666):
  Generate synthetic preference pairs (good vs bad state evaluations)
  from the agent's own traces, then use them to improve the Purpose
  Function's prompts without any human labels.

Adaptation for Purpose Agent (no weight updates):
  1. Take a completed trace with Φ scores
  2. For each step, generate a "worse" evaluation (modified instruction trick)
  3. The correct evaluation becomes a positive example
  4. The worse evaluation becomes a negative example
  5. Store both as critic_calibration memories
  6. The Purpose Function improves via in-context learning from these examples

This is an automatic curriculum: as the Purpose Function improves,
it generates harder training pairs, which further improve it.
"""
from __future__ import annotations

import json
import logging
from typing import Any

from purpose_agent.llm_backend import LLMBackend, ChatMessage
from purpose_agent.trace import Trace
from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus
from purpose_agent.v2_types import MemoryScope
from purpose_agent.memory_ci import MemoryCI

logger = logging.getLogger(__name__)

GENERATE_CONTRAST_PROMPT = """\
You are generating training data for a state evaluator (critic).

Given this CORRECT evaluation of a state transition:
  State before: {state_before}
  Action: {action}
  State after: {state_after}
  Purpose: {purpose}
  Correct Φ_before: {phi_before:.1f}
  Correct Φ_after: {phi_after:.1f}
  Correct reasoning: {reasoning}

Generate a PLAUSIBLE BUT WRONG evaluation that makes a common mistake.
Common mistakes:
- Giving credit for intentions rather than actual state changes
- Inflating scores to be encouraging (sycophancy)
- Ignoring evidence and scoring based on action name alone
- Being inconsistent with the scoring scale

Respond with JSON:
{{
  "wrong_phi_after": <a plausible but incorrect score>,
  "wrong_reasoning": "<plausible but flawed reasoning>",
  "mistake_type": "<which common mistake this represents>"
}}
"""

CONTRAST_SCHEMA = {
    "type": "object",
    "properties": {
        "wrong_phi_after": {"type": "number"},
        "wrong_reasoning": {"type": "string"},
        "mistake_type": {"type": "string"},
    },
    "required": ["wrong_phi_after", "wrong_reasoning", "mistake_type"],
}


class SelfTaughtEvaluator:
    """
    Generates synthetic training data for the Purpose Function from traces.

    Usage:
        ste = SelfTaughtEvaluator(llm=model, memory_ci=ci)

        # After a trace is complete:
        pairs = ste.generate_from_trace(trace)
        # → Creates critic_calibration memories with good/bad examples

        # Iterative: as the critic improves, it generates harder pairs
        for iteration in range(3):
            for trace in recent_traces:
                ste.generate_from_trace(trace)
    """

    def __init__(
        self,
        llm: LLMBackend,
        memory_ci: MemoryCI,
        min_delta_for_training: float = 0.5,
    ):
        self.llm = llm
        self.memory_ci = memory_ci
        self.min_delta = min_delta_for_training
        self._pairs_generated = 0

    def generate_from_trace(self, trace: Trace) -> int:
        """
        Generate contrast pairs from a trace's score events.

        Returns number of pairs generated.
        """
        count = 0
        score_events = [e for e in trace.events if e.kind == "score"]

        for event in score_events:
            data = event.data
            phi_before = data.get("phi_before", 0)
            phi_after = data.get("phi_after", 0)
            delta = phi_after - phi_before

            # Only generate pairs for meaningful transitions
            if abs(delta) < self.min_delta:
                continue

            try:
                pair = self._generate_contrast_pair(
                    state_before=data.get("state_before", ""),
                    action=data.get("action_name", ""),
                    state_after=data.get("state_after", ""),
                    purpose=trace.purpose,
                    phi_before=phi_before,
                    phi_after=phi_after,
                    reasoning=data.get("reasoning", ""),
                )
                if pair:
                    self._store_pair(pair, trace.trace_id)
                    count += 1
            except Exception as e:
                logger.warning(f"SelfTaught: Failed to generate pair: {e}")

        self._pairs_generated += count
        logger.info(f"SelfTaught: Generated {count} contrast pairs from trace {trace.trace_id}")
        return count

    def _generate_contrast_pair(
        self,
        state_before: str,
        action: str,
        state_after: str,
        purpose: str,
        phi_before: float,
        phi_after: float,
        reasoning: str,
    ) -> dict[str, Any] | None:
        """Generate a single (correct, wrong) evaluation pair."""
        messages = [
            ChatMessage(role="system", content="Generate a plausible but incorrect evaluation for training."),
            ChatMessage(role="user", content=GENERATE_CONTRAST_PROMPT.format(
                state_before=state_before[:200],
                action=action,
                state_after=state_after[:200],
                purpose=purpose,
                phi_before=phi_before,
                phi_after=phi_after,
                reasoning=reasoning[:200],
            )),
        ]

        try:
            result = self.llm.generate_structured(messages, schema=CONTRAST_SCHEMA)
        except Exception:
            raw = self.llm.generate(messages, temperature=0.7)
            try:
                result = json.loads(raw)
            except Exception:
                return None

        return {
            "correct_phi_after": phi_after,
            "correct_reasoning": reasoning,
            "wrong_phi_after": result.get("wrong_phi_after", phi_after + 2),
            "wrong_reasoning": result.get("wrong_reasoning", ""),
            "mistake_type": result.get("mistake_type", "unknown"),
        }

    def _store_pair(self, pair: dict, trace_id: str) -> None:
        """Store a contrast pair as calibration memories."""
        # Positive example
        self.memory_ci.submit(MemoryCard(
            kind=MemoryKind.CRITIC_CALIBRATION,
            content=(
                f"CORRECT scoring example: Φ_after={pair['correct_phi_after']:.1f}. "
                f"Reasoning: {pair['correct_reasoning'][:200]}"
            ),
            pattern="When evaluating state transitions",
            strategy=f"Score like this: {pair['correct_reasoning'][:150]}",
            trust_score=0.7,
            source_trace_id=trace_id,
            created_by="self_taught",
        ))

        # Negative example (what NOT to do)
        self.memory_ci.submit(MemoryCard(
            kind=MemoryKind.FAILURE_PATTERN,
            content=(
                f"WRONG scoring example ({pair['mistake_type']}): "
                f"Incorrectly scored Φ_after={pair['wrong_phi_after']:.1f}. "
                f"Flawed reasoning: {pair['wrong_reasoning'][:200]}"
            ),
            pattern="When evaluating state transitions",
            strategy=f"AVOID this mistake: {pair['mistake_type']}",
            trust_score=0.7,
            source_trace_id=trace_id,
            created_by="self_taught",
            scope=MemoryScope(agent_roles=["critic"]),
        ))

    @property
    def pairs_generated(self) -> int:
        return self._pairs_generated