| """Reward callables for TRL GRPO training. |
| |
| These helpers consume rollout metadata and return one float reward per |
| completion, matching TRL reward function expectations. |
| """ |
|
|
| from typing import Any |
|
|
|
|
| def _coerce_bool(value: Any) -> bool: |
| """Convert common truthy/falsey values to bool.""" |
|
|
| if isinstance(value, bool): |
| return value |
| if isinstance(value, (int, float)): |
| return value != 0 |
| if isinstance(value, str): |
| normalized = value.strip().lower() |
| if normalized in {"true", "1", "yes", "y"}: |
| return True |
| if normalized in {"false", "0", "no", "n", ""}: |
| return False |
| return bool(value) |
|
|
|
|
| def _coerce_float(value: Any, default: float = 0.0) -> float: |
| """Convert numeric-like values to float with fallback.""" |
|
|
| try: |
| return float(value) |
| except (TypeError, ValueError): |
| return default |
|
|
|
|
| def _clamp(value: float, low: float, high: float) -> float: |
| """Clamp value to the closed interval [low, high].""" |
|
|
| return max(low, min(high, value)) |
|
|
|
|
| def _extract_metadata_rows( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[dict[str, Any]]: |
| """Resolve one metadata dict per completion. |
| |
| TRL can pass rollout metadata in different shapes depending on wrapper code. |
| We support the common variants: |
| - ``kwargs['metadata']`` as list[dict] |
| - ``kwargs['metadata']`` as dict containing list-valued keys |
| - flattened keys like ``correct``, ``progress``, ``operational`` |
| - fallback to empty dict when metadata is unavailable |
| """ |
|
|
| batch_size = len(completions) |
|
|
| metadata_kw = kwargs.get("metadata") |
| if isinstance(metadata_kw, list): |
| rows: list[dict[str, Any]] = [] |
| for idx in range(batch_size): |
| entry = metadata_kw[idx] if idx < len(metadata_kw) else {} |
| rows.append(entry if isinstance(entry, dict) else {}) |
| return rows |
|
|
| if isinstance(metadata_kw, dict): |
| rows = [] |
| for idx in range(batch_size): |
| row: dict[str, Any] = {} |
| for key, value in metadata_kw.items(): |
| if isinstance(value, list): |
| row[key] = value[idx] if idx < len(value) else None |
| else: |
| row[key] = value |
| rows.append(row) |
| return rows |
|
|
| rows = [] |
| for idx in range(batch_size): |
| row = {} |
| for key in ( |
| "answer_correct", |
| "correct", |
| "cumulative_progress", |
| "progress", |
| "operational_signals", |
| "operational", |
| ): |
| value = kwargs.get(key) |
| if isinstance(value, list): |
| row[key] = value[idx] if idx < len(value) else None |
| elif value is not None: |
| row[key] = value |
| rows.append(row) |
| return rows |
|
|
|
|
| def reward_correctness( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[float]: |
| """Binary reward: 1.0 for correct terminal answer, else 0.0.""" |
|
|
| metadata_rows = _extract_metadata_rows(completions, **kwargs) |
| rewards: list[float] = [] |
| for row in metadata_rows: |
| is_correct = _coerce_bool(row.get("answer_correct", row.get("correct", False))) |
| rewards.append(1.0 if is_correct else 0.0) |
| return rewards |
|
|
|
|
| def reward_progress( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[float]: |
| """Progress reward normalized to [0.0, 1.0].""" |
|
|
| metadata_rows = _extract_metadata_rows(completions, **kwargs) |
| rewards: list[float] = [] |
| for row in metadata_rows: |
| raw = row.get("cumulative_progress", row.get("progress", 0.0)) |
| rewards.append(_clamp(_coerce_float(raw, default=0.0), 0.0, 1.0)) |
| return rewards |
|
|
|
|
| def reward_operational( |
| completions: list[list[dict[str, str]]], |
| **kwargs: Any, |
| ) -> list[float]: |
| """Operational reward from per-step L1-style rollout signals.""" |
|
|
| metadata_rows = _extract_metadata_rows(completions, **kwargs) |
| rewards: list[float] = [] |
| for row in metadata_rows: |
| signals = row.get("operational_signals") |
| if isinstance(signals, list) and signals: |
| score = 0.0 |
| for signal in signals: |
| if not isinstance(signal, dict): |
| continue |
| if _coerce_bool(signal.get("exec_ok", False)): |
| score += 1.0 |
| if _coerce_bool(signal.get("new_info", False)): |
| score += 1.0 |
| if _coerce_bool(signal.get("repeat", False)): |
| score -= 1.0 |
| rewards.append(float(score)) |
| continue |
|
|
| fallback = row.get("operational", 0.0) |
| rewards.append(_coerce_float(fallback, default=0.0)) |
| return rewards |
|
|