| """ |
| self_taught.py — Synthetic training data for the Purpose Function. |
| |
| From Self-Taught Evaluators (arxiv:2408.02666): |
| Generate synthetic preference pairs (good vs bad state evaluations) |
| from the agent's own traces, then use them to improve the Purpose |
| Function's prompts without any human labels. |
| |
| Adaptation for Purpose Agent (no weight updates): |
| 1. Take a completed trace with Φ scores |
| 2. For each step, generate a "worse" evaluation (modified instruction trick) |
| 3. The correct evaluation becomes a positive example |
| 4. The worse evaluation becomes a negative example |
| 5. Store both as critic_calibration memories |
| 6. The Purpose Function improves via in-context learning from these examples |
| |
| This is an automatic curriculum: as the Purpose Function improves, |
| it generates harder training pairs, which further improve it. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import logging |
| from typing import Any |
|
|
| from purpose_agent.llm_backend import LLMBackend, ChatMessage |
| from purpose_agent.trace import Trace |
| from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus |
| from purpose_agent.v2_types import MemoryScope |
| from purpose_agent.memory_ci import MemoryCI |
|
|
| logger = logging.getLogger(__name__) |
|
|
| GENERATE_CONTRAST_PROMPT = """\ |
| You are generating training data for a state evaluator (critic). |
| |
| Given this CORRECT evaluation of a state transition: |
| State before: {state_before} |
| Action: {action} |
| State after: {state_after} |
| Purpose: {purpose} |
| Correct Φ_before: {phi_before:.1f} |
| Correct Φ_after: {phi_after:.1f} |
| Correct reasoning: {reasoning} |
| |
| Generate a PLAUSIBLE BUT WRONG evaluation that makes a common mistake. |
| Common mistakes: |
| - Giving credit for intentions rather than actual state changes |
| - Inflating scores to be encouraging (sycophancy) |
| - Ignoring evidence and scoring based on action name alone |
| - Being inconsistent with the scoring scale |
| |
| Respond with JSON: |
| {{ |
| "wrong_phi_after": <a plausible but incorrect score>, |
| "wrong_reasoning": "<plausible but flawed reasoning>", |
| "mistake_type": "<which common mistake this represents>" |
| }} |
| """ |
|
|
| CONTRAST_SCHEMA = { |
| "type": "object", |
| "properties": { |
| "wrong_phi_after": {"type": "number"}, |
| "wrong_reasoning": {"type": "string"}, |
| "mistake_type": {"type": "string"}, |
| }, |
| "required": ["wrong_phi_after", "wrong_reasoning", "mistake_type"], |
| } |
|
|
|
|
| class SelfTaughtEvaluator: |
| """ |
| Generates synthetic training data for the Purpose Function from traces. |
| |
| Usage: |
| ste = SelfTaughtEvaluator(llm=model, memory_ci=ci) |
| |
| # After a trace is complete: |
| pairs = ste.generate_from_trace(trace) |
| # → Creates critic_calibration memories with good/bad examples |
| |
| # Iterative: as the critic improves, it generates harder pairs |
| for iteration in range(3): |
| for trace in recent_traces: |
| ste.generate_from_trace(trace) |
| """ |
|
|
| def __init__( |
| self, |
| llm: LLMBackend, |
| memory_ci: MemoryCI, |
| min_delta_for_training: float = 0.5, |
| ): |
| self.llm = llm |
| self.memory_ci = memory_ci |
| self.min_delta = min_delta_for_training |
| self._pairs_generated = 0 |
|
|
| def generate_from_trace(self, trace: Trace) -> int: |
| """ |
| Generate contrast pairs from a trace's score events. |
| |
| Returns number of pairs generated. |
| """ |
| count = 0 |
| score_events = [e for e in trace.events if e.kind == "score"] |
|
|
| for event in score_events: |
| data = event.data |
| phi_before = data.get("phi_before", 0) |
| phi_after = data.get("phi_after", 0) |
| delta = phi_after - phi_before |
|
|
| |
| if abs(delta) < self.min_delta: |
| continue |
|
|
| try: |
| pair = self._generate_contrast_pair( |
| state_before=data.get("state_before", ""), |
| action=data.get("action_name", ""), |
| state_after=data.get("state_after", ""), |
| purpose=trace.purpose, |
| phi_before=phi_before, |
| phi_after=phi_after, |
| reasoning=data.get("reasoning", ""), |
| ) |
| if pair: |
| self._store_pair(pair, trace.trace_id) |
| count += 1 |
| except Exception as e: |
| logger.warning(f"SelfTaught: Failed to generate pair: {e}") |
|
|
| self._pairs_generated += count |
| logger.info(f"SelfTaught: Generated {count} contrast pairs from trace {trace.trace_id}") |
| return count |
|
|
| def _generate_contrast_pair( |
| self, |
| state_before: str, |
| action: str, |
| state_after: str, |
| purpose: str, |
| phi_before: float, |
| phi_after: float, |
| reasoning: str, |
| ) -> dict[str, Any] | None: |
| """Generate a single (correct, wrong) evaluation pair.""" |
| messages = [ |
| ChatMessage(role="system", content="Generate a plausible but incorrect evaluation for training."), |
| ChatMessage(role="user", content=GENERATE_CONTRAST_PROMPT.format( |
| state_before=state_before[:200], |
| action=action, |
| state_after=state_after[:200], |
| purpose=purpose, |
| phi_before=phi_before, |
| phi_after=phi_after, |
| reasoning=reasoning[:200], |
| )), |
| ] |
|
|
| try: |
| result = self.llm.generate_structured(messages, schema=CONTRAST_SCHEMA) |
| except Exception: |
| raw = self.llm.generate(messages, temperature=0.7) |
| try: |
| result = json.loads(raw) |
| except Exception: |
| return None |
|
|
| return { |
| "correct_phi_after": phi_after, |
| "correct_reasoning": reasoning, |
| "wrong_phi_after": result.get("wrong_phi_after", phi_after + 2), |
| "wrong_reasoning": result.get("wrong_reasoning", ""), |
| "mistake_type": result.get("mistake_type", "unknown"), |
| } |
|
|
| def _store_pair(self, pair: dict, trace_id: str) -> None: |
| """Store a contrast pair as calibration memories.""" |
| |
| self.memory_ci.submit(MemoryCard( |
| kind=MemoryKind.CRITIC_CALIBRATION, |
| content=( |
| f"CORRECT scoring example: Φ_after={pair['correct_phi_after']:.1f}. " |
| f"Reasoning: {pair['correct_reasoning'][:200]}" |
| ), |
| pattern="When evaluating state transitions", |
| strategy=f"Score like this: {pair['correct_reasoning'][:150]}", |
| trust_score=0.7, |
| source_trace_id=trace_id, |
| created_by="self_taught", |
| )) |
|
|
| |
| self.memory_ci.submit(MemoryCard( |
| kind=MemoryKind.FAILURE_PATTERN, |
| content=( |
| f"WRONG scoring example ({pair['mistake_type']}): " |
| f"Incorrectly scored Φ_after={pair['wrong_phi_after']:.1f}. " |
| f"Flawed reasoning: {pair['wrong_reasoning'][:200]}" |
| ), |
| pattern="When evaluating state transitions", |
| strategy=f"AVOID this mistake: {pair['mistake_type']}", |
| trust_score=0.7, |
| source_trace_id=trace_id, |
| created_by="self_taught", |
| scope=MemoryScope(agent_roles=["critic"]), |
| )) |
|
|
| @property |
| def pairs_generated(self) -> int: |
| return self._pairs_generated |
|
|