Rohan03 commited on
Commit
28b87a7
·
verified ·
1 Parent(s): ccbf192

V2 merge: purpose_agent/self_taught.py

Browse files
Files changed (1) hide show
  1. purpose_agent/self_taught.py +214 -0
purpose_agent/self_taught.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ self_taught.py — Synthetic training data for the Purpose Function.
3
+
4
+ From Self-Taught Evaluators (arxiv:2408.02666):
5
+ Generate synthetic preference pairs (good vs bad state evaluations)
6
+ from the agent's own traces, then use them to improve the Purpose
7
+ Function's prompts without any human labels.
8
+
9
+ Adaptation for Purpose Agent (no weight updates):
10
+ 1. Take a completed trace with Φ scores
11
+ 2. For each step, generate a "worse" evaluation (modified instruction trick)
12
+ 3. The correct evaluation becomes a positive example
13
+ 4. The worse evaluation becomes a negative example
14
+ 5. Store both as critic_calibration memories
15
+ 6. The Purpose Function improves via in-context learning from these examples
16
+
17
+ This is an automatic curriculum: as the Purpose Function improves,
18
+ it generates harder training pairs, which further improve it.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import logging
24
+ from typing import Any
25
+
26
+ from purpose_agent.llm_backend import LLMBackend, ChatMessage
27
+ from purpose_agent.trace import Trace
28
+ from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus
29
+ from purpose_agent.v2_types import MemoryScope
30
+ from purpose_agent.memory_ci import MemoryCI
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ GENERATE_CONTRAST_PROMPT = """\
35
+ You are generating training data for a state evaluator (critic).
36
+
37
+ Given this CORRECT evaluation of a state transition:
38
+ State before: {state_before}
39
+ Action: {action}
40
+ State after: {state_after}
41
+ Purpose: {purpose}
42
+ Correct Φ_before: {phi_before:.1f}
43
+ Correct Φ_after: {phi_after:.1f}
44
+ Correct reasoning: {reasoning}
45
+
46
+ Generate a PLAUSIBLE BUT WRONG evaluation that makes a common mistake.
47
+ Common mistakes:
48
+ - Giving credit for intentions rather than actual state changes
49
+ - Inflating scores to be encouraging (sycophancy)
50
+ - Ignoring evidence and scoring based on action name alone
51
+ - Being inconsistent with the scoring scale
52
+
53
+ Respond with JSON:
54
+ {{
55
+ "wrong_phi_after": <a plausible but incorrect score>,
56
+ "wrong_reasoning": "<plausible but flawed reasoning>",
57
+ "mistake_type": "<which common mistake this represents>"
58
+ }}
59
+ """
60
+
61
+ CONTRAST_SCHEMA = {
62
+ "type": "object",
63
+ "properties": {
64
+ "wrong_phi_after": {"type": "number"},
65
+ "wrong_reasoning": {"type": "string"},
66
+ "mistake_type": {"type": "string"},
67
+ },
68
+ "required": ["wrong_phi_after", "wrong_reasoning", "mistake_type"],
69
+ }
70
+
71
+
72
+ class SelfTaughtEvaluator:
73
+ """
74
+ Generates synthetic training data for the Purpose Function from traces.
75
+
76
+ Usage:
77
+ ste = SelfTaughtEvaluator(llm=model, memory_ci=ci)
78
+
79
+ # After a trace is complete:
80
+ pairs = ste.generate_from_trace(trace)
81
+ # → Creates critic_calibration memories with good/bad examples
82
+
83
+ # Iterative: as the critic improves, it generates harder pairs
84
+ for iteration in range(3):
85
+ for trace in recent_traces:
86
+ ste.generate_from_trace(trace)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ llm: LLMBackend,
92
+ memory_ci: MemoryCI,
93
+ min_delta_for_training: float = 0.5,
94
+ ):
95
+ self.llm = llm
96
+ self.memory_ci = memory_ci
97
+ self.min_delta = min_delta_for_training
98
+ self._pairs_generated = 0
99
+
100
+ def generate_from_trace(self, trace: Trace) -> int:
101
+ """
102
+ Generate contrast pairs from a trace's score events.
103
+
104
+ Returns number of pairs generated.
105
+ """
106
+ count = 0
107
+ score_events = [e for e in trace.events if e.kind == "score"]
108
+
109
+ for event in score_events:
110
+ data = event.data
111
+ phi_before = data.get("phi_before", 0)
112
+ phi_after = data.get("phi_after", 0)
113
+ delta = phi_after - phi_before
114
+
115
+ # Only generate pairs for meaningful transitions
116
+ if abs(delta) < self.min_delta:
117
+ continue
118
+
119
+ try:
120
+ pair = self._generate_contrast_pair(
121
+ state_before=data.get("state_before", ""),
122
+ action=data.get("action_name", ""),
123
+ state_after=data.get("state_after", ""),
124
+ purpose=trace.purpose,
125
+ phi_before=phi_before,
126
+ phi_after=phi_after,
127
+ reasoning=data.get("reasoning", ""),
128
+ )
129
+ if pair:
130
+ self._store_pair(pair, trace.trace_id)
131
+ count += 1
132
+ except Exception as e:
133
+ logger.warning(f"SelfTaught: Failed to generate pair: {e}")
134
+
135
+ self._pairs_generated += count
136
+ logger.info(f"SelfTaught: Generated {count} contrast pairs from trace {trace.trace_id}")
137
+ return count
138
+
139
+ def _generate_contrast_pair(
140
+ self,
141
+ state_before: str,
142
+ action: str,
143
+ state_after: str,
144
+ purpose: str,
145
+ phi_before: float,
146
+ phi_after: float,
147
+ reasoning: str,
148
+ ) -> dict[str, Any] | None:
149
+ """Generate a single (correct, wrong) evaluation pair."""
150
+ messages = [
151
+ ChatMessage(role="system", content="Generate a plausible but incorrect evaluation for training."),
152
+ ChatMessage(role="user", content=GENERATE_CONTRAST_PROMPT.format(
153
+ state_before=state_before[:200],
154
+ action=action,
155
+ state_after=state_after[:200],
156
+ purpose=purpose,
157
+ phi_before=phi_before,
158
+ phi_after=phi_after,
159
+ reasoning=reasoning[:200],
160
+ )),
161
+ ]
162
+
163
+ try:
164
+ result = self.llm.generate_structured(messages, schema=CONTRAST_SCHEMA)
165
+ except Exception:
166
+ raw = self.llm.generate(messages, temperature=0.7)
167
+ try:
168
+ result = json.loads(raw)
169
+ except Exception:
170
+ return None
171
+
172
+ return {
173
+ "correct_phi_after": phi_after,
174
+ "correct_reasoning": reasoning,
175
+ "wrong_phi_after": result.get("wrong_phi_after", phi_after + 2),
176
+ "wrong_reasoning": result.get("wrong_reasoning", ""),
177
+ "mistake_type": result.get("mistake_type", "unknown"),
178
+ }
179
+
180
+ def _store_pair(self, pair: dict, trace_id: str) -> None:
181
+ """Store a contrast pair as calibration memories."""
182
+ # Positive example
183
+ self.memory_ci.submit(MemoryCard(
184
+ kind=MemoryKind.CRITIC_CALIBRATION,
185
+ content=(
186
+ f"CORRECT scoring example: Φ_after={pair['correct_phi_after']:.1f}. "
187
+ f"Reasoning: {pair['correct_reasoning'][:200]}"
188
+ ),
189
+ pattern="When evaluating state transitions",
190
+ strategy=f"Score like this: {pair['correct_reasoning'][:150]}",
191
+ trust_score=0.7,
192
+ source_trace_id=trace_id,
193
+ created_by="self_taught",
194
+ ))
195
+
196
+ # Negative example (what NOT to do)
197
+ self.memory_ci.submit(MemoryCard(
198
+ kind=MemoryKind.FAILURE_PATTERN,
199
+ content=(
200
+ f"WRONG scoring example ({pair['mistake_type']}): "
201
+ f"Incorrectly scored Φ_after={pair['wrong_phi_after']:.1f}. "
202
+ f"Flawed reasoning: {pair['wrong_reasoning'][:200]}"
203
+ ),
204
+ pattern="When evaluating state transitions",
205
+ strategy=f"AVOID this mistake: {pair['mistake_type']}",
206
+ trust_score=0.7,
207
+ source_trace_id=trace_id,
208
+ created_by="self_taught",
209
+ scope=MemoryScope(agent_roles=["critic"]),
210
+ ))
211
+
212
+ @property
213
+ def pairs_generated(self) -> int:
214
+ return self._pairs_generated