"""Per-step + final-action reward grader for the multi-step interactive env. The training script in train/train_grpo.py uses a single-shot reward (in train_grpo.make_reward_fn) that scores the whole rollout at once. This file is what the env returns step-by-step when an agent walks it interactively (e.g. from the HF Space web UI). """ from __future__ import annotations from typing import Any, Dict R_AUDIO_TOOL_USE = 0.05 R_TRANSCRIPT_USE = 0.02 R_BAD_TOOL = -0.10 R_BAD_ARGS = -0.05 def step_reward(tool_used: str, error: str | None) -> float: """Reward for the action just taken (per-step delta). Episode-final reward (correctness etc.) is added when submit_belief fires. """ if error: return R_BAD_ARGS if "args" in (error or "") else R_BAD_TOOL if tool_used in ("get_prosody_features", "get_pitch_contour"): return R_AUDIO_TOOL_USE if tool_used == "get_transcript": return R_TRANSCRIPT_USE return 0.0 def final_reward( submitted_label: str | None, submitted_confidence: float, gold_label: str, is_pivot: bool, n_audio_calls: int, n_total_calls: int, ) -> Dict[str, float]: """Reward computed when submit_belief terminates the episode. Components: correctness confidence-weighted match against gold prosody_grounding 1.0 if any audio-tool call, 0.4 otherwise (0.0 on Pivot) tool_parsimony 1.0 for 1-3 calls, 0.6 for 4-5, 0.0 for >5 format_ok 1.0 if a valid label was submitted Penalties: no submission -0.30 too many calls -0.20 pivot + no audio + wrong -0.50 """ components: Dict[str, float] = { "correctness": 0.0, "prosody_grounding": 0.0, "tool_parsimony": 0.0, "format_ok": 0.0, } penalties: Dict[str, float] = { "no_submission": 0.0, "too_many_calls": 0.0, "pivot_no_audio_wrong": 0.0, } if submitted_label is None: penalties["no_submission"] = -0.30 else: components["format_ok"] = 1.0 correct = submitted_label.lower() == gold_label.lower() conf = max(0.0, min(1.0, submitted_confidence)) components["correctness"] = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf) if n_audio_calls >= 1: components["prosody_grounding"] = 1.0 elif not is_pivot: components["prosody_grounding"] = 0.4 if n_total_calls == 0: components["tool_parsimony"] = 0.5 elif 1 <= n_total_calls <= 3: components["tool_parsimony"] = 1.0 elif n_total_calls <= 5: components["tool_parsimony"] = 0.6 else: components["tool_parsimony"] = 0.0 penalties["too_many_calls"] = -0.20 if ( is_pivot and n_audio_calls == 0 and submitted_label is not None and submitted_label.lower() != gold_label.lower() ): penalties["pivot_no_audio_wrong"] = -0.50 weighted = ( 0.50 * components["correctness"] + 0.25 * components["prosody_grounding"] + 0.15 * components["tool_parsimony"] + 0.10 * components["format_ok"] + sum(penalties.values()) ) components["_total"] = round(weighted, 4) components["_penalties"] = sum(penalties.values()) # type: ignore[assignment] return components