File size: 3,372 Bytes
9f43137
225e725
9f43137
 
 
 
225e725
 
 
 
 
 
9f43137
 
 
 
225e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f43137
 
 
 
 
 
 
 
 
 
 
 
225e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Per-step + final-action reward grader for the multi-step interactive env.

The training script in train/train_grpo.py uses a single-shot reward (in
train_grpo.make_reward_fn) that scores the whole rollout at once. This
file is what the env returns step-by-step when an agent walks it
interactively (e.g. from the HF Space web UI).
"""
from __future__ import annotations

from typing import Any, Dict


R_AUDIO_TOOL_USE = 0.05
R_TRANSCRIPT_USE = 0.02
R_BAD_TOOL = -0.10
R_BAD_ARGS = -0.05


def step_reward(tool_used: str, error: str | None) -> float:
    """Reward for the action just taken (per-step delta).

    Episode-final reward (correctness etc.) is added when submit_belief fires.
    """
    if error:
        return R_BAD_ARGS if "args" in (error or "") else R_BAD_TOOL
    if tool_used in ("get_prosody_features", "get_pitch_contour"):
        return R_AUDIO_TOOL_USE
    if tool_used == "get_transcript":
        return R_TRANSCRIPT_USE
    return 0.0


def final_reward(
    submitted_label: str | None,
    submitted_confidence: float,
    gold_label: str,
    is_pivot: bool,
    n_audio_calls: int,
    n_total_calls: int,
) -> Dict[str, float]:
    """Reward computed when submit_belief terminates the episode.

    Components:
        correctness        confidence-weighted match against gold
        prosody_grounding  1.0 if any audio-tool call, 0.4 otherwise (0.0 on Pivot)
        tool_parsimony     1.0 for 1-3 calls, 0.6 for 4-5, 0.0 for >5
        format_ok          1.0 if a valid label was submitted

    Penalties:
        no submission              -0.30
        too many calls             -0.20
        pivot + no audio + wrong   -0.50
    """
    components: Dict[str, float] = {
        "correctness": 0.0,
        "prosody_grounding": 0.0,
        "tool_parsimony": 0.0,
        "format_ok": 0.0,
    }
    penalties: Dict[str, float] = {
        "no_submission": 0.0,
        "too_many_calls": 0.0,
        "pivot_no_audio_wrong": 0.0,
    }

    if submitted_label is None:
        penalties["no_submission"] = -0.30
    else:
        components["format_ok"] = 1.0
        correct = submitted_label.lower() == gold_label.lower()
        conf = max(0.0, min(1.0, submitted_confidence))
        components["correctness"] = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf)

    if n_audio_calls >= 1:
        components["prosody_grounding"] = 1.0
    elif not is_pivot:
        components["prosody_grounding"] = 0.4

    if n_total_calls == 0:
        components["tool_parsimony"] = 0.5
    elif 1 <= n_total_calls <= 3:
        components["tool_parsimony"] = 1.0
    elif n_total_calls <= 5:
        components["tool_parsimony"] = 0.6
    else:
        components["tool_parsimony"] = 0.0
        penalties["too_many_calls"] = -0.20

    if (
        is_pivot
        and n_audio_calls == 0
        and submitted_label is not None
        and submitted_label.lower() != gold_label.lower()
    ):
        penalties["pivot_no_audio_wrong"] = -0.50

    weighted = (
        0.50 * components["correctness"]
        + 0.25 * components["prosody_grounding"]
        + 0.15 * components["tool_parsimony"]
        + 0.10 * components["format_ok"]
        + sum(penalties.values())
    )
    components["_total"] = round(weighted, 4)
    components["_penalties"] = sum(penalties.values())  # type: ignore[assignment]
    return components