"""Reward functions for GRPO: env-backed verification + cheap shaping signals. GRPO takes a list of `reward_funcs`. Each must accept `completions` and any columns from the dataset as kwargs, and return one scalar per completion. """ from __future__ import annotations import logging import re from typing import List from .client import EnvClient from .prompt import extract_code log = logging.getLogger("opensleuth.reward") def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0) -> callable: """Verifier-backed reward. Calls the env's `submit` and returns the env's reward divided by `scale` (default: divide by 100 so a perfect submission is ~+1.5 and a bad one is around -0.5; this keeps GRPO advantages well behaved without needing reward normalisation). """ def env_reward(completions, target_function_name=None, row_seed=None, **kwargs): # noqa: ANN001 rewards: List[float] = [] # GRPO calls the reward fn once per completion; both target_function_name # and row_seed come in as lists of length len(completions). for i, completion in enumerate(completions): text = _extract_text(completion) code = extract_code(text) tname = _index(target_function_name, i, default="fibonacci") seed = _index(row_seed, i, default=0) try: env_reward_value = client.score_submission(tname, code, seed=seed) except Exception as e: # noqa: BLE001 log.warning("env scoring failed for %s: %s", tname, e) env_reward_value = -50.0 rewards.append(env_reward_value * scale) return rewards return env_reward _FUNC_RE = re.compile(r"^def\s+(\w+)\s*\(", re.MULTILINE) def format_reward(completions, target_function_name=None, **kwargs): # noqa: ANN001 """Cheap shaping reward: +0.2 if the completion contains a fenced python block AND defines a function with the right name. Encourages the model to converge on the expected output format quickly so the env reward becomes informative early in training.""" rewards: List[float] = [] for i, completion in enumerate(completions): text = _extract_text(completion) score = 0.0 if "```python" in text or "```\n" in text: score += 0.1 code = extract_code(text) m = _FUNC_RE.search(code) tname = _index(target_function_name, i, default=None) if m and (tname is None or m.group(1) == tname): score += 0.1 rewards.append(score) return rewards def _extract_text(completion): # noqa: ANN001 """GRPO can pass either a string or an OpenAI-style chat list of dicts. Normalise to a single string.""" if isinstance(completion, str): return completion if isinstance(completion, list): # [{role:..., content:...}, ...] parts = [] for msg in completion: if isinstance(msg, dict) and "content" in msg: parts.append(str(msg["content"])) else: parts.append(str(msg)) return "\n".join(parts) return str(completion) def _index(value, i: int, default): if value is None: return default if isinstance(value, list): return value[i] if i < len(value) else default return value