File size: 4,419 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """Held-out evaluator: the deterministic ground-truth scorer.
Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees
this directly; the Drift Generator's training signal derives from
alignment between the visible verifier and this evaluator (Pearson
correlation across the K rollouts).
"""
from __future__ import annotations
import ast
import re
from forgeenv.tasks.models import ExecutionResult, Task
def compute_held_out_scores(
result: ExecutionResult, task: Task, repair_diff: str = ""
) -> dict[str, float]:
"""Compute 7 independent held-out components."""
scores: dict[str, float] = {
"executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
"checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
"loss_decreased": _compute_loss_score(result.stdout),
"metrics_in_range": _check_metrics(result.stdout, task),
"no_forbidden_workarounds": _check_workarounds(result.script_content),
"intent_preserved": _compute_intent_preservation(
task.script_content, result.script_content
),
"hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
}
return scores
def _compute_loss_score(stdout: str) -> float:
"""Continuous score based on relative loss decrease from first to last step."""
losses: list[float] = []
for line in stdout.splitlines():
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
if match:
try:
losses.append(float(match.group(1)))
except ValueError:
continue
if len(losses) < 2:
return 0.0
decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
return max(0.0, min(1.0, decrease))
def _check_metrics(stdout: str, task: Task) -> float:
"""Return 1.0 if any reported accuracy/eval metric falls in the task's
expected range; partial credit otherwise; 0.5 if no metric was found."""
for line in stdout.splitlines():
match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
if match:
try:
val = float(match.group(1))
low, high = task.expected_accuracy_range
if low <= val <= high:
return 1.0
distance = min(abs(val - low), abs(val - high))
return max(0.0, 1.0 - distance)
except ValueError:
continue
return 0.5
def _check_workarounds(script_content: str) -> float:
"""Detect forbidden workaround patterns via AST analysis.
Catches: bare except, `except Exception: pass`, `except Exception: return`,
monkey-patching of `__getattr__` / `__class__` / `__dict__`.
"""
if not script_content:
return 0.0
try:
tree = ast.parse(script_content)
except SyntaxError:
return 0.0
violations = 0
for node in ast.walk(tree):
if isinstance(node, ast.Try):
for handler in node.handlers:
if handler.type is None:
violations += 1
elif (
isinstance(handler.type, ast.Name)
and handler.type.id == "Exception"
):
if len(handler.body) == 1 and isinstance(
handler.body[0], (ast.Pass, ast.Return)
):
violations += 1
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Attribute):
if target.attr in ("__getattr__", "__class__", "__dict__"):
violations += 1
return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)
def _compute_intent_preservation(original: str, repaired: str) -> float:
"""Measure how much of the original AST structure is preserved.
Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).
"""
if not original or not repaired:
return 0.0
try:
orig_tree = ast.parse(original)
repair_tree = ast.parse(repaired)
except SyntaxError:
return 0.0
orig_nodes = len(list(ast.walk(orig_tree)))
repair_nodes = len(list(ast.walk(repair_tree)))
if orig_nodes == 0:
return 0.0
return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)
|