sql_env / training /rewards.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Reward callables for TRL GRPO training.
These helpers consume rollout metadata and return one float reward per
completion, matching TRL reward function expectations.
"""
from typing import Any
def _coerce_bool(value: Any) -> bool:
"""Convert common truthy/falsey values to bool."""
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value != 0
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes", "y"}:
return True
if normalized in {"false", "0", "no", "n", ""}:
return False
return bool(value)
def _coerce_float(value: Any, default: float = 0.0) -> float:
"""Convert numeric-like values to float with fallback."""
try:
return float(value)
except (TypeError, ValueError):
return default
def _clamp(value: float, low: float, high: float) -> float:
"""Clamp value to the closed interval [low, high]."""
return max(low, min(high, value))
def _extract_metadata_rows(
completions: list[list[dict[str, str]]],
**kwargs: Any,
) -> list[dict[str, Any]]:
"""Resolve one metadata dict per completion.
TRL can pass rollout metadata in different shapes depending on wrapper code.
We support the common variants:
- ``kwargs['metadata']`` as list[dict]
- ``kwargs['metadata']`` as dict containing list-valued keys
- flattened keys like ``correct``, ``progress``, ``operational``
- fallback to empty dict when metadata is unavailable
"""
batch_size = len(completions)
metadata_kw = kwargs.get("metadata")
if isinstance(metadata_kw, list):
rows: list[dict[str, Any]] = []
for idx in range(batch_size):
entry = metadata_kw[idx] if idx < len(metadata_kw) else {}
rows.append(entry if isinstance(entry, dict) else {})
return rows
if isinstance(metadata_kw, dict):
rows = []
for idx in range(batch_size):
row: dict[str, Any] = {}
for key, value in metadata_kw.items():
if isinstance(value, list):
row[key] = value[idx] if idx < len(value) else None
else:
row[key] = value
rows.append(row)
return rows
rows = []
for idx in range(batch_size):
row = {}
for key in (
"answer_correct",
"correct",
"cumulative_progress",
"progress",
"operational_signals",
"operational",
):
value = kwargs.get(key)
if isinstance(value, list):
row[key] = value[idx] if idx < len(value) else None
elif value is not None:
row[key] = value
rows.append(row)
return rows
def reward_correctness(
completions: list[list[dict[str, str]]],
**kwargs: Any,
) -> list[float]:
"""Binary reward: 1.0 for correct terminal answer, else 0.0."""
metadata_rows = _extract_metadata_rows(completions, **kwargs)
rewards: list[float] = []
for row in metadata_rows:
is_correct = _coerce_bool(row.get("answer_correct", row.get("correct", False)))
rewards.append(1.0 if is_correct else 0.0)
return rewards
def reward_progress(
completions: list[list[dict[str, str]]],
**kwargs: Any,
) -> list[float]:
"""Progress reward normalized to [0.0, 1.0]."""
metadata_rows = _extract_metadata_rows(completions, **kwargs)
rewards: list[float] = []
for row in metadata_rows:
raw = row.get("cumulative_progress", row.get("progress", 0.0))
rewards.append(_clamp(_coerce_float(raw, default=0.0), 0.0, 1.0))
return rewards
def reward_operational(
completions: list[list[dict[str, str]]],
**kwargs: Any,
) -> list[float]:
"""Operational reward from per-step L1-style rollout signals."""
metadata_rows = _extract_metadata_rows(completions, **kwargs)
rewards: list[float] = []
for row in metadata_rows:
signals = row.get("operational_signals")
if isinstance(signals, list) and signals:
score = 0.0
for signal in signals:
if not isinstance(signal, dict):
continue
if _coerce_bool(signal.get("exec_ok", False)):
score += 1.0
if _coerce_bool(signal.get("new_info", False)):
score += 1.0
if _coerce_bool(signal.get("repeat", False)):
score -= 1.0
rewards.append(float(score))
continue
fallback = row.get("operational", 0.0)
rewards.append(_coerce_float(fallback, default=0.0))
return rewards