File size: 4,296 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Reward functions for both roles, following R-Zero's Algorithm 1.

- Repair Agent (Solver): visible verifier reward (binary-ish with partial credit)
- Drift Generator (Challenger): uncertainty reward + repetition penalty
- Alignment metric: Pearson correlation between visible and held-out scores
"""
from __future__ import annotations

import numpy as np
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from sklearn.cluster import AgglomerativeClustering

from forgeenv.tasks.models import ExecutionResult, Task
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores  # re-export
from forgeenv.verifier.visible_verifier import compute_visible_reward

__all__ = [
    "compute_repair_reward",
    "compute_uncertainty_reward",
    "compute_repetition_penalty",
    "compute_drift_gen_reward",
    "compute_alignment_score",
    "compute_held_out_scores",
    "compute_visible_reward",
]


def compute_repair_reward(result: ExecutionResult, task: Task) -> float:
    """Repair Agent reward: visible verifier scalar."""
    reward, _ = compute_visible_reward(result, task)
    return reward


def compute_uncertainty_reward(success_rates: list[bool]) -> float:
    """R-Zero uncertainty reward (Section 2.2, Eq. 1).

    r_uncertainty = 1 - 2 * |p_hat - 0.5|

    Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift
    Generator to propose breakages exactly at the edge of Repair Agent
    capability.
    """
    if not success_rates:
        return 0.0
    p_hat = sum(success_rates) / len(success_rates)
    return 1.0 - 2.0 * abs(p_hat - 0.5)


def compute_repetition_penalty(
    breakage_text: str,
    batch_breakages: list[str],
    threshold: float = 0.5,
) -> float:
    """R-Zero repetition penalty.

    Cluster batch breakages by 1 - BLEU distance using agglomerative
    clustering, then penalize a target proportional to the size of its
    cluster (encouraging diverse proposals).
    """
    if len(batch_breakages) <= 1:
        return 0.0

    smoother = SmoothingFunction().method1
    n = len(batch_breakages)
    distances = np.ones((n, n), dtype=np.float64)

    for i in range(n):
        distances[i][i] = 0.0
        for j in range(i + 1, n):
            tokens_i = batch_breakages[i].split()
            tokens_j = batch_breakages[j].split()
            if tokens_i and tokens_j:
                bleu = sentence_bleu(
                    [tokens_i], tokens_j, smoothing_function=smoother
                )
                dist = 1.0 - bleu
            else:
                dist = 1.0
            distances[i][j] = dist
            distances[j][i] = dist

    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=threshold,
        metric="precomputed",
        linkage="average",
    )
    labels = clustering.fit_predict(distances)

    target_idx = (
        batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0
    )
    target_cluster = labels[target_idx]
    cluster_size = int(sum(1 for label in labels if label == target_cluster))
    return cluster_size / len(batch_breakages)


def compute_drift_gen_reward(
    breakage_text: str,
    repair_successes: list[bool],
    batch_breakages: list[str],
) -> float:
    """R-Zero composite Challenger reward: max(0, uncertainty - repetition)."""
    uncertainty = compute_uncertainty_reward(repair_successes)
    penalty = compute_repetition_penalty(breakage_text, batch_breakages)
    return max(0.0, uncertainty - penalty)


def compute_alignment_score(
    visible_scores: list[float],
    held_out_scores: list[float],
) -> float:
    """Pearson correlation between visible verifier and held-out scores
    across rollouts. Used to train the Drift Generator to propose
    breakages where the visible verifier tracks ground truth (anti-hacking
    signal: an exploitable visible verifier produces low correlation)."""
    if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores):
        return 0.0

    v = np.asarray(visible_scores, dtype=np.float64)
    h = np.asarray(held_out_scores, dtype=np.float64)

    if v.std() < 1e-8 or h.std() < 1e-8:
        return 0.0

    correlation = np.corrcoef(v, h)[0, 1]
    return 0.0 if np.isnan(correlation) else float(correlation)