File size: 2,238 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Visible verifier: the immediate reward signal the Repair Agent sees.

4 weighted components, summed to a scalar. This is what drives the Repair
Agent's GRPO updates each rollout. Multiple independent components were
chosen on purpose, per the reward-engineering survey (arxiv 2408.10215)
and software-tasks survey (arxiv 2601.19100): a single scalar is far
easier to game than a composable rubric.
"""
from __future__ import annotations

import re

from forgeenv.tasks.models import ExecutionResult, Task

WEIGHTS: dict[str, float] = {
    "script_executes": 1.0,
    "loss_decreased": 0.5,
    "checkpoint_appeared": 0.3,
    "diff_size_penalty": 0.2,  # multiplied with a non-positive component value
}


def compute_visible_reward(
    result: ExecutionResult, task: Task
) -> tuple[float, dict[str, float]]:
    """Compute scalar visible reward and per-component breakdown."""

    components: dict[str, float] = {}

    components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0
    components["loss_decreased"] = _check_loss_trend(result.stdout)
    components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0

    original_lines = max(len(task.script_content.splitlines()), 1)
    current_lines = (
        len(result.script_content.splitlines()) if result.script_content else original_lines
    )
    diff_ratio = abs(current_lines - original_lines) / original_lines
    components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0

    total = sum(components[k] * WEIGHTS[k] for k in components)
    return total, components


def _check_loss_trend(stdout: str) -> float:
    """Parse stdout for `loss=...` patterns and return the fraction of
    consecutive steps where loss strictly decreased."""

    losses: list[float] = []
    for line in stdout.splitlines():
        match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
        if match:
            try:
                losses.append(float(match.group(1)))
            except ValueError:
                continue

    if len(losses) < 2:
        return 0.0

    decreasing_steps = sum(
        1 for i in range(1, len(losses)) if losses[i] < losses[i - 1]
    )
    return decreasing_steps / (len(losses) - 1)