"""Reward callables for TRL GRPO training. These helpers consume rollout metadata and return one float reward per completion, matching TRL reward function expectations. """ from typing import Any def _coerce_bool(value: Any) -> bool: """Convert common truthy/falsey values to bool.""" if isinstance(value, bool): return value if isinstance(value, (int, float)): return value != 0 if isinstance(value, str): normalized = value.strip().lower() if normalized in {"true", "1", "yes", "y"}: return True if normalized in {"false", "0", "no", "n", ""}: return False return bool(value) def _coerce_float(value: Any, default: float = 0.0) -> float: """Convert numeric-like values to float with fallback.""" try: return float(value) except (TypeError, ValueError): return default def _clamp(value: float, low: float, high: float) -> float: """Clamp value to the closed interval [low, high].""" return max(low, min(high, value)) def _extract_metadata_rows( completions: list[list[dict[str, str]]], **kwargs: Any, ) -> list[dict[str, Any]]: """Resolve one metadata dict per completion. TRL can pass rollout metadata in different shapes depending on wrapper code. We support the common variants: - ``kwargs['metadata']`` as list[dict] - ``kwargs['metadata']`` as dict containing list-valued keys - flattened keys like ``correct``, ``progress``, ``operational`` - fallback to empty dict when metadata is unavailable """ batch_size = len(completions) metadata_kw = kwargs.get("metadata") if isinstance(metadata_kw, list): rows: list[dict[str, Any]] = [] for idx in range(batch_size): entry = metadata_kw[idx] if idx < len(metadata_kw) else {} rows.append(entry if isinstance(entry, dict) else {}) return rows if isinstance(metadata_kw, dict): rows = [] for idx in range(batch_size): row: dict[str, Any] = {} for key, value in metadata_kw.items(): if isinstance(value, list): row[key] = value[idx] if idx < len(value) else None else: row[key] = value rows.append(row) return rows rows = [] for idx in range(batch_size): row = {} for key in ( "answer_correct", "correct", "cumulative_progress", "progress", "operational_signals", "operational", ): value = kwargs.get(key) if isinstance(value, list): row[key] = value[idx] if idx < len(value) else None elif value is not None: row[key] = value rows.append(row) return rows def reward_correctness( completions: list[list[dict[str, str]]], **kwargs: Any, ) -> list[float]: """Binary reward: 1.0 for correct terminal answer, else 0.0.""" metadata_rows = _extract_metadata_rows(completions, **kwargs) rewards: list[float] = [] for row in metadata_rows: is_correct = _coerce_bool(row.get("answer_correct", row.get("correct", False))) rewards.append(1.0 if is_correct else 0.0) return rewards def reward_progress( completions: list[list[dict[str, str]]], **kwargs: Any, ) -> list[float]: """Progress reward normalized to [0.0, 1.0].""" metadata_rows = _extract_metadata_rows(completions, **kwargs) rewards: list[float] = [] for row in metadata_rows: raw = row.get("cumulative_progress", row.get("progress", 0.0)) rewards.append(_clamp(_coerce_float(raw, default=0.0), 0.0, 1.0)) return rewards def reward_operational( completions: list[list[dict[str, str]]], **kwargs: Any, ) -> list[float]: """Operational reward from per-step L1-style rollout signals.""" metadata_rows = _extract_metadata_rows(completions, **kwargs) rewards: list[float] = [] for row in metadata_rows: signals = row.get("operational_signals") if isinstance(signals, list) and signals: score = 0.0 for signal in signals: if not isinstance(signal, dict): continue if _coerce_bool(signal.get("exec_ok", False)): score += 1.0 if _coerce_bool(signal.get("new_info", False)): score += 1.0 if _coerce_bool(signal.get("repeat", False)): score -= 1.0 rewards.append(float(score)) continue fallback = row.get("operational", 0.0) rewards.append(_coerce_float(fallback, default=0.0)) return rewards