| """Library Drift Engine.
|
|
|
| Manages library version snapshots and triggers version upgrades during
|
| training to create non-stationary verification. In simulation mode it
|
| just tracks the current snapshot index — that index influences
|
| breakage selection and is exposed in observations so the Repair Agent
|
| can adapt.
|
|
|
| Also exposes Chojecki GVU's SNR computation
|
| (https://arxiv.org/abs/2512.02731 Definition 4.4).
|
| """
|
| from __future__ import annotations
|
|
|
| import math
|
| from dataclasses import dataclass, field
|
|
|
| DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
|
| {"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
|
| {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
|
| {"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
|
| {"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
|
| ]
|
|
|
|
|
| @dataclass
|
| class LibraryDriftEngine:
|
| snapshots: list[dict[str, str]] = field(
|
| default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
|
| )
|
| current_index: int = 0
|
| drift_history: list[dict] = field(default_factory=list)
|
|
|
| def current_versions(self) -> dict[str, str]:
|
| return dict(self.snapshots[self.current_index])
|
|
|
| def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
|
| if (
|
| episode_num > 0
|
| and episode_num % drift_every == 0
|
| and self.current_index < len(self.snapshots) - 1
|
| ):
|
| prev = self.snapshots[self.current_index]
|
| self.current_index += 1
|
| self.drift_history.append(
|
| {
|
| "episode": episode_num,
|
| "from": prev,
|
| "to": self.snapshots[self.current_index],
|
| }
|
| )
|
| return True
|
| return False
|
|
|
| def reset(self) -> None:
|
| self.current_index = 0
|
| self.drift_history.clear()
|
|
|
| @staticmethod
|
| def compute_snr(
|
| recent_held_out: list[float], recent_visible: list[float]
|
| ) -> dict[str, float]:
|
| """SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""
|
|
|
| def snr(values: list[float]) -> float:
|
| if len(values) < 2:
|
| return 0.0
|
| mean = sum(values) / len(values)
|
| var = sum((v - mean) ** 2 for v in values) / len(values)
|
| return mean**2 / max(var, 1e-8)
|
|
|
| return {
|
| "snr_verifier": snr(recent_held_out),
|
| "snr_generator": snr(recent_visible),
|
| }
|
|
|