| """Reward functions for GRPO: env-backed verification + cheap shaping signals. |
| |
| GRPO takes a list of `reward_funcs`. Each must accept `completions` and any |
| columns from the dataset as kwargs, and return one scalar per completion. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import re |
| from typing import List |
|
|
| from .client import EnvClient |
| from .prompt import extract_code |
|
|
| log = logging.getLogger("opensleuth.reward") |
|
|
|
|
| def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0) -> callable: |
| """Verifier-backed reward. Calls the env's `submit` and returns the env's |
| reward divided by `scale` (default: divide by 100 so a perfect submission |
| is ~+1.5 and a bad one is around -0.5; this keeps GRPO advantages well |
| behaved without needing reward normalisation). |
| """ |
|
|
| def env_reward(completions, target_function_name=None, row_seed=None, **kwargs): |
| rewards: List[float] = [] |
| |
| |
| for i, completion in enumerate(completions): |
| text = _extract_text(completion) |
| code = extract_code(text) |
| tname = _index(target_function_name, i, default="fibonacci") |
| seed = _index(row_seed, i, default=0) |
| try: |
| env_reward_value = client.score_submission(tname, code, seed=seed) |
| except Exception as e: |
| log.warning("env scoring failed for %s: %s", tname, e) |
| env_reward_value = -50.0 |
| rewards.append(env_reward_value * scale) |
| return rewards |
|
|
| return env_reward |
|
|
|
|
| _FUNC_RE = re.compile(r"^def\s+(\w+)\s*\(", re.MULTILINE) |
|
|
|
|
| def format_reward(completions, target_function_name=None, **kwargs): |
| """Cheap shaping reward: +0.2 if the completion contains a fenced python |
| block AND defines a function with the right name. Encourages the model to |
| converge on the expected output format quickly so the env reward becomes |
| informative early in training.""" |
| rewards: List[float] = [] |
| for i, completion in enumerate(completions): |
| text = _extract_text(completion) |
| score = 0.0 |
| if "```python" in text or "```\n" in text: |
| score += 0.1 |
| code = extract_code(text) |
| m = _FUNC_RE.search(code) |
| tname = _index(target_function_name, i, default=None) |
| if m and (tname is None or m.group(1) == tname): |
| score += 0.1 |
| rewards.append(score) |
| return rewards |
|
|
|
|
| def _extract_text(completion): |
| """GRPO can pass either a string or an OpenAI-style chat list of dicts. |
| Normalise to a single string.""" |
| if isinstance(completion, str): |
| return completion |
| if isinstance(completion, list): |
| |
| parts = [] |
| for msg in completion: |
| if isinstance(msg, dict) and "content" in msg: |
| parts.append(str(msg["content"])) |
| else: |
| parts.append(str(msg)) |
| return "\n".join(parts) |
| return str(completion) |
|
|
|
|
| def _index(value, i: int, default): |
| if value is None: |
| return default |
| if isinstance(value, list): |
| return value[i] if i < len(value) else default |
| return value |
|
|