anugrah55's picture
Overhaul trainer: TRL GRPO with env-backed reward, Qwen2.5-0.5B 4bit+LoRA, slim PyTorch CUDA base, heartbeat HTTP for HF Spaces health probe
d597642 verified
"""Reward functions for GRPO: env-backed verification + cheap shaping signals.
GRPO takes a list of `reward_funcs`. Each must accept `completions` and any
columns from the dataset as kwargs, and return one scalar per completion.
"""
from __future__ import annotations
import logging
import re
from typing import List
from .client import EnvClient
from .prompt import extract_code
log = logging.getLogger("opensleuth.reward")
def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0) -> callable:
"""Verifier-backed reward. Calls the env's `submit` and returns the env's
reward divided by `scale` (default: divide by 100 so a perfect submission
is ~+1.5 and a bad one is around -0.5; this keeps GRPO advantages well
behaved without needing reward normalisation).
"""
def env_reward(completions, target_function_name=None, row_seed=None, **kwargs): # noqa: ANN001
rewards: List[float] = []
# GRPO calls the reward fn once per completion; both target_function_name
# and row_seed come in as lists of length len(completions).
for i, completion in enumerate(completions):
text = _extract_text(completion)
code = extract_code(text)
tname = _index(target_function_name, i, default="fibonacci")
seed = _index(row_seed, i, default=0)
try:
env_reward_value = client.score_submission(tname, code, seed=seed)
except Exception as e: # noqa: BLE001
log.warning("env scoring failed for %s: %s", tname, e)
env_reward_value = -50.0
rewards.append(env_reward_value * scale)
return rewards
return env_reward
_FUNC_RE = re.compile(r"^def\s+(\w+)\s*\(", re.MULTILINE)
def format_reward(completions, target_function_name=None, **kwargs): # noqa: ANN001
"""Cheap shaping reward: +0.2 if the completion contains a fenced python
block AND defines a function with the right name. Encourages the model to
converge on the expected output format quickly so the env reward becomes
informative early in training."""
rewards: List[float] = []
for i, completion in enumerate(completions):
text = _extract_text(completion)
score = 0.0
if "```python" in text or "```\n" in text:
score += 0.1
code = extract_code(text)
m = _FUNC_RE.search(code)
tname = _index(target_function_name, i, default=None)
if m and (tname is None or m.group(1) == tname):
score += 0.1
rewards.append(score)
return rewards
def _extract_text(completion): # noqa: ANN001
"""GRPO can pass either a string or an OpenAI-style chat list of dicts.
Normalise to a single string."""
if isinstance(completion, str):
return completion
if isinstance(completion, list):
# [{role:..., content:...}, ...]
parts = []
for msg in completion:
if isinstance(msg, dict) and "content" in msg:
parts.append(str(msg["content"]))
else:
parts.append(str(msg))
return "\n".join(parts)
return str(completion)
def _index(value, i: int, default):
if value is None:
return default
if isinstance(value, list):
return value[i] if i < len(value) else default
return value