File size: 4,553 Bytes
b0fbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Held-out evaluator: the deterministic ground-truth scorer.



Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees

this directly; the Drift Generator's training signal derives from

alignment between the visible verifier and this evaluator (Pearson

correlation across the K rollouts).

"""
from __future__ import annotations

import ast
import re

from forgeenv.tasks.models import ExecutionResult, Task


def compute_held_out_scores(

    result: ExecutionResult, task: Task, repair_diff: str = ""

) -> dict[str, float]:
    """Compute 7 independent held-out components."""

    scores: dict[str, float] = {
        "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
        "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
        "loss_decreased": _compute_loss_score(result.stdout),
        "metrics_in_range": _check_metrics(result.stdout, task),
        "no_forbidden_workarounds": _check_workarounds(result.script_content),
        "intent_preserved": _compute_intent_preservation(
            task.script_content, result.script_content
        ),
        "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
    }
    return scores


def _compute_loss_score(stdout: str) -> float:
    """Continuous score based on relative loss decrease from first to last step."""

    losses: list[float] = []
    for line in stdout.splitlines():
        match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
        if match:
            try:
                losses.append(float(match.group(1)))
            except ValueError:
                continue

    if len(losses) < 2:
        return 0.0

    decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
    return max(0.0, min(1.0, decrease))


def _check_metrics(stdout: str, task: Task) -> float:
    """Return 1.0 if any reported accuracy/eval metric falls in the task's

    expected range; partial credit otherwise; 0.5 if no metric was found."""

    for line in stdout.splitlines():
        match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
        if match:
            try:
                val = float(match.group(1))
                low, high = task.expected_accuracy_range
                if low <= val <= high:
                    return 1.0
                distance = min(abs(val - low), abs(val - high))
                return max(0.0, 1.0 - distance)
            except ValueError:
                continue
    return 0.5


def _check_workarounds(script_content: str) -> float:
    """Detect forbidden workaround patterns via AST analysis.



    Catches: bare except, `except Exception: pass`, `except Exception: return`,

    monkey-patching of `__getattr__` / `__class__` / `__dict__`.

    """

    if not script_content:
        return 0.0

    try:
        tree = ast.parse(script_content)
    except SyntaxError:
        return 0.0

    violations = 0

    for node in ast.walk(tree):
        if isinstance(node, ast.Try):
            for handler in node.handlers:
                if handler.type is None:
                    violations += 1
                elif (
                    isinstance(handler.type, ast.Name)
                    and handler.type.id == "Exception"
                ):
                    if len(handler.body) == 1 and isinstance(
                        handler.body[0], (ast.Pass, ast.Return)
                    ):
                        violations += 1

        if isinstance(node, ast.Assign):
            for target in node.targets:
                if isinstance(target, ast.Attribute):
                    if target.attr in ("__getattr__", "__class__", "__dict__"):
                        violations += 1

    return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)


def _compute_intent_preservation(original: str, repaired: str) -> float:
    """Measure how much of the original AST structure is preserved.



    Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).

    """

    if not original or not repaired:
        return 0.0

    try:
        orig_tree = ast.parse(original)
        repair_tree = ast.parse(repaired)
    except SyntaxError:
        return 0.0

    orig_nodes = len(list(ast.walk(orig_tree)))
    repair_nodes = len(list(ast.walk(repair_tree)))

    if orig_nodes == 0:
        return 0.0

    return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)