| """Reward functions for both roles, following R-Zero's Algorithm 1. |
| |
| - Repair Agent (Solver): visible verifier reward (binary-ish with partial credit) |
| - Drift Generator (Challenger): uncertainty reward + repetition penalty |
| - Alignment metric: Pearson correlation between visible and held-out scores |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu |
| from sklearn.cluster import AgglomerativeClustering |
|
|
| from forgeenv.tasks.models import ExecutionResult, Task |
| from forgeenv.verifier.held_out_evaluator import compute_held_out_scores |
| from forgeenv.verifier.visible_verifier import compute_visible_reward |
|
|
| __all__ = [ |
| "compute_repair_reward", |
| "compute_uncertainty_reward", |
| "compute_repetition_penalty", |
| "compute_drift_gen_reward", |
| "compute_alignment_score", |
| "compute_held_out_scores", |
| "compute_visible_reward", |
| ] |
|
|
|
|
| def compute_repair_reward(result: ExecutionResult, task: Task) -> float: |
| """Repair Agent reward: visible verifier scalar.""" |
| reward, _ = compute_visible_reward(result, task) |
| return reward |
|
|
|
|
| def compute_uncertainty_reward(success_rates: list[bool]) -> float: |
| """R-Zero uncertainty reward (Section 2.2, Eq. 1). |
| |
| r_uncertainty = 1 - 2 * |p_hat - 0.5| |
| |
| Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift |
| Generator to propose breakages exactly at the edge of Repair Agent |
| capability. |
| """ |
| if not success_rates: |
| return 0.0 |
| p_hat = sum(success_rates) / len(success_rates) |
| return 1.0 - 2.0 * abs(p_hat - 0.5) |
|
|
|
|
| def compute_repetition_penalty( |
| breakage_text: str, |
| batch_breakages: list[str], |
| threshold: float = 0.5, |
| ) -> float: |
| """R-Zero repetition penalty. |
| |
| Cluster batch breakages by 1 - BLEU distance using agglomerative |
| clustering, then penalize a target proportional to the size of its |
| cluster (encouraging diverse proposals). |
| """ |
| if len(batch_breakages) <= 1: |
| return 0.0 |
|
|
| smoother = SmoothingFunction().method1 |
| n = len(batch_breakages) |
| distances = np.ones((n, n), dtype=np.float64) |
|
|
| for i in range(n): |
| distances[i][i] = 0.0 |
| for j in range(i + 1, n): |
| tokens_i = batch_breakages[i].split() |
| tokens_j = batch_breakages[j].split() |
| if tokens_i and tokens_j: |
| bleu = sentence_bleu( |
| [tokens_i], tokens_j, smoothing_function=smoother |
| ) |
| dist = 1.0 - bleu |
| else: |
| dist = 1.0 |
| distances[i][j] = dist |
| distances[j][i] = dist |
|
|
| clustering = AgglomerativeClustering( |
| n_clusters=None, |
| distance_threshold=threshold, |
| metric="precomputed", |
| linkage="average", |
| ) |
| labels = clustering.fit_predict(distances) |
|
|
| target_idx = ( |
| batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0 |
| ) |
| target_cluster = labels[target_idx] |
| cluster_size = int(sum(1 for label in labels if label == target_cluster)) |
| return cluster_size / len(batch_breakages) |
|
|
|
|
| def compute_drift_gen_reward( |
| breakage_text: str, |
| repair_successes: list[bool], |
| batch_breakages: list[str], |
| ) -> float: |
| """R-Zero composite Challenger reward: max(0, uncertainty - repetition).""" |
| uncertainty = compute_uncertainty_reward(repair_successes) |
| penalty = compute_repetition_penalty(breakage_text, batch_breakages) |
| return max(0.0, uncertainty - penalty) |
|
|
|
|
| def compute_alignment_score( |
| visible_scores: list[float], |
| held_out_scores: list[float], |
| ) -> float: |
| """Pearson correlation between visible verifier and held-out scores |
| across rollouts. Used to train the Drift Generator to propose |
| breakages where the visible verifier tracks ground truth (anti-hacking |
| signal: an exploitable visible verifier produces low correlation).""" |
| if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores): |
| return 0.0 |
|
|
| v = np.asarray(visible_scores, dtype=np.float64) |
| h = np.asarray(held_out_scores, dtype=np.float64) |
|
|
| if v.std() < 1e-8 or h.std() < 1e-8: |
| return 0.0 |
|
|
| correlation = np.corrcoef(v, h)[0, 1] |
| return 0.0 if np.isnan(correlation) else float(correlation) |
|
|