Rohan03 commited on
Commit
fedfb2e
·
verified ·
1 Parent(s): 3110b12

Add purpose_agent/purpose_function.py

Browse files
Files changed (1) hide show
  1. purpose_agent/purpose_function.py +431 -0
purpose_agent/purpose_function.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Purpose Function — The Critic / State Evaluator.
3
+
4
+ This is the core innovation: a strictly separated LLM call that evaluates
5
+ state improvement Φ(s). It rewards the agent ONLY if Φ(s_new) > Φ(s_current).
6
+
7
+ Design principles (from literature):
8
+ 1. Score AFTER environment feedback, never from expected state alone (LATS)
9
+ 2. Require specific observable state changes as evidence (SPC anti-hacking)
10
+ 3. Use separate LLM call / separate system prompt from the Actor (MUSE)
11
+ 4. Normalize scores to prevent inflation over trajectory (novel addition)
12
+ 5. V(s) = λ·LM_score + (1-λ)·consistency_score (LATS formulation)
13
+
14
+ The Purpose Function is intentionally "non-hackable" by design:
15
+ - It sees the ACTUAL new state, not the Actor's prediction
16
+ - It must cite specific evidence for every score
17
+ - Scores are bounded and normalized
18
+ - The system prompt explicitly guards against sycophancy and vague reasoning
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import json
24
+ import logging
25
+ from typing import Any
26
+
27
+ from purpose_agent.types import Action, PurposeScore, State
28
+ from purpose_agent.llm_backend import ChatMessage, LLMBackend
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Purpose Function System Prompt — The "Non-Hackable Judge"
35
+ # ---------------------------------------------------------------------------
36
+
37
+ PURPOSE_FUNCTION_SYSTEM_PROMPT = """\
38
+ You are a STATE EVALUATOR — a strict, impartial judge of progress toward a goal.
39
+ You are NOT the agent. You do NOT help the agent. You ONLY measure progress.
40
+
41
+ ## Your Role
42
+ Given a state transition (state_before → action → state_after) and an ultimate purpose,
43
+ you compute two scores:
44
+ - Φ(state_before): How far the OLD state was from the purpose (0.0 = no progress, 10.0 = goal achieved)
45
+ - Φ(state_after): How far the NEW state is from the purpose (same scale)
46
+
47
+ The delta Φ(state_after) - Φ(state_before) is the ONLY signal the agent receives.
48
+
49
+ ## STRICT RULES — Violation of any rule invalidates your evaluation
50
+
51
+ 1. **EVIDENCE REQUIRED**: Every score MUST cite a specific, observable change in the
52
+ state data. "The state improved" is NOT evidence. "Field 'score' changed from 3 to 7"
53
+ IS evidence. If you cannot cite a specific change, the delta MUST be 0.0.
54
+
55
+ 2. **NO CREDIT FOR INTENTIONS**: The agent's "thought" and "expected_delta" are
56
+ provided for context only. You score based on ACTUAL state changes, never on
57
+ what the agent intended or claimed would happen.
58
+
59
+ 3. **NO SYCOPHANCY**: You are not the agent's friend. Do not inflate scores to be
60
+ encouraging. A lateral move (no improvement) gets delta = 0.0. A regression gets
61
+ negative delta. Be precise.
62
+
63
+ 4. **MONOTONIC SCALE**: Φ = 0.0 means the state has zero progress toward the purpose.
64
+ Φ = 10.0 means the purpose is fully achieved. Intermediate values are proportional.
65
+ Justify WHY you chose each specific value.
66
+
67
+ 5. **ANTI-GAMING**: If the action appears to manipulate the state in a way that
68
+ superficially looks like progress but doesn't genuinely advance the purpose
69
+ (e.g., changing a label without doing the work), score it as delta = 0.0 or negative
70
+ and flag it in your evidence field.
71
+
72
+ 6. **CONSISTENCY**: If a state identical to one you scored before appears again,
73
+ it MUST receive the same Φ score. Progress is objective, not relative to your mood.
74
+
75
+ 7. **CONFIDENCE**: Rate your confidence 0.0–1.0. High confidence (>0.8) requires
76
+ clear, unambiguous evidence. If the state change is ambiguous, lower your confidence.
77
+
78
+ ## Scoring Guide
79
+ - Φ = 0.0: No meaningful progress toward the purpose
80
+ - Φ = 1.0–3.0: Initial setup/preparation steps completed
81
+ - Φ = 4.0–6.0: Substantive progress, key sub-goals partially achieved
82
+ - Φ = 7.0–8.0: Most of the purpose is achieved, final steps remaining
83
+ - Φ = 9.0: Purpose essentially achieved with minor polish needed
84
+ - Φ = 10.0: Purpose fully and completely achieved
85
+ """
86
+
87
+
88
+ PURPOSE_FUNCTION_EVAL_PROMPT = """\
89
+ ## Ultimate Purpose
90
+ {purpose}
91
+
92
+ ## State BEFORE Action
93
+ {state_before}
94
+
95
+ ## Action Taken
96
+ Name: {action_name}
97
+ Parameters: {action_params}
98
+ Agent's Thought: {action_thought}
99
+ Agent's Prediction: {expected_delta}
100
+
101
+ ## State AFTER Action (this is the ACTUAL result — score based on THIS)
102
+ {state_after}
103
+
104
+ Evaluate this state transition. Remember:
105
+ - Score Φ(state_before) and Φ(state_after) on the 0.0–10.0 scale
106
+ - Cite SPECIFIC evidence from the state data
107
+ - Do NOT give credit for intentions — only actual changes
108
+ """
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Purpose Function Schema (for structured output)
113
+ # ---------------------------------------------------------------------------
114
+
115
+ PURPOSE_SCORE_SCHEMA: dict[str, Any] = {
116
+ "type": "object",
117
+ "properties": {
118
+ "phi_before": {
119
+ "type": "number",
120
+ "minimum": 0.0,
121
+ "maximum": 10.0,
122
+ "description": "Φ(state_before) — distance-to-purpose of the state before the action",
123
+ },
124
+ "phi_after": {
125
+ "type": "number",
126
+ "minimum": 0.0,
127
+ "maximum": 10.0,
128
+ "description": "Φ(state_after) — distance-to-purpose of the state after the action",
129
+ },
130
+ "reasoning": {
131
+ "type": "string",
132
+ "description": "Step-by-step justification for both scores (max 200 words)",
133
+ },
134
+ "evidence": {
135
+ "type": "string",
136
+ "description": "Specific observable state changes that justify the delta (REQUIRED)",
137
+ },
138
+ "confidence": {
139
+ "type": "number",
140
+ "minimum": 0.0,
141
+ "maximum": 1.0,
142
+ "description": "Confidence in this evaluation (0.0 = pure guess, 1.0 = certain)",
143
+ },
144
+ },
145
+ "required": ["phi_before", "phi_after", "reasoning", "evidence", "confidence"],
146
+ }
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Purpose Function Class
151
+ # ---------------------------------------------------------------------------
152
+
153
+ class PurposeFunction:
154
+ """
155
+ The Critic — evaluates state transitions via Φ(s) scoring.
156
+
157
+ Uses a SEPARATE LLM call from the Actor to prevent self-confirmation bias
158
+ (per MUSE's Reflect Agent design, arxiv:2510.08002).
159
+
160
+ Can optionally use a different model than the Actor (recommended for
161
+ production — use a stronger model as the critic).
162
+
163
+ Args:
164
+ llm: LLM backend (can be same or different from Actor's)
165
+ score_cache_size: Max entries in the Φ score cache (for consistency)
166
+ require_evidence: If True, reject scores with empty evidence
167
+ min_confidence: Minimum confidence threshold — below this, score is discarded
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ llm: LLMBackend,
173
+ score_cache_size: int = 1000,
174
+ require_evidence: bool = True,
175
+ min_confidence: float = 0.3,
176
+ ):
177
+ self.llm = llm
178
+ self.require_evidence = require_evidence
179
+ self.min_confidence = min_confidence
180
+ # Cache: state_hash → Φ score (for consistency rule #6)
181
+ self._phi_cache: dict[str, float] = {}
182
+ self._cache_size = score_cache_size
183
+ # Running stats for normalization
184
+ self._score_history: list[float] = []
185
+
186
+ # ------------------------------------------------------------------
187
+ # Core Evaluation
188
+ # ------------------------------------------------------------------
189
+
190
+ def evaluate(
191
+ self,
192
+ state_before: State,
193
+ action: Action,
194
+ state_after: State,
195
+ purpose: str,
196
+ ) -> PurposeScore:
197
+ """
198
+ Evaluate a state transition: did the action move closer to the purpose?
199
+
200
+ Returns a PurposeScore with phi_before, phi_after, delta, reasoning,
201
+ evidence, and confidence.
202
+
203
+ Anti-hacking measures:
204
+ 1. Scores based on ACTUAL state_after (not actor's expected_delta)
205
+ 2. Evidence is required — vague scores are rejected
206
+ 3. Cached Φ values enforce consistency
207
+ 4. Confidence threshold filters uncertain evaluations
208
+ """
209
+ # Check cache for consistency (Rule #6)
210
+ cached_before = self._get_cached_phi(state_before)
211
+ cached_after = self._get_cached_phi(state_after)
212
+
213
+ # Build evaluation prompt
214
+ messages = [
215
+ ChatMessage(role="system", content=PURPOSE_FUNCTION_SYSTEM_PROMPT),
216
+ ChatMessage(role="user", content=PURPOSE_FUNCTION_EVAL_PROMPT.format(
217
+ purpose=purpose,
218
+ state_before=state_before.describe(),
219
+ state_after=state_after.describe(),
220
+ action_name=action.name,
221
+ action_params=json.dumps(action.params, default=str),
222
+ action_thought=action.thought,
223
+ expected_delta=action.expected_delta,
224
+ )),
225
+ ]
226
+
227
+ # Get structured evaluation from LLM
228
+ try:
229
+ raw_score = self.llm.generate_structured(
230
+ messages, schema=PURPOSE_SCORE_SCHEMA, temperature=0.2
231
+ )
232
+ except Exception as e:
233
+ logger.error(f"Purpose Function structured output failed: {e}")
234
+ # Fall back to text-based evaluation
235
+ raw_score = self._fallback_evaluate(messages)
236
+
237
+ # Extract and validate scores
238
+ phi_before = float(raw_score.get("phi_before", 0.0))
239
+ phi_after = float(raw_score.get("phi_after", 0.0))
240
+ reasoning = str(raw_score.get("reasoning", ""))
241
+ evidence = str(raw_score.get("evidence", ""))
242
+ confidence = float(raw_score.get("confidence", 0.5))
243
+
244
+ # Clamp to valid range
245
+ phi_before = max(0.0, min(10.0, phi_before))
246
+ phi_after = max(0.0, min(10.0, phi_after))
247
+ confidence = max(0.0, min(1.0, confidence))
248
+
249
+ # Apply anti-hacking rules
250
+ phi_before, phi_after, confidence = self._apply_safeguards(
251
+ phi_before, phi_after, evidence, confidence,
252
+ cached_before, cached_after,
253
+ )
254
+
255
+ delta = phi_after - phi_before
256
+
257
+ # Update caches
258
+ self._cache_phi(state_before, phi_before)
259
+ self._cache_phi(state_after, phi_after)
260
+ self._score_history.append(phi_after)
261
+
262
+ score = PurposeScore(
263
+ phi_before=phi_before,
264
+ phi_after=phi_after,
265
+ delta=delta,
266
+ reasoning=reasoning,
267
+ evidence=evidence,
268
+ confidence=confidence,
269
+ )
270
+
271
+ logger.info(
272
+ f"Purpose Function: Φ({phi_before:.1f}) → Φ({phi_after:.1f}), "
273
+ f"Δ={delta:+.2f}, conf={confidence:.2f}, improved={score.improved}"
274
+ )
275
+ return score
276
+
277
+ # ------------------------------------------------------------------
278
+ # Anti-Hacking Safeguards
279
+ # ------------------------------------------------------------------
280
+
281
+ def _apply_safeguards(
282
+ self,
283
+ phi_before: float,
284
+ phi_after: float,
285
+ evidence: str,
286
+ confidence: float,
287
+ cached_before: float | None,
288
+ cached_after: float | None,
289
+ ) -> tuple[float, float, float]:
290
+ """
291
+ Apply anti-reward-hacking safeguards.
292
+
293
+ 1. Evidence requirement: no evidence → delta forced to 0
294
+ 2. Cache consistency: if we've scored this state before, use cached value
295
+ 3. Confidence threshold: low confidence → reduce delta magnitude
296
+ 4. Anomaly detection: suspiciously large jumps get confidence penalty
297
+ """
298
+ # Rule 1: Require evidence
299
+ if self.require_evidence and len(evidence.strip()) < 10:
300
+ logger.warning("Purpose Function: Insufficient evidence, forcing delta=0")
301
+ phi_after = phi_before # No credit without evidence
302
+ confidence = max(confidence, 0.1)
303
+
304
+ # Rule 2: Cache consistency (allow small drift for scoring noise)
305
+ if cached_before is not None:
306
+ drift = abs(phi_before - cached_before)
307
+ if drift > 1.0:
308
+ logger.warning(
309
+ f"Purpose Function: Inconsistent Φ_before "
310
+ f"(new={phi_before:.1f}, cached={cached_before:.1f}), "
311
+ f"using cached value"
312
+ )
313
+ phi_before = cached_before
314
+
315
+ if cached_after is not None:
316
+ drift = abs(phi_after - cached_after)
317
+ if drift > 1.0:
318
+ logger.warning(
319
+ f"Purpose Function: Inconsistent Φ_after "
320
+ f"(new={phi_after:.1f}, cached={cached_after:.1f}), "
321
+ f"using cached value"
322
+ )
323
+ phi_after = cached_after
324
+
325
+ # Rule 3: Confidence threshold
326
+ if confidence < self.min_confidence:
327
+ logger.warning(
328
+ f"Purpose Function: Low confidence ({confidence:.2f}), "
329
+ f"reducing delta magnitude by 50%"
330
+ )
331
+ midpoint = (phi_before + phi_after) / 2
332
+ phi_after = midpoint + (phi_after - midpoint) * 0.5
333
+
334
+ # Rule 4: Anomaly detection — flag suspiciously large single-step jumps
335
+ delta = phi_after - phi_before
336
+ if abs(delta) > 3.0:
337
+ logger.warning(
338
+ f"Purpose Function: Unusually large delta ({delta:+.1f}), "
339
+ f"applying confidence penalty"
340
+ )
341
+ confidence = min(confidence, 0.5)
342
+
343
+ return phi_before, phi_after, confidence
344
+
345
+ # ------------------------------------------------------------------
346
+ # Caching
347
+ # ------------------------------------------------------------------
348
+
349
+ def _state_hash(self, state: State) -> str:
350
+ """Hash a state for cache lookup (based on data content)."""
351
+ return json.dumps(state.data, sort_keys=True, default=str)
352
+
353
+ def _get_cached_phi(self, state: State) -> float | None:
354
+ return self._phi_cache.get(self._state_hash(state))
355
+
356
+ def _cache_phi(self, state: State, phi: float) -> None:
357
+ key = self._state_hash(state)
358
+ if len(self._phi_cache) >= self._cache_size:
359
+ # Evict oldest (FIFO — good enough for our use case)
360
+ oldest_key = next(iter(self._phi_cache))
361
+ del self._phi_cache[oldest_key]
362
+ self._phi_cache[key] = phi
363
+
364
+ # ------------------------------------------------------------------
365
+ # Normalization (prevent score inflation over long trajectories)
366
+ # ------------------------------------------------------------------
367
+
368
+ def get_normalized_phi(self, raw_phi: float) -> float:
369
+ """
370
+ Normalize a Φ score relative to the trajectory's score distribution.
371
+
372
+ Prevents the common failure mode where LLM scores drift upward over
373
+ a trajectory regardless of actual progress.
374
+ """
375
+ if len(self._score_history) < 3:
376
+ return raw_phi
377
+
378
+ mean = sum(self._score_history) / len(self._score_history)
379
+ variance = sum((x - mean) ** 2 for x in self._score_history) / len(self._score_history)
380
+ std = max(variance ** 0.5, 0.1) # Avoid division by zero
381
+
382
+ # Z-score normalization mapped back to 0-10
383
+ z = (raw_phi - mean) / std
384
+ normalized = 5.0 + z * 2.0 # Center at 5, spread by 2
385
+ return max(0.0, min(10.0, normalized))
386
+
387
+ def reset_trajectory_stats(self) -> None:
388
+ """Reset per-trajectory normalization stats. Call at trajectory start."""
389
+ self._score_history = []
390
+
391
+ # ------------------------------------------------------------------
392
+ # Fallback
393
+ # ------------------------------------------------------------------
394
+
395
+ def _fallback_evaluate(self, messages: list[ChatMessage]) -> dict[str, Any]:
396
+ """Text-based fallback when structured output is unavailable."""
397
+ raw = self.llm.generate(messages, temperature=0.2)
398
+
399
+ import re
400
+
401
+ phi_before = 0.0
402
+ phi_after = 0.0
403
+
404
+ # Try to extract scores from text
405
+ before_match = re.search(r'[Φφ]\s*\(?state_?before\)?\s*[=:]\s*([\d.]+)', raw, re.IGNORECASE)
406
+ after_match = re.search(r'[Φφ]\s*\(?state_?after\)?\s*[=:]\s*([\d.]+)', raw, re.IGNORECASE)
407
+
408
+ if before_match:
409
+ phi_before = float(before_match.group(1))
410
+ if after_match:
411
+ phi_after = float(after_match.group(1))
412
+
413
+ # Also try "Score: X/10" patterns
414
+ if not before_match:
415
+ score_matches = re.findall(r'(\d+\.?\d*)\s*/?\s*10', raw)
416
+ if len(score_matches) >= 2:
417
+ phi_before = float(score_matches[0])
418
+ phi_after = float(score_matches[1])
419
+ elif len(score_matches) == 1:
420
+ phi_after = float(score_matches[0])
421
+
422
+ confidence_match = re.search(r'confidence\s*[=:]\s*([\d.]+)', raw, re.IGNORECASE)
423
+ confidence = float(confidence_match.group(1)) if confidence_match else 0.4
424
+
425
+ return {
426
+ "phi_before": phi_before,
427
+ "phi_after": phi_after,
428
+ "reasoning": raw[:500],
429
+ "evidence": raw[500:800] if len(raw) > 500 else "",
430
+ "confidence": confidence,
431
+ }