| """Held-out evaluator: the deterministic ground-truth scorer. |
| |
| Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees |
| this directly; the Drift Generator's training signal derives from |
| alignment between the visible verifier and this evaluator (Pearson |
| correlation across the K rollouts). |
| """ |
| from __future__ import annotations |
|
|
| import ast |
| import re |
|
|
| from forgeenv.tasks.models import ExecutionResult, Task |
|
|
|
|
| def compute_held_out_scores( |
| result: ExecutionResult, task: Task, repair_diff: str = "" |
| ) -> dict[str, float]: |
| """Compute 7 independent held-out components.""" |
|
|
| scores: dict[str, float] = { |
| "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0, |
| "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0, |
| "loss_decreased": _compute_loss_score(result.stdout), |
| "metrics_in_range": _check_metrics(result.stdout, task), |
| "no_forbidden_workarounds": _check_workarounds(result.script_content), |
| "intent_preserved": _compute_intent_preservation( |
| task.script_content, result.script_content |
| ), |
| "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0, |
| } |
| return scores |
|
|
|
|
| def _compute_loss_score(stdout: str) -> float: |
| """Continuous score based on relative loss decrease from first to last step.""" |
|
|
| losses: list[float] = [] |
| for line in stdout.splitlines(): |
| match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) |
| if match: |
| try: |
| losses.append(float(match.group(1))) |
| except ValueError: |
| continue |
|
|
| if len(losses) < 2: |
| return 0.0 |
|
|
| decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8) |
| return max(0.0, min(1.0, decrease)) |
|
|
|
|
| def _check_metrics(stdout: str, task: Task) -> float: |
| """Return 1.0 if any reported accuracy/eval metric falls in the task's |
| expected range; partial credit otherwise; 0.5 if no metric was found.""" |
|
|
| for line in stdout.splitlines(): |
| match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE) |
| if match: |
| try: |
| val = float(match.group(1)) |
| low, high = task.expected_accuracy_range |
| if low <= val <= high: |
| return 1.0 |
| distance = min(abs(val - low), abs(val - high)) |
| return max(0.0, 1.0 - distance) |
| except ValueError: |
| continue |
| return 0.5 |
|
|
|
|
| def _check_workarounds(script_content: str) -> float: |
| """Detect forbidden workaround patterns via AST analysis. |
| |
| Catches: bare except, `except Exception: pass`, `except Exception: return`, |
| monkey-patching of `__getattr__` / `__class__` / `__dict__`. |
| """ |
|
|
| if not script_content: |
| return 0.0 |
|
|
| try: |
| tree = ast.parse(script_content) |
| except SyntaxError: |
| return 0.0 |
|
|
| violations = 0 |
|
|
| for node in ast.walk(tree): |
| if isinstance(node, ast.Try): |
| for handler in node.handlers: |
| if handler.type is None: |
| violations += 1 |
| elif ( |
| isinstance(handler.type, ast.Name) |
| and handler.type.id == "Exception" |
| ): |
| if len(handler.body) == 1 and isinstance( |
| handler.body[0], (ast.Pass, ast.Return) |
| ): |
| violations += 1 |
|
|
| if isinstance(node, ast.Assign): |
| for target in node.targets: |
| if isinstance(target, ast.Attribute): |
| if target.attr in ("__getattr__", "__class__", "__dict__"): |
| violations += 1 |
|
|
| return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3) |
|
|
|
|
| def _compute_intent_preservation(original: str, repaired: str) -> float: |
| """Measure how much of the original AST structure is preserved. |
| |
| Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...). |
| """ |
|
|
| if not original or not repaired: |
| return 0.0 |
|
|
| try: |
| orig_tree = ast.parse(original) |
| repair_tree = ast.parse(repaired) |
| except SyntaxError: |
| return 0.0 |
|
|
| orig_nodes = len(list(ast.walk(orig_tree))) |
| repair_nodes = len(list(ast.walk(repair_tree))) |
|
|
| if orig_nodes == 0: |
| return 0.0 |
|
|
| return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes) |
|
|