TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Training metrics aggregation."""
from __future__ import annotations
from dataclasses import dataclass, field
from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS
@dataclass
class TrainingMetrics:
rewards: list[float] = field(default_factory=list)
legality_rate: list[float] = field(default_factory=list)
severe_violation_rate: list[float] = field(default_factory=list)
abstention_rate: list[float] = field(default_factory=list)
episode_lengths: list[int] = field(default_factory=list)
success_rate: list[float] = field(default_factory=list)
burden_delta: list[float] = field(default_factory=list)
safety_delta: list[float] = field(default_factory=list)
dosing_quality: list[float] = field(default_factory=list)
process_fidelity: list[float] = field(default_factory=list)
exploit_detection_count: int = 0
timeout_count: int = 0
failure_visible_count: int = 0
invalid_action_count: list[float] = field(default_factory=list)
primary_reward_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in PRIMARY_REWARD_KEYS})
primary_reward_counts: int = 0
reward_component_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in REQUIRED_REWARD_KEYS})
reward_component_counts: int = 0
def add(
self,
reward: float,
legal: bool,
severe_violation: bool,
abstain: bool,
episode_len: int,
reward_components: dict[str, float] | None = None,
success: bool = False,
burden_delta: float = 0.0,
safety_delta: float = 0.0,
dosing_quality: float = 0.0,
process_fidelity: float = 0.0,
exploit_detected: bool = False,
timeout: bool = False,
failure_visible: bool = False,
invalid_actions: int = 0,
primary_channels: dict[str, float] | None = None,
) -> None:
self.rewards.append(reward)
self.legality_rate.append(1.0 if legal else 0.0)
self.severe_violation_rate.append(1.0 if severe_violation else 0.0)
self.abstention_rate.append(1.0 if abstain else 0.0)
self.episode_lengths.append(episode_len)
self.success_rate.append(1.0 if success else 0.0)
self.burden_delta.append(burden_delta)
self.safety_delta.append(safety_delta)
self.dosing_quality.append(dosing_quality)
self.process_fidelity.append(process_fidelity)
if exploit_detected:
self.exploit_detection_count += 1
if timeout:
self.timeout_count += 1
if failure_visible:
self.failure_visible_count += 1
self.invalid_action_count.append(float(invalid_actions))
if reward_components:
for key in REQUIRED_REWARD_KEYS:
self.reward_component_totals[key] += float(reward_components.get(key, 0.0))
self.reward_component_counts += 1
if primary_channels:
for key in PRIMARY_REWARD_KEYS:
self.primary_reward_totals[key] += float(primary_channels.get(key, 0.0))
self.primary_reward_counts += 1
def summary(self) -> dict[str, float]:
def avg(values: list[float]) -> float:
return sum(values) / len(values) if values else 0.0
return {
"avg_reward": avg(self.rewards),
"legality_rate": avg(self.legality_rate),
"severe_violation_rate": avg(self.severe_violation_rate),
"abstention_rate": avg(self.abstention_rate),
"avg_episode_length": avg([float(x) for x in self.episode_lengths]),
"success_rate": avg(self.success_rate),
"avg_burden_delta": avg(self.burden_delta),
"avg_safety_delta": avg(self.safety_delta),
"avg_dosing_quality": avg(self.dosing_quality),
"avg_process_fidelity": avg(self.process_fidelity),
"exploit_detection_count": float(self.exploit_detection_count),
"timeout_rate": (float(self.timeout_count) / len(self.rewards)) if self.rewards else 0.0,
"failure_visible_rate": (float(self.failure_visible_count) / len(self.rewards)) if self.rewards else 0.0,
"avg_invalid_actions": avg(self.invalid_action_count),
"reward_columns": {
key: (
self.reward_component_totals[key] / self.reward_component_counts
if self.reward_component_counts
else 0.0
)
for key in REQUIRED_REWARD_KEYS
},
"primary_reward_channels": {
key: (
self.primary_reward_totals[key] / self.primary_reward_counts if self.primary_reward_counts else 0.0
)
for key in PRIMARY_REWARD_KEYS
},
}