forgeenv-source / forgeenv-space /forgeenv /verifier /held_out_evaluator.py
akhiilll's picture
forgeenv source snapshot for training job
b0fbec3 verified
"""Held-out evaluator: the deterministic ground-truth scorer.
Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees
this directly; the Drift Generator's training signal derives from
alignment between the visible verifier and this evaluator (Pearson
correlation across the K rollouts).
"""
from __future__ import annotations
import ast
import re
from forgeenv.tasks.models import ExecutionResult, Task
def compute_held_out_scores(
result: ExecutionResult, task: Task, repair_diff: str = ""
) -> dict[str, float]:
"""Compute 7 independent held-out components."""
scores: dict[str, float] = {
"executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
"checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
"loss_decreased": _compute_loss_score(result.stdout),
"metrics_in_range": _check_metrics(result.stdout, task),
"no_forbidden_workarounds": _check_workarounds(result.script_content),
"intent_preserved": _compute_intent_preservation(
task.script_content, result.script_content
),
"hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
}
return scores
def _compute_loss_score(stdout: str) -> float:
"""Continuous score based on relative loss decrease from first to last step."""
losses: list[float] = []
for line in stdout.splitlines():
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
if match:
try:
losses.append(float(match.group(1)))
except ValueError:
continue
if len(losses) < 2:
return 0.0
decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
return max(0.0, min(1.0, decrease))
def _check_metrics(stdout: str, task: Task) -> float:
"""Return 1.0 if any reported accuracy/eval metric falls in the task's
expected range; partial credit otherwise; 0.5 if no metric was found."""
for line in stdout.splitlines():
match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
if match:
try:
val = float(match.group(1))
low, high = task.expected_accuracy_range
if low <= val <= high:
return 1.0
distance = min(abs(val - low), abs(val - high))
return max(0.0, 1.0 - distance)
except ValueError:
continue
return 0.5
def _check_workarounds(script_content: str) -> float:
"""Detect forbidden workaround patterns via AST analysis.
Catches: bare except, `except Exception: pass`, `except Exception: return`,
monkey-patching of `__getattr__` / `__class__` / `__dict__`.
"""
if not script_content:
return 0.0
try:
tree = ast.parse(script_content)
except SyntaxError:
return 0.0
violations = 0
for node in ast.walk(tree):
if isinstance(node, ast.Try):
for handler in node.handlers:
if handler.type is None:
violations += 1
elif (
isinstance(handler.type, ast.Name)
and handler.type.id == "Exception"
):
if len(handler.body) == 1 and isinstance(
handler.body[0], (ast.Pass, ast.Return)
):
violations += 1
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Attribute):
if target.attr in ("__getattr__", "__class__", "__dict__"):
violations += 1
return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)
def _compute_intent_preservation(original: str, repaired: str) -> float:
"""Measure how much of the original AST structure is preserved.
Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).
"""
if not original or not repaired:
return 0.0
try:
orig_tree = ast.parse(original)
repair_tree = ast.parse(repaired)
except SyntaxError:
return 0.0
orig_nodes = len(list(ast.walk(orig_tree)))
repair_nodes = len(list(ast.walk(repair_tree)))
if orig_nodes == 0:
return 0.0
return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)