Spaces:
Sleeping
Sleeping
| """Per-step + final-action reward grader for the multi-step interactive env. | |
| The training script in train/train_grpo.py uses a single-shot reward (in | |
| train_grpo.make_reward_fn) that scores the whole rollout at once. This | |
| file is what the env returns step-by-step when an agent walks it | |
| interactively (e.g. from the HF Space web UI). | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| R_AUDIO_TOOL_USE = 0.05 | |
| R_TRANSCRIPT_USE = 0.02 | |
| R_BAD_TOOL = -0.10 | |
| R_BAD_ARGS = -0.05 | |
| def step_reward(tool_used: str, error: str | None) -> float: | |
| """Reward for the action just taken (per-step delta). | |
| Episode-final reward (correctness etc.) is added when submit_belief fires. | |
| """ | |
| if error: | |
| return R_BAD_ARGS if "args" in (error or "") else R_BAD_TOOL | |
| if tool_used in ("get_prosody_features", "get_pitch_contour"): | |
| return R_AUDIO_TOOL_USE | |
| if tool_used == "get_transcript": | |
| return R_TRANSCRIPT_USE | |
| return 0.0 | |
| def final_reward( | |
| submitted_label: str | None, | |
| submitted_confidence: float, | |
| gold_label: str, | |
| is_pivot: bool, | |
| n_audio_calls: int, | |
| n_total_calls: int, | |
| ) -> Dict[str, float]: | |
| """Reward computed when submit_belief terminates the episode. | |
| Components: | |
| correctness confidence-weighted match against gold | |
| prosody_grounding 1.0 if any audio-tool call, 0.4 otherwise (0.0 on Pivot) | |
| tool_parsimony 1.0 for 1-3 calls, 0.6 for 4-5, 0.0 for >5 | |
| format_ok 1.0 if a valid label was submitted | |
| Penalties: | |
| no submission -0.30 | |
| too many calls -0.20 | |
| pivot + no audio + wrong -0.50 | |
| """ | |
| components: Dict[str, float] = { | |
| "correctness": 0.0, | |
| "prosody_grounding": 0.0, | |
| "tool_parsimony": 0.0, | |
| "format_ok": 0.0, | |
| } | |
| penalties: Dict[str, float] = { | |
| "no_submission": 0.0, | |
| "too_many_calls": 0.0, | |
| "pivot_no_audio_wrong": 0.0, | |
| } | |
| if submitted_label is None: | |
| penalties["no_submission"] = -0.30 | |
| else: | |
| components["format_ok"] = 1.0 | |
| correct = submitted_label.lower() == gold_label.lower() | |
| conf = max(0.0, min(1.0, submitted_confidence)) | |
| components["correctness"] = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf) | |
| if n_audio_calls >= 1: | |
| components["prosody_grounding"] = 1.0 | |
| elif not is_pivot: | |
| components["prosody_grounding"] = 0.4 | |
| if n_total_calls == 0: | |
| components["tool_parsimony"] = 0.5 | |
| elif 1 <= n_total_calls <= 3: | |
| components["tool_parsimony"] = 1.0 | |
| elif n_total_calls <= 5: | |
| components["tool_parsimony"] = 0.6 | |
| else: | |
| components["tool_parsimony"] = 0.0 | |
| penalties["too_many_calls"] = -0.20 | |
| if ( | |
| is_pivot | |
| and n_audio_calls == 0 | |
| and submitted_label is not None | |
| and submitted_label.lower() != gold_label.lower() | |
| ): | |
| penalties["pivot_no_audio_wrong"] = -0.50 | |
| weighted = ( | |
| 0.50 * components["correctness"] | |
| + 0.25 * components["prosody_grounding"] | |
| + 0.15 * components["tool_parsimony"] | |
| + 0.10 * components["format_ok"] | |
| + sum(penalties.values()) | |
| ) | |
| components["_total"] = round(weighted, 4) | |
| components["_penalties"] = sum(penalties.values()) # type: ignore[assignment] | |
| return components | |