File size: 2,495 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Library Drift Engine.

Manages library version snapshots and triggers version upgrades during
training to create non-stationary verification. In simulation mode it
just tracks the current snapshot index — that index influences
breakage selection and is exposed in observations so the Repair Agent
can adapt.

Also exposes Chojecki GVU's SNR computation
(https://arxiv.org/abs/2512.02731 Definition 4.4).
"""
from __future__ import annotations

import math
from dataclasses import dataclass, field

DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
    {"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
    {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
    {"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
    {"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
]


@dataclass
class LibraryDriftEngine:
    snapshots: list[dict[str, str]] = field(
        default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
    )
    current_index: int = 0
    drift_history: list[dict] = field(default_factory=list)

    def current_versions(self) -> dict[str, str]:
        return dict(self.snapshots[self.current_index])

    def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
        if (
            episode_num > 0
            and episode_num % drift_every == 0
            and self.current_index < len(self.snapshots) - 1
        ):
            prev = self.snapshots[self.current_index]
            self.current_index += 1
            self.drift_history.append(
                {
                    "episode": episode_num,
                    "from": prev,
                    "to": self.snapshots[self.current_index],
                }
            )
            return True
        return False

    def reset(self) -> None:
        self.current_index = 0
        self.drift_history.clear()

    @staticmethod
    def compute_snr(
        recent_held_out: list[float], recent_visible: list[float]
    ) -> dict[str, float]:
        """SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""

        def snr(values: list[float]) -> float:
            if len(values) < 2:
                return 0.0
            mean = sum(values) / len(values)
            var = sum((v - mean) ** 2 for v in values) / len(values)
            return mean**2 / max(var, 1e-8)

        return {
            "snr_verifier": snr(recent_held_out),
            "snr_generator": snr(recent_visible),
        }