| """Held-out evaluator: the deterministic ground-truth scorer. | |
| Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees | |
| this directly; the Drift Generator's training signal derives from | |
| alignment between the visible verifier and this evaluator (Pearson | |
| correlation across the K rollouts). | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import re | |
| from forgeenv.tasks.models import ExecutionResult, Task | |
| def compute_held_out_scores( | |
| result: ExecutionResult, task: Task, repair_diff: str = "" | |
| ) -> dict[str, float]: | |
| """Compute 7 independent held-out components.""" | |
| scores: dict[str, float] = { | |
| "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0, | |
| "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0, | |
| "loss_decreased": _compute_loss_score(result.stdout), | |
| "metrics_in_range": _check_metrics(result.stdout, task), | |
| "no_forbidden_workarounds": _check_workarounds(result.script_content), | |
| "intent_preserved": _compute_intent_preservation( | |
| task.script_content, result.script_content | |
| ), | |
| "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0, | |
| } | |
| return scores | |
| def _compute_loss_score(stdout: str) -> float: | |
| """Continuous score based on relative loss decrease from first to last step.""" | |
| losses: list[float] = [] | |
| for line in stdout.splitlines(): | |
| match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) | |
| if match: | |
| try: | |
| losses.append(float(match.group(1))) | |
| except ValueError: | |
| continue | |
| if len(losses) < 2: | |
| return 0.0 | |
| decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8) | |
| return max(0.0, min(1.0, decrease)) | |
| def _check_metrics(stdout: str, task: Task) -> float: | |
| """Return 1.0 if any reported accuracy/eval metric falls in the task's | |
| expected range; partial credit otherwise; 0.5 if no metric was found.""" | |
| for line in stdout.splitlines(): | |
| match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE) | |
| if match: | |
| try: | |
| val = float(match.group(1)) | |
| low, high = task.expected_accuracy_range | |
| if low <= val <= high: | |
| return 1.0 | |
| distance = min(abs(val - low), abs(val - high)) | |
| return max(0.0, 1.0 - distance) | |
| except ValueError: | |
| continue | |
| return 0.5 | |
| def _check_workarounds(script_content: str) -> float: | |
| """Detect forbidden workaround patterns via AST analysis. | |
| Catches: bare except, `except Exception: pass`, `except Exception: return`, | |
| monkey-patching of `__getattr__` / `__class__` / `__dict__`. | |
| """ | |
| if not script_content: | |
| return 0.0 | |
| try: | |
| tree = ast.parse(script_content) | |
| except SyntaxError: | |
| return 0.0 | |
| violations = 0 | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Try): | |
| for handler in node.handlers: | |
| if handler.type is None: | |
| violations += 1 | |
| elif ( | |
| isinstance(handler.type, ast.Name) | |
| and handler.type.id == "Exception" | |
| ): | |
| if len(handler.body) == 1 and isinstance( | |
| handler.body[0], (ast.Pass, ast.Return) | |
| ): | |
| violations += 1 | |
| if isinstance(node, ast.Assign): | |
| for target in node.targets: | |
| if isinstance(target, ast.Attribute): | |
| if target.attr in ("__getattr__", "__class__", "__dict__"): | |
| violations += 1 | |
| return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3) | |
| def _compute_intent_preservation(original: str, repaired: str) -> float: | |
| """Measure how much of the original AST structure is preserved. | |
| Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...). | |
| """ | |
| if not original or not repaired: | |
| return 0.0 | |
| try: | |
| orig_tree = ast.parse(original) | |
| repair_tree = ast.parse(repaired) | |
| except SyntaxError: | |
| return 0.0 | |
| orig_nodes = len(list(ast.walk(orig_tree))) | |
| repair_nodes = len(list(ast.walk(repair_tree))) | |
| if orig_nodes == 0: | |
| return 0.0 | |
| return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes) | |