"""Held-out evaluator: the deterministic ground-truth scorer. Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees this directly; the Drift Generator's training signal derives from alignment between the visible verifier and this evaluator (Pearson correlation across the K rollouts). """ from __future__ import annotations import ast import re from forgeenv.tasks.models import ExecutionResult, Task def compute_held_out_scores( result: ExecutionResult, task: Task, repair_diff: str = "" ) -> dict[str, float]: """Compute 7 independent held-out components.""" scores: dict[str, float] = { "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0, "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0, "loss_decreased": _compute_loss_score(result.stdout), "metrics_in_range": _check_metrics(result.stdout, task), "no_forbidden_workarounds": _check_workarounds(result.script_content), "intent_preserved": _compute_intent_preservation( task.script_content, result.script_content ), "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0, } return scores def _compute_loss_score(stdout: str) -> float: """Continuous score based on relative loss decrease from first to last step.""" losses: list[float] = [] for line in stdout.splitlines(): match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) if match: try: losses.append(float(match.group(1))) except ValueError: continue if len(losses) < 2: return 0.0 decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8) return max(0.0, min(1.0, decrease)) def _check_metrics(stdout: str, task: Task) -> float: """Return 1.0 if any reported accuracy/eval metric falls in the task's expected range; partial credit otherwise; 0.5 if no metric was found.""" for line in stdout.splitlines(): match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE) if match: try: val = float(match.group(1)) low, high = task.expected_accuracy_range if low <= val <= high: return 1.0 distance = min(abs(val - low), abs(val - high)) return max(0.0, 1.0 - distance) except ValueError: continue return 0.5 def _check_workarounds(script_content: str) -> float: """Detect forbidden workaround patterns via AST analysis. Catches: bare except, `except Exception: pass`, `except Exception: return`, monkey-patching of `__getattr__` / `__class__` / `__dict__`. """ if not script_content: return 0.0 try: tree = ast.parse(script_content) except SyntaxError: return 0.0 violations = 0 for node in ast.walk(tree): if isinstance(node, ast.Try): for handler in node.handlers: if handler.type is None: violations += 1 elif ( isinstance(handler.type, ast.Name) and handler.type.id == "Exception" ): if len(handler.body) == 1 and isinstance( handler.body[0], (ast.Pass, ast.Return) ): violations += 1 if isinstance(node, ast.Assign): for target in node.targets: if isinstance(target, ast.Attribute): if target.attr in ("__getattr__", "__class__", "__dict__"): violations += 1 return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3) def _compute_intent_preservation(original: str, repaired: str) -> float: """Measure how much of the original AST structure is preserved. Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...). """ if not original or not repaired: return 0.0 try: orig_tree = ast.parse(original) repair_tree = ast.parse(repaired) except SyntaxError: return 0.0 orig_nodes = len(list(ast.walk(orig_tree))) repair_nodes = len(list(ast.walk(repair_tree))) if orig_nodes == 0: return 0.0 return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)