purpose-agent / purpose_agent /self_taught.py
Rohan03's picture
V2 merge: purpose_agent/self_taught.py
28b87a7 verified
"""
self_taught.py — Synthetic training data for the Purpose Function.
From Self-Taught Evaluators (arxiv:2408.02666):
Generate synthetic preference pairs (good vs bad state evaluations)
from the agent's own traces, then use them to improve the Purpose
Function's prompts without any human labels.
Adaptation for Purpose Agent (no weight updates):
1. Take a completed trace with Φ scores
2. For each step, generate a "worse" evaluation (modified instruction trick)
3. The correct evaluation becomes a positive example
4. The worse evaluation becomes a negative example
5. Store both as critic_calibration memories
6. The Purpose Function improves via in-context learning from these examples
This is an automatic curriculum: as the Purpose Function improves,
it generates harder training pairs, which further improve it.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from purpose_agent.llm_backend import LLMBackend, ChatMessage
from purpose_agent.trace import Trace
from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus
from purpose_agent.v2_types import MemoryScope
from purpose_agent.memory_ci import MemoryCI
logger = logging.getLogger(__name__)
GENERATE_CONTRAST_PROMPT = """\
You are generating training data for a state evaluator (critic).
Given this CORRECT evaluation of a state transition:
State before: {state_before}
Action: {action}
State after: {state_after}
Purpose: {purpose}
Correct Φ_before: {phi_before:.1f}
Correct Φ_after: {phi_after:.1f}
Correct reasoning: {reasoning}
Generate a PLAUSIBLE BUT WRONG evaluation that makes a common mistake.
Common mistakes:
- Giving credit for intentions rather than actual state changes
- Inflating scores to be encouraging (sycophancy)
- Ignoring evidence and scoring based on action name alone
- Being inconsistent with the scoring scale
Respond with JSON:
{{
"wrong_phi_after": <a plausible but incorrect score>,
"wrong_reasoning": "<plausible but flawed reasoning>",
"mistake_type": "<which common mistake this represents>"
}}
"""
CONTRAST_SCHEMA = {
"type": "object",
"properties": {
"wrong_phi_after": {"type": "number"},
"wrong_reasoning": {"type": "string"},
"mistake_type": {"type": "string"},
},
"required": ["wrong_phi_after", "wrong_reasoning", "mistake_type"],
}
class SelfTaughtEvaluator:
"""
Generates synthetic training data for the Purpose Function from traces.
Usage:
ste = SelfTaughtEvaluator(llm=model, memory_ci=ci)
# After a trace is complete:
pairs = ste.generate_from_trace(trace)
# → Creates critic_calibration memories with good/bad examples
# Iterative: as the critic improves, it generates harder pairs
for iteration in range(3):
for trace in recent_traces:
ste.generate_from_trace(trace)
"""
def __init__(
self,
llm: LLMBackend,
memory_ci: MemoryCI,
min_delta_for_training: float = 0.5,
):
self.llm = llm
self.memory_ci = memory_ci
self.min_delta = min_delta_for_training
self._pairs_generated = 0
def generate_from_trace(self, trace: Trace) -> int:
"""
Generate contrast pairs from a trace's score events.
Returns number of pairs generated.
"""
count = 0
score_events = [e for e in trace.events if e.kind == "score"]
for event in score_events:
data = event.data
phi_before = data.get("phi_before", 0)
phi_after = data.get("phi_after", 0)
delta = phi_after - phi_before
# Only generate pairs for meaningful transitions
if abs(delta) < self.min_delta:
continue
try:
pair = self._generate_contrast_pair(
state_before=data.get("state_before", ""),
action=data.get("action_name", ""),
state_after=data.get("state_after", ""),
purpose=trace.purpose,
phi_before=phi_before,
phi_after=phi_after,
reasoning=data.get("reasoning", ""),
)
if pair:
self._store_pair(pair, trace.trace_id)
count += 1
except Exception as e:
logger.warning(f"SelfTaught: Failed to generate pair: {e}")
self._pairs_generated += count
logger.info(f"SelfTaught: Generated {count} contrast pairs from trace {trace.trace_id}")
return count
def _generate_contrast_pair(
self,
state_before: str,
action: str,
state_after: str,
purpose: str,
phi_before: float,
phi_after: float,
reasoning: str,
) -> dict[str, Any] | None:
"""Generate a single (correct, wrong) evaluation pair."""
messages = [
ChatMessage(role="system", content="Generate a plausible but incorrect evaluation for training."),
ChatMessage(role="user", content=GENERATE_CONTRAST_PROMPT.format(
state_before=state_before[:200],
action=action,
state_after=state_after[:200],
purpose=purpose,
phi_before=phi_before,
phi_after=phi_after,
reasoning=reasoning[:200],
)),
]
try:
result = self.llm.generate_structured(messages, schema=CONTRAST_SCHEMA)
except Exception:
raw = self.llm.generate(messages, temperature=0.7)
try:
result = json.loads(raw)
except Exception:
return None
return {
"correct_phi_after": phi_after,
"correct_reasoning": reasoning,
"wrong_phi_after": result.get("wrong_phi_after", phi_after + 2),
"wrong_reasoning": result.get("wrong_reasoning", ""),
"mistake_type": result.get("mistake_type", "unknown"),
}
def _store_pair(self, pair: dict, trace_id: str) -> None:
"""Store a contrast pair as calibration memories."""
# Positive example
self.memory_ci.submit(MemoryCard(
kind=MemoryKind.CRITIC_CALIBRATION,
content=(
f"CORRECT scoring example: Φ_after={pair['correct_phi_after']:.1f}. "
f"Reasoning: {pair['correct_reasoning'][:200]}"
),
pattern="When evaluating state transitions",
strategy=f"Score like this: {pair['correct_reasoning'][:150]}",
trust_score=0.7,
source_trace_id=trace_id,
created_by="self_taught",
))
# Negative example (what NOT to do)
self.memory_ci.submit(MemoryCard(
kind=MemoryKind.FAILURE_PATTERN,
content=(
f"WRONG scoring example ({pair['mistake_type']}): "
f"Incorrectly scored Φ_after={pair['wrong_phi_after']:.1f}. "
f"Flawed reasoning: {pair['wrong_reasoning'][:200]}"
),
pattern="When evaluating state transitions",
strategy=f"AVOID this mistake: {pair['mistake_type']}",
trust_score=0.7,
source_trace_id=trace_id,
created_by="self_taught",
scope=MemoryScope(agent_roles=["critic"]),
))
@property
def pairs_generated(self) -> int:
return self._pairs_generated