"""Training metrics aggregation.""" from __future__ import annotations from dataclasses import dataclass, field from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS @dataclass class TrainingMetrics: rewards: list[float] = field(default_factory=list) legality_rate: list[float] = field(default_factory=list) severe_violation_rate: list[float] = field(default_factory=list) abstention_rate: list[float] = field(default_factory=list) episode_lengths: list[int] = field(default_factory=list) success_rate: list[float] = field(default_factory=list) burden_delta: list[float] = field(default_factory=list) safety_delta: list[float] = field(default_factory=list) dosing_quality: list[float] = field(default_factory=list) process_fidelity: list[float] = field(default_factory=list) exploit_detection_count: int = 0 timeout_count: int = 0 failure_visible_count: int = 0 invalid_action_count: list[float] = field(default_factory=list) primary_reward_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in PRIMARY_REWARD_KEYS}) primary_reward_counts: int = 0 reward_component_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in REQUIRED_REWARD_KEYS}) reward_component_counts: int = 0 def add( self, reward: float, legal: bool, severe_violation: bool, abstain: bool, episode_len: int, reward_components: dict[str, float] | None = None, success: bool = False, burden_delta: float = 0.0, safety_delta: float = 0.0, dosing_quality: float = 0.0, process_fidelity: float = 0.0, exploit_detected: bool = False, timeout: bool = False, failure_visible: bool = False, invalid_actions: int = 0, primary_channels: dict[str, float] | None = None, ) -> None: self.rewards.append(reward) self.legality_rate.append(1.0 if legal else 0.0) self.severe_violation_rate.append(1.0 if severe_violation else 0.0) self.abstention_rate.append(1.0 if abstain else 0.0) self.episode_lengths.append(episode_len) self.success_rate.append(1.0 if success else 0.0) self.burden_delta.append(burden_delta) self.safety_delta.append(safety_delta) self.dosing_quality.append(dosing_quality) self.process_fidelity.append(process_fidelity) if exploit_detected: self.exploit_detection_count += 1 if timeout: self.timeout_count += 1 if failure_visible: self.failure_visible_count += 1 self.invalid_action_count.append(float(invalid_actions)) if reward_components: for key in REQUIRED_REWARD_KEYS: self.reward_component_totals[key] += float(reward_components.get(key, 0.0)) self.reward_component_counts += 1 if primary_channels: for key in PRIMARY_REWARD_KEYS: self.primary_reward_totals[key] += float(primary_channels.get(key, 0.0)) self.primary_reward_counts += 1 def summary(self) -> dict[str, float]: def avg(values: list[float]) -> float: return sum(values) / len(values) if values else 0.0 return { "avg_reward": avg(self.rewards), "legality_rate": avg(self.legality_rate), "severe_violation_rate": avg(self.severe_violation_rate), "abstention_rate": avg(self.abstention_rate), "avg_episode_length": avg([float(x) for x in self.episode_lengths]), "success_rate": avg(self.success_rate), "avg_burden_delta": avg(self.burden_delta), "avg_safety_delta": avg(self.safety_delta), "avg_dosing_quality": avg(self.dosing_quality), "avg_process_fidelity": avg(self.process_fidelity), "exploit_detection_count": float(self.exploit_detection_count), "timeout_rate": (float(self.timeout_count) / len(self.rewards)) if self.rewards else 0.0, "failure_visible_rate": (float(self.failure_visible_count) / len(self.rewards)) if self.rewards else 0.0, "avg_invalid_actions": avg(self.invalid_action_count), "reward_columns": { key: ( self.reward_component_totals[key] / self.reward_component_counts if self.reward_component_counts else 0.0 ) for key in REQUIRED_REWARD_KEYS }, "primary_reward_channels": { key: ( self.primary_reward_totals[key] / self.primary_reward_counts if self.primary_reward_counts else 0.0 ) for key in PRIMARY_REWARD_KEYS }, }