File size: 7,570 Bytes
28b87a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """
self_taught.py — Synthetic training data for the Purpose Function.
From Self-Taught Evaluators (arxiv:2408.02666):
Generate synthetic preference pairs (good vs bad state evaluations)
from the agent's own traces, then use them to improve the Purpose
Function's prompts without any human labels.
Adaptation for Purpose Agent (no weight updates):
1. Take a completed trace with Φ scores
2. For each step, generate a "worse" evaluation (modified instruction trick)
3. The correct evaluation becomes a positive example
4. The worse evaluation becomes a negative example
5. Store both as critic_calibration memories
6. The Purpose Function improves via in-context learning from these examples
This is an automatic curriculum: as the Purpose Function improves,
it generates harder training pairs, which further improve it.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from purpose_agent.llm_backend import LLMBackend, ChatMessage
from purpose_agent.trace import Trace
from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus
from purpose_agent.v2_types import MemoryScope
from purpose_agent.memory_ci import MemoryCI
logger = logging.getLogger(__name__)
GENERATE_CONTRAST_PROMPT = """\
You are generating training data for a state evaluator (critic).
Given this CORRECT evaluation of a state transition:
State before: {state_before}
Action: {action}
State after: {state_after}
Purpose: {purpose}
Correct Φ_before: {phi_before:.1f}
Correct Φ_after: {phi_after:.1f}
Correct reasoning: {reasoning}
Generate a PLAUSIBLE BUT WRONG evaluation that makes a common mistake.
Common mistakes:
- Giving credit for intentions rather than actual state changes
- Inflating scores to be encouraging (sycophancy)
- Ignoring evidence and scoring based on action name alone
- Being inconsistent with the scoring scale
Respond with JSON:
{{
"wrong_phi_after": <a plausible but incorrect score>,
"wrong_reasoning": "<plausible but flawed reasoning>",
"mistake_type": "<which common mistake this represents>"
}}
"""
CONTRAST_SCHEMA = {
"type": "object",
"properties": {
"wrong_phi_after": {"type": "number"},
"wrong_reasoning": {"type": "string"},
"mistake_type": {"type": "string"},
},
"required": ["wrong_phi_after", "wrong_reasoning", "mistake_type"],
}
class SelfTaughtEvaluator:
"""
Generates synthetic training data for the Purpose Function from traces.
Usage:
ste = SelfTaughtEvaluator(llm=model, memory_ci=ci)
# After a trace is complete:
pairs = ste.generate_from_trace(trace)
# → Creates critic_calibration memories with good/bad examples
# Iterative: as the critic improves, it generates harder pairs
for iteration in range(3):
for trace in recent_traces:
ste.generate_from_trace(trace)
"""
def __init__(
self,
llm: LLMBackend,
memory_ci: MemoryCI,
min_delta_for_training: float = 0.5,
):
self.llm = llm
self.memory_ci = memory_ci
self.min_delta = min_delta_for_training
self._pairs_generated = 0
def generate_from_trace(self, trace: Trace) -> int:
"""
Generate contrast pairs from a trace's score events.
Returns number of pairs generated.
"""
count = 0
score_events = [e for e in trace.events if e.kind == "score"]
for event in score_events:
data = event.data
phi_before = data.get("phi_before", 0)
phi_after = data.get("phi_after", 0)
delta = phi_after - phi_before
# Only generate pairs for meaningful transitions
if abs(delta) < self.min_delta:
continue
try:
pair = self._generate_contrast_pair(
state_before=data.get("state_before", ""),
action=data.get("action_name", ""),
state_after=data.get("state_after", ""),
purpose=trace.purpose,
phi_before=phi_before,
phi_after=phi_after,
reasoning=data.get("reasoning", ""),
)
if pair:
self._store_pair(pair, trace.trace_id)
count += 1
except Exception as e:
logger.warning(f"SelfTaught: Failed to generate pair: {e}")
self._pairs_generated += count
logger.info(f"SelfTaught: Generated {count} contrast pairs from trace {trace.trace_id}")
return count
def _generate_contrast_pair(
self,
state_before: str,
action: str,
state_after: str,
purpose: str,
phi_before: float,
phi_after: float,
reasoning: str,
) -> dict[str, Any] | None:
"""Generate a single (correct, wrong) evaluation pair."""
messages = [
ChatMessage(role="system", content="Generate a plausible but incorrect evaluation for training."),
ChatMessage(role="user", content=GENERATE_CONTRAST_PROMPT.format(
state_before=state_before[:200],
action=action,
state_after=state_after[:200],
purpose=purpose,
phi_before=phi_before,
phi_after=phi_after,
reasoning=reasoning[:200],
)),
]
try:
result = self.llm.generate_structured(messages, schema=CONTRAST_SCHEMA)
except Exception:
raw = self.llm.generate(messages, temperature=0.7)
try:
result = json.loads(raw)
except Exception:
return None
return {
"correct_phi_after": phi_after,
"correct_reasoning": reasoning,
"wrong_phi_after": result.get("wrong_phi_after", phi_after + 2),
"wrong_reasoning": result.get("wrong_reasoning", ""),
"mistake_type": result.get("mistake_type", "unknown"),
}
def _store_pair(self, pair: dict, trace_id: str) -> None:
"""Store a contrast pair as calibration memories."""
# Positive example
self.memory_ci.submit(MemoryCard(
kind=MemoryKind.CRITIC_CALIBRATION,
content=(
f"CORRECT scoring example: Φ_after={pair['correct_phi_after']:.1f}. "
f"Reasoning: {pair['correct_reasoning'][:200]}"
),
pattern="When evaluating state transitions",
strategy=f"Score like this: {pair['correct_reasoning'][:150]}",
trust_score=0.7,
source_trace_id=trace_id,
created_by="self_taught",
))
# Negative example (what NOT to do)
self.memory_ci.submit(MemoryCard(
kind=MemoryKind.FAILURE_PATTERN,
content=(
f"WRONG scoring example ({pair['mistake_type']}): "
f"Incorrectly scored Φ_after={pair['wrong_phi_after']:.1f}. "
f"Flawed reasoning: {pair['wrong_reasoning'][:200]}"
),
pattern="When evaluating state transitions",
strategy=f"AVOID this mistake: {pair['mistake_type']}",
trust_score=0.7,
source_trace_id=trace_id,
created_by="self_taught",
scope=MemoryScope(agent_roles=["critic"]),
))
@property
def pairs_generated(self) -> int:
return self._pairs_generated
|