Spaces:
Running
Running
| """Training metrics aggregation.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS | |
| class TrainingMetrics: | |
| rewards: list[float] = field(default_factory=list) | |
| legality_rate: list[float] = field(default_factory=list) | |
| severe_violation_rate: list[float] = field(default_factory=list) | |
| abstention_rate: list[float] = field(default_factory=list) | |
| episode_lengths: list[int] = field(default_factory=list) | |
| success_rate: list[float] = field(default_factory=list) | |
| burden_delta: list[float] = field(default_factory=list) | |
| safety_delta: list[float] = field(default_factory=list) | |
| dosing_quality: list[float] = field(default_factory=list) | |
| process_fidelity: list[float] = field(default_factory=list) | |
| exploit_detection_count: int = 0 | |
| timeout_count: int = 0 | |
| failure_visible_count: int = 0 | |
| invalid_action_count: list[float] = field(default_factory=list) | |
| primary_reward_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in PRIMARY_REWARD_KEYS}) | |
| primary_reward_counts: int = 0 | |
| reward_component_totals: dict[str, float] = field(default_factory=lambda: {k: 0.0 for k in REQUIRED_REWARD_KEYS}) | |
| reward_component_counts: int = 0 | |
| def add( | |
| self, | |
| reward: float, | |
| legal: bool, | |
| severe_violation: bool, | |
| abstain: bool, | |
| episode_len: int, | |
| reward_components: dict[str, float] | None = None, | |
| success: bool = False, | |
| burden_delta: float = 0.0, | |
| safety_delta: float = 0.0, | |
| dosing_quality: float = 0.0, | |
| process_fidelity: float = 0.0, | |
| exploit_detected: bool = False, | |
| timeout: bool = False, | |
| failure_visible: bool = False, | |
| invalid_actions: int = 0, | |
| primary_channels: dict[str, float] | None = None, | |
| ) -> None: | |
| self.rewards.append(reward) | |
| self.legality_rate.append(1.0 if legal else 0.0) | |
| self.severe_violation_rate.append(1.0 if severe_violation else 0.0) | |
| self.abstention_rate.append(1.0 if abstain else 0.0) | |
| self.episode_lengths.append(episode_len) | |
| self.success_rate.append(1.0 if success else 0.0) | |
| self.burden_delta.append(burden_delta) | |
| self.safety_delta.append(safety_delta) | |
| self.dosing_quality.append(dosing_quality) | |
| self.process_fidelity.append(process_fidelity) | |
| if exploit_detected: | |
| self.exploit_detection_count += 1 | |
| if timeout: | |
| self.timeout_count += 1 | |
| if failure_visible: | |
| self.failure_visible_count += 1 | |
| self.invalid_action_count.append(float(invalid_actions)) | |
| if reward_components: | |
| for key in REQUIRED_REWARD_KEYS: | |
| self.reward_component_totals[key] += float(reward_components.get(key, 0.0)) | |
| self.reward_component_counts += 1 | |
| if primary_channels: | |
| for key in PRIMARY_REWARD_KEYS: | |
| self.primary_reward_totals[key] += float(primary_channels.get(key, 0.0)) | |
| self.primary_reward_counts += 1 | |
| def summary(self) -> dict[str, float]: | |
| def avg(values: list[float]) -> float: | |
| return sum(values) / len(values) if values else 0.0 | |
| return { | |
| "avg_reward": avg(self.rewards), | |
| "legality_rate": avg(self.legality_rate), | |
| "severe_violation_rate": avg(self.severe_violation_rate), | |
| "abstention_rate": avg(self.abstention_rate), | |
| "avg_episode_length": avg([float(x) for x in self.episode_lengths]), | |
| "success_rate": avg(self.success_rate), | |
| "avg_burden_delta": avg(self.burden_delta), | |
| "avg_safety_delta": avg(self.safety_delta), | |
| "avg_dosing_quality": avg(self.dosing_quality), | |
| "avg_process_fidelity": avg(self.process_fidelity), | |
| "exploit_detection_count": float(self.exploit_detection_count), | |
| "timeout_rate": (float(self.timeout_count) / len(self.rewards)) if self.rewards else 0.0, | |
| "failure_visible_rate": (float(self.failure_visible_count) / len(self.rewards)) if self.rewards else 0.0, | |
| "avg_invalid_actions": avg(self.invalid_action_count), | |
| "reward_columns": { | |
| key: ( | |
| self.reward_component_totals[key] / self.reward_component_counts | |
| if self.reward_component_counts | |
| else 0.0 | |
| ) | |
| for key in REQUIRED_REWARD_KEYS | |
| }, | |
| "primary_reward_channels": { | |
| key: ( | |
| self.primary_reward_totals[key] / self.primary_reward_counts if self.primary_reward_counts else 0.0 | |
| ) | |
| for key in PRIMARY_REWARD_KEYS | |
| }, | |
| } | |