File size: 4,851 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Training metrics aggregation."""

from __future__ import annotations

from dataclasses import dataclass, field

from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS


@dataclass
class TrainingMetrics:
    rewards: list[float] = field(default_factory=list)
    legality_rate: list[float] = field(default_factory=list)
    severe_violation_rate: list[float] = field(default_factory=list)
    abstention_rate: list[float] = field(default_factory=list)
    episode_lengths: list[int] = field(default_factory=list)
    success_rate: list[float] = field(default_factory=list)
    burden_delta: list[float] = field(default_factory=list)
    safety_delta: list[float] = field(default_factory=list)
    dosing_quality: list[float] = field(default_factory=list)
    process_fidelity: list[float] = field(default_factory=list)
    exploit_detection_count: int = 0
    timeout_count: int = 0
    failure_visible_count: int = 0
    invalid_action_count: list[float] = field(default_factory=list)
    primary_reward_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in PRIMARY_REWARD_KEYS})
    primary_reward_counts: int = 0
    reward_component_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in REQUIRED_REWARD_KEYS})
    reward_component_counts: int = 0

    def add(
        self,
        reward: float,
        legal: bool,
        severe_violation: bool,
        abstain: bool,
        episode_len: int,
        reward_components: dict[str, float] | None = None,
        success: bool = False,
        burden_delta: float = 0.0,
        safety_delta: float = 0.0,
        dosing_quality: float = 0.0,
        process_fidelity: float = 0.0,
        exploit_detected: bool = False,
        timeout: bool = False,
        failure_visible: bool = False,
        invalid_actions: int = 0,
        primary_channels: dict[str, float] | None = None,
    ) -> None:
        self.rewards.append(reward)
        self.legality_rate.append(1.0 if legal else 0.0)
        self.severe_violation_rate.append(1.0 if severe_violation else 0.0)
        self.abstention_rate.append(1.0 if abstain else 0.0)
        self.episode_lengths.append(episode_len)
        self.success_rate.append(1.0 if success else 0.0)
        self.burden_delta.append(burden_delta)
        self.safety_delta.append(safety_delta)
        self.dosing_quality.append(dosing_quality)
        self.process_fidelity.append(process_fidelity)
        if exploit_detected:
            self.exploit_detection_count += 1
        if timeout:
            self.timeout_count += 1
        if failure_visible:
            self.failure_visible_count += 1
        self.invalid_action_count.append(float(invalid_actions))
        if reward_components:
            for key in REQUIRED_REWARD_KEYS:
                self.reward_component_totals[key] += float(reward_components.get(key, 0.0))
            self.reward_component_counts += 1
        if primary_channels:
            for key in PRIMARY_REWARD_KEYS:
                self.primary_reward_totals[key] += float(primary_channels.get(key, 0.0))
            self.primary_reward_counts += 1

    def summary(self) -> dict[str, float]:
        def avg(values: list[float]) -> float:
            return sum(values) / len(values) if values else 0.0

        return {
            "avg_reward": avg(self.rewards),
            "legality_rate": avg(self.legality_rate),
            "severe_violation_rate": avg(self.severe_violation_rate),
            "abstention_rate": avg(self.abstention_rate),
            "avg_episode_length": avg([float(x) for x in self.episode_lengths]),
            "success_rate": avg(self.success_rate),
            "avg_burden_delta": avg(self.burden_delta),
            "avg_safety_delta": avg(self.safety_delta),
            "avg_dosing_quality": avg(self.dosing_quality),
            "avg_process_fidelity": avg(self.process_fidelity),
            "exploit_detection_count": float(self.exploit_detection_count),
            "timeout_rate": (float(self.timeout_count) / len(self.rewards)) if self.rewards else 0.0,
            "failure_visible_rate": (float(self.failure_visible_count) / len(self.rewards)) if self.rewards else 0.0,
            "avg_invalid_actions": avg(self.invalid_action_count),
            "reward_columns": {
                key: (
                    self.reward_component_totals[key] / self.reward_component_counts
                    if self.reward_component_counts
                    else 0.0
                )
                for key in REQUIRED_REWARD_KEYS
            },
            "primary_reward_channels": {
                key: (
                    self.primary_reward_totals[key] / self.primary_reward_counts if self.primary_reward_counts else 0.0
                )
                for key in PRIMARY_REWARD_KEYS
            },
        }