File size: 18,296 Bytes
fedfb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7046c0a
 
fedfb2e
 
 
 
7046c0a
 
 
 
 
 
 
 
 
 
 
fedfb2e
 
7046c0a
fedfb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22fb57a
fedfb2e
 
 
22fb57a
 
 
 
 
 
 
fedfb2e
 
 
22fb57a
 
 
 
 
 
 
 
 
fedfb2e
 
 
 
 
22fb57a
fedfb2e
22fb57a
fedfb2e
22fb57a
 
 
fedfb2e
22fb57a
 
fedfb2e
22fb57a
 
 
 
fedfb2e
 
22fb57a
fedfb2e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
"""
Purpose Function — The Critic / State Evaluator.

This is the core innovation: a strictly separated LLM call that evaluates
state improvement Φ(s). It rewards the agent ONLY if Φ(s_new) > Φ(s_current).

Design principles (from literature):
  1. Score AFTER environment feedback, never from expected state alone (LATS)
  2. Require specific observable state changes as evidence (SPC anti-hacking)
  3. Use separate LLM call / separate system prompt from the Actor (MUSE)
  4. Normalize scores to prevent inflation over trajectory (novel addition)
  5. V(s) = λ·LM_score + (1-λ)·consistency_score (LATS formulation)

The Purpose Function is intentionally "non-hackable" by design:
  - It sees the ACTUAL new state, not the Actor's prediction
  - It must cite specific evidence for every score
  - Scores are bounded and normalized
  - The system prompt explicitly guards against sycophancy and vague reasoning
"""

from __future__ import annotations

import json
import logging
from typing import Any

from purpose_agent.types import Action, PurposeScore, State
from purpose_agent.llm_backend import ChatMessage, LLMBackend

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Purpose Function System Prompt — The "Non-Hackable Judge"
# ---------------------------------------------------------------------------

PURPOSE_FUNCTION_SYSTEM_PROMPT = """\
You are a STATE EVALUATOR — a strict, impartial judge of progress toward a goal.
You are NOT the agent. You do NOT help the agent. You ONLY measure progress.

## Your Role
Given a state transition (state_before → action → state_after) and an ultimate purpose,
you compute two scores:
  - Φ(state_before): How far the OLD state was from the purpose (0.0 = no progress, 10.0 = goal achieved)
  - Φ(state_after):  How far the NEW state is from the purpose (same scale)

The delta Φ(state_after) - Φ(state_before) is the ONLY signal the agent receives.

## STRICT RULES — Violation of any rule invalidates your evaluation

1. **EVIDENCE REQUIRED**: Every score MUST cite a specific, observable change in the
   state data. "The state improved" is NOT evidence. "Field 'score' changed from 3 to 7"
   IS evidence. If you cannot cite a specific change, the delta MUST be 0.0.

2. **NO CREDIT FOR INTENTIONS**: The agent's "thought" and "expected_delta" are
   provided for context only. You score based on ACTUAL state changes, never on
   what the agent intended or claimed would happen.

3. **NO SYCOPHANCY**: You are not the agent's friend. Do not inflate scores to be
   encouraging. A lateral move (no improvement) gets delta = 0.0. A regression gets
   negative delta. Be precise.

4. **MONOTONIC SCALE**: Φ = 0.0 means the state has zero progress toward the purpose.
   Φ = 10.0 means the purpose is fully achieved. Intermediate values are proportional.
   Justify WHY you chose each specific value.

5. **ANTI-GAMING**: If the action appears to manipulate the state in a way that
   superficially looks like progress but doesn't genuinely advance the purpose
   (e.g., changing a label without doing the work), score it as delta = 0.0 or negative
   and flag it in your evidence field.

6. **CONSISTENCY**: If a state identical to one you scored before appears again,
   it MUST receive the same Φ score. Progress is objective, not relative to your mood.

7. **CONFIDENCE**: Rate your confidence 0.0–1.0. High confidence (>0.8) requires
   clear, unambiguous evidence. If the state change is ambiguous, lower your confidence.

## Scoring Guide
- Φ = 0.0: No meaningful progress toward the purpose
- Φ = 1.0–3.0: Initial setup/preparation steps completed
- Φ = 4.0–6.0: Substantive progress, key sub-goals partially achieved
- Φ = 7.0–8.0: Most of the purpose is achieved, final steps remaining
- Φ = 9.0: Purpose essentially achieved with minor polish needed
- Φ = 10.0: Purpose fully and completely achieved
"""


