File size: 3,370 Bytes
d597642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Reward functions for GRPO: env-backed verification + cheap shaping signals.

GRPO takes a list of `reward_funcs`. Each must accept `completions` and any
columns from the dataset as kwargs, and return one scalar per completion.
"""

from __future__ import annotations

import logging
import re
from typing import List

from .client import EnvClient
from .prompt import extract_code

log = logging.getLogger("opensleuth.reward")


def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0) -> callable:
    """Verifier-backed reward. Calls the env's `submit` and returns the env's
    reward divided by `scale` (default: divide by 100 so a perfect submission
    is ~+1.5 and a bad one is around -0.5; this keeps GRPO advantages well
    behaved without needing reward normalisation).
    """

    def env_reward(completions, target_function_name=None, row_seed=None, **kwargs):  # noqa: ANN001
        rewards: List[float] = []
        # GRPO calls the reward fn once per completion; both target_function_name
        # and row_seed come in as lists of length len(completions).
        for i, completion in enumerate(completions):
            text = _extract_text(completion)
            code = extract_code(text)
            tname = _index(target_function_name, i, default="fibonacci")
            seed = _index(row_seed, i, default=0)
            try:
                env_reward_value = client.score_submission(tname, code, seed=seed)
            except Exception as e:  # noqa: BLE001
                log.warning("env scoring failed for %s: %s", tname, e)
                env_reward_value = -50.0
            rewards.append(env_reward_value * scale)
        return rewards

    return env_reward


_FUNC_RE = re.compile(r"^def\s+(\w+)\s*\(", re.MULTILINE)


def format_reward(completions, target_function_name=None, **kwargs):  # noqa: ANN001
    """Cheap shaping reward: +0.2 if the completion contains a fenced python
    block AND defines a function with the right name. Encourages the model to
    converge on the expected output format quickly so the env reward becomes
    informative early in training."""
    rewards: List[float] = []
    for i, completion in enumerate(completions):
        text = _extract_text(completion)
        score = 0.0
        if "```python" in text or "```\n" in text:
            score += 0.1
        code = extract_code(text)
        m = _FUNC_RE.search(code)
        tname = _index(target_function_name, i, default=None)
        if m and (tname is None or m.group(1) == tname):
            score += 0.1
        rewards.append(score)
    return rewards


def _extract_text(completion):  # noqa: ANN001
    """GRPO can pass either a string or an OpenAI-style chat list of dicts.
    Normalise to a single string."""
    if isinstance(completion, str):
        return completion
    if isinstance(completion, list):
        # [{role:..., content:...}, ...]
        parts = []
        for msg in completion:
            if isinstance(msg, dict) and "content" in msg:
                parts.append(str(msg["content"]))
            else:
                parts.append(str(msg))
        return "\n".join(parts)
    return str(completion)


def _index(value, i: int, default):
    if value is None:
        return default
    if isinstance(value, list):
        return value[i] if i < len(value) else default
    return value