File size: 4,296 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | """Reward functions for both roles, following R-Zero's Algorithm 1.
- Repair Agent (Solver): visible verifier reward (binary-ish with partial credit)
- Drift Generator (Challenger): uncertainty reward + repetition penalty
- Alignment metric: Pearson correlation between visible and held-out scores
"""
from __future__ import annotations
import numpy as np
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from sklearn.cluster import AgglomerativeClustering
from forgeenv.tasks.models import ExecutionResult, Task
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores # re-export
from forgeenv.verifier.visible_verifier import compute_visible_reward
__all__ = [
"compute_repair_reward",
"compute_uncertainty_reward",
"compute_repetition_penalty",
"compute_drift_gen_reward",
"compute_alignment_score",
"compute_held_out_scores",
"compute_visible_reward",
]
def compute_repair_reward(result: ExecutionResult, task: Task) -> float:
"""Repair Agent reward: visible verifier scalar."""
reward, _ = compute_visible_reward(result, task)
return reward
def compute_uncertainty_reward(success_rates: list[bool]) -> float:
"""R-Zero uncertainty reward (Section 2.2, Eq. 1).
r_uncertainty = 1 - 2 * |p_hat - 0.5|
Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift
Generator to propose breakages exactly at the edge of Repair Agent
capability.
"""
if not success_rates:
return 0.0
p_hat = sum(success_rates) / len(success_rates)
return 1.0 - 2.0 * abs(p_hat - 0.5)
def compute_repetition_penalty(
breakage_text: str,
batch_breakages: list[str],
threshold: float = 0.5,
) -> float:
"""R-Zero repetition penalty.
Cluster batch breakages by 1 - BLEU distance using agglomerative
clustering, then penalize a target proportional to the size of its
cluster (encouraging diverse proposals).
"""
if len(batch_breakages) <= 1:
return 0.0
smoother = SmoothingFunction().method1
n = len(batch_breakages)
distances = np.ones((n, n), dtype=np.float64)
for i in range(n):
distances[i][i] = 0.0
for j in range(i + 1, n):
tokens_i = batch_breakages[i].split()
tokens_j = batch_breakages[j].split()
if tokens_i and tokens_j:
bleu = sentence_bleu(
[tokens_i], tokens_j, smoothing_function=smoother
)
dist = 1.0 - bleu
else:
dist = 1.0
distances[i][j] = dist
distances[j][i] = dist
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=threshold,
metric="precomputed",
linkage="average",
)
labels = clustering.fit_predict(distances)
target_idx = (
batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0
)
target_cluster = labels[target_idx]
cluster_size = int(sum(1 for label in labels if label == target_cluster))
return cluster_size / len(batch_breakages)
def compute_drift_gen_reward(
breakage_text: str,
repair_successes: list[bool],
batch_breakages: list[str],
) -> float:
"""R-Zero composite Challenger reward: max(0, uncertainty - repetition)."""
uncertainty = compute_uncertainty_reward(repair_successes)
penalty = compute_repetition_penalty(breakage_text, batch_breakages)
return max(0.0, uncertainty - penalty)
def compute_alignment_score(
visible_scores: list[float],
held_out_scores: list[float],
) -> float:
"""Pearson correlation between visible verifier and held-out scores
across rollouts. Used to train the Drift Generator to propose
breakages where the visible verifier tracks ground truth (anti-hacking
signal: an exploitable visible verifier produces low correlation)."""
if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores):
return 0.0
v = np.asarray(visible_scores, dtype=np.float64)
h = np.asarray(held_out_scores, dtype=np.float64)
if v.std() < 1e-8 or h.std() < 1e-8:
return 0.0
correlation = np.corrcoef(v, h)[0, 1]
return 0.0 if np.isnan(correlation) else float(correlation)
|