PURPOSE_FUNCTION_EVAL_PROMPT = """\
## Ultimate Purpose
{purpose}

## State BEFORE Action
{state_before}

## Action Taken
Name: {action_name}
Parameters: {action_params}
Agent's Thought: {action_thought}
Agent's Prediction: {expected_delta}

## State AFTER Action (this is the ACTUAL result — score based on THIS)
{state_after}

Evaluate this state transition. Remember:
- Score Φ(state_before) and Φ(state_after) on the 0.0–10.0 scale
- Cite SPECIFIC evidence from the state data
- Do NOT give credit for intentions — only actual changes
"""


# ---------------------------------------------------------------------------
# Purpose Function Schema (for structured output)
# ---------------------------------------------------------------------------

PURPOSE_SCORE_SCHEMA: dict[str, Any] = {
    "type": "object",
    "properties": {
        "phi_before": {
            "type": "number",
            "minimum": 0.0,
            "maximum": 10.0,
            "description": "Φ(state_before) — distance-to-purpose of the state before the action",
        },
        "phi_after": {
            "type": "number",
            "minimum": 0.0,
            "maximum": 10.0,
            "description": "Φ(state_after) — distance-to-purpose of the state after the action",
        },
        "reasoning": {
            "type": "string",
            "description": "Step-by-step justification for both scores (max 200 words)",
        },
        "evidence": {
            "type": "string",
            "description": "Specific observable state changes that justify the delta (REQUIRED)",
        },
        "confidence": {
            "type": "number",
            "minimum": 0.0,
            "maximum": 1.0,
            "description": "Confidence in this evaluation (0.0 = pure guess, 1.0 = certain)",
        },
    },
    "required": ["phi_before", "phi_after", "reasoning", "evidence", "confidence"],
}


# ---------------------------------------------------------------------------
# Purpose Function Class
# ---------------------------------------------------------------------------

class PurposeFunction:
    """
    The Critic — evaluates state transitions via Φ(s) scoring.
    
    Uses a SEPARATE LLM call from the Actor to prevent self-confirmation bias
    (per MUSE's Reflect Agent design, arxiv:2510.08002).
    
    Can optionally use a different model than the Actor (recommended for
    production — use a stronger model as the critic).
    
    Args:
        llm: LLM backend (can be same or different from Actor's)
        score_cache_size: Max entries in the Φ score cache (for consistency)
        require_evidence: If True, reject scores with empty evidence
        min_confidence: Minimum confidence threshold — below this, score is discarded
    """

    def __init__(
        self,
        llm: LLMBackend,
        score_cache_size: int = 1000,
        require_evidence: bool = True,
        min_confidence: float = 0.3,
    ):
        self.llm = llm
        self.require_evidence = require_evidence
        self.min_confidence = min_confidence
        # Cache: state_hash → Φ score (for consistency rule #6)
        self._phi_cache: dict[str, float] = {}
        self._cache_size = score_cache_size
        # Running stats for normalization
        self._score_history: list[float] = []

    # ------------------------------------------------------------------
    # Core Evaluation
    # ------------------------------------------------------------------

    def evaluate(
        self,
        state_before: State,
        action: Action,
        state_after: State,
        purpose: str,
    ) -> PurposeScore:
        """
        Evaluate a state transition: did the action move closer to the purpose?
        
        Returns a PurposeScore with phi_before, phi_after, delta, reasoning,
        evidence, and confidence.
        
        Anti-hacking measures:
        1. Scores based on ACTUAL state_after (not actor's expected_delta)
        2. Evidence is required — vague scores are rejected
        3. Cached Φ values enforce consistency
        4. Confidence threshold filters uncertain evaluations
        """
        # Check cache for consistency (Rule #6)
        cached_before = self._get_cached_phi(state_before)
        cached_after = self._get_cached_phi(state_after)

        # Build evaluation prompt
        messages = [
            ChatMessage(role="system", content=PURPOSE_FUNCTION_SYSTEM_PROMPT),
            ChatMessage(role="user", content=PURPOSE_FUNCTION_EVAL_PROMPT.format(
                purpose=purpose,
                state_before=state_before.describe(),
                state_after=state_after.describe(),
                action_name=action.name,
                action_params=json.dumps(action.params, default=str),
                action_thought=action.thought,
                expected_delta=action.expected_delta,
            )),
        ]

        # Get structured evaluation from LLM
        from purpose_agent.robust_parser import parse_critic_response

        try:
            raw_score = self.llm.generate_structured(
                messages, schema=PURPOSE_SCORE_SCHEMA, temperature=0.2
            )
        except Exception:
            # Structured output not available — use universal text parser
            raw = self.llm.generate(messages, temperature=0.2, max_tokens=2000)
            raw_score = parse_critic_response(raw)

        # Extract and validate scores (safe — parse_critic_response always returns valid keys)
        def _safe_float(v, d=0.0):
            try: return float(str(v).rstrip('.'))
            except (ValueError, TypeError): return d
        phi_before = _safe_float(raw_score.get("phi_before", 0.0))
        phi_after = _safe_float(raw_score.get("phi_after", 0.0))
        reasoning = str(raw_score.get("reasoning", ""))
        evidence = str(raw_score.get("evidence", ""))
        confidence = _safe_float(raw_score.get("confidence", 0.5))

        # Clamp to valid range
        phi_before = max(0.0, min(10.0, phi_before))
        phi_after = max(0.0, min(10.0, phi_after))
        confidence = max(0.0, min(1.0, confidence))

        # Apply anti-hacking rules
        phi_before, phi_after, confidence = self._apply_safeguards(
            phi_before, phi_after, evidence, confidence,
            cached_before, cached_after,
        )

        delta = phi_after - phi_before

        # Update caches
        self._cache_phi(state_before, phi_before)
        self._cache_phi(state_after, phi_after)
        self._score_history.append(phi_after)

        score = PurposeScore(
            phi_before=phi_before,
            phi_after=phi_after,
            delta=delta,
            reasoning=reasoning,
            evidence=evidence,
            confidence=confidence,
        )

        logger.info(
            f"Purpose Function: Φ({phi_before:.1f}) → Φ({phi_after:.1f}), "
            f"Δ={delta:+.2f}, conf={confidence:.2f}, improved={score.improved}"
        )
        return score

    # ------------------------------------------------------------------
    # Anti-Hacking Safeguards
    # ------------------------------------------------------------------

    def _apply_safeguards(
        self,
        phi_before: float,
        phi_after: float,
        evidence: str,
        confidence: float,
        cached_before: float | None,
        cached_after: float | None,
    ) -> tuple[float, float, float]:
        """
        Apply anti-reward-hacking safeguards.
        
        1. Evidence requirement: no evidence → delta forced to 0
        2. Cache consistency: if we've scored this state before, use cached value
        3. Confidence threshold: low confidence → reduce delta magnitude
        4. Anomaly detection: suspiciously large jumps get confidence penalty
        """
        # Rule 1: Require evidence
        if self.require_evidence and len(evidence.strip()) < 10:
            logger.warning("Purpose Function: Insufficient evidence, forcing delta=0")
            phi_after = phi_before  # No credit without evidence
            confidence = max(confidence, 0.1)

        # Rule 2: Cache consistency (allow small drift for scoring noise)
        if cached_before is not None:
            drift = abs(phi_before - cached_before)
            if drift > 1.0:
                logger.warning(
                    f"Purpose Function: Inconsistent Φ_before "
                    f"(new={phi_before:.1f}, cached={cached_before:.1f}), "
                    f"using cached value"
                )
                phi_before = cached_before

        if cached_after is not None:
            drift = abs(phi_after - cached_after)
            if drift > 1.0:
                logger.warning(
                    f"Purpose Function: Inconsistent Φ_after "
                    f"(new={phi_after:.1f}, cached={cached_after:.1f}), "
                    f"using cached value"
                )
                phi_after = cached_after

        # Rule 3: Confidence threshold
        if confidence < self.min_confidence:
            logger.warning(
                f"Purpose Function: Low confidence ({confidence:.2f}), "
                f"reducing delta magnitude by 50%"
            )
            midpoint = (phi_before + phi_after) / 2
            phi_after = midpoint + (phi_after - midpoint) * 0.5

        # Rule 4: Anomaly detection — flag suspiciously large single-step jumps
        delta = phi_after - phi_before
        if abs(delta) > 3.0:
            logger.warning(
                f"Purpose Function: Unusually large delta ({delta:+.1f}), "
                f"applying confidence penalty"
            )
            confidence = min(confidence, 0.5)

        return phi_before, phi_after, confidence

    # ------------------------------------------------------------------
    # Caching
    # ------------------------------------------------------------------

    def _state_hash(self, state: State) -> str:
        """Hash a state for cache lookup (based on data content)."""
        return json.dumps(state.data, sort_keys=True, default=str)

    def _get_cached_phi(self, state: State) -> float | None:
        return self._phi_cache.get(self._state_hash(state))

    def _cache_phi(self, state: State, phi: float) -> None:
        key = self._state_hash(state)
        if len(self._phi_cache) >= self._cache_size:
            # Evict oldest (FIFO — good enough for our use case)
            oldest_key = next(iter(self._phi_cache))
            del self._phi_cache[oldest_key]
        self._phi_cache[key] = phi

    # ------------------------------------------------------------------
    # Normalization (prevent score inflation over long trajectories)
    # ------------------------------------------------------------------

    def get_normalized_phi(self, raw_phi: float) -> float:
        """
        Normalize a Φ score relative to the trajectory's score distribution.
        
        Prevents the common failure mode where LLM scores drift upward over
        a trajectory regardless of actual progress.
        """
        if len(self._score_history) < 3:
            return raw_phi

        mean = sum(self._score_history) / len(self._score_history)
        variance = sum((x - mean) ** 2 for x in self._score_history) / len(self._score_history)
        std = max(variance ** 0.5, 0.1)  # Avoid division by zero

        # Z-score normalization mapped back to 0-10
        z = (raw_phi - mean) / std
        normalized = 5.0 + z * 2.0  # Center at 5, spread by 2
        return max(0.0, min(10.0, normalized))

    def reset_trajectory_stats(self) -> None:
        """Reset per-trajectory normalization stats. Call at trajectory start."""
        self._score_history = []

    # ------------------------------------------------------------------
    # Fallback
    # ------------------------------------------------------------------

    def _fallback_evaluate(self, messages: list[ChatMessage]) -> dict[str, Any]:
        """Text-based fallback when structured output is unavailable."""
        raw = self.llm.generate(messages, temperature=0.2, max_tokens=2000)

        import re

        def safe_float(s, default=0.0):
            """Parse float from string, handling trailing dots and garbage."""
            try:
                return float(s.rstrip('.'))
            except (ValueError, TypeError):
                return default

        phi_before = 0.0
        phi_after = 0.0

        # Try to extract JSON block first (most reliable)
        json_match = re.search(r'\{[^{}]*"phi_before"[^{}]*\}', raw, re.DOTALL)
        if json_match:
            try:
                parsed = json.loads(json_match.group())
                return parsed
            except (json.JSONDecodeError, ValueError):
                pass

        # Try to extract scores from text
        before_match = re.search(r'[Φφ]\s*\(?state_?before\)?\s*[=:]\s*([\d.]+)', raw, re.IGNORECASE)
        after_match = re.search(r'[Φφ]\s*\(?state_?after\)?\s*[=:]\s*([\d.]+)', raw, re.IGNORECASE)

        if before_match:
            phi_before = safe_float(before_match.group(1))
        if after_match:
            phi_after = safe_float(after_match.group(1))

        # Also try "Score: X/10" patterns (only if we found Φ markers)
        if not before_match and not after_match:
            score_matches = re.findall(r'(\d+\.?\d*)\s*/\s*10', raw)  # require explicit /10
            if len(score_matches) >= 2:
                phi_before = safe_float(score_matches[0])
                phi_after = safe_float(score_matches[1])
            elif len(score_matches) == 1:
                phi_after = safe_float(score_matches[0])

        # If no scores found, return conservative defaults (don't guess from random numbers)
        # This is honest: if the LLM didn't produce parseable scores, admit uncertainty

        confidence_match = re.search(r'confidence\s*[=:]\s*([\d.]+)', raw, re.IGNORECASE)
        confidence = safe_float(confidence_match.group(1), 0.4) if confidence_match else 0.4

        return {
            "phi_before": phi_before,
            "phi_after": phi_after,
            "reasoning": raw[:500],
            "evidence": raw[500:800] if len(raw) > 500 else "",
            "confidence": confidence,
        }