File size: 6,728 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fd9d22
 
 
 
 
 
6a7521a
 
 
 
5fd9d22
 
 
 
 
 
6a7521a
 
 
5fd9d22
 
 
 
 
 
 
 
 
 
11c11f8
5fd9d22
11c11f8
 
 
6a7521a
 
 
 
 
 
11c11f8
6a7521a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import annotations

import torch
import torch.nn as nn


class GrowLengthScheduler:
    def __init__(self, stages, total_steps):
        total_frac = sum(frac for _, frac in stages) or 1.0
        cumulative = 0
        self._boundaries = []
        for seq_len, frac in stages:
            cumulative += int(total_steps * frac / total_frac)
            self._boundaries.append((cumulative, int(seq_len)))

    def get_seq_len(self, step: int) -> int:
        for boundary, seq_len in self._boundaries:
            if step < boundary:
                return seq_len
        return self._boundaries[-1][1]


def apply_reservoir_freezing(model) -> int:
    frozen = 0
    for _, module in model.named_modules():
        targets = []
        if hasattr(module, "a_proj") and hasattr(module, "b_proj"):
            targets.extend(["a_proj", "b_proj"])
        if hasattr(module, "fgate") and hasattr(module, "igate"):
            targets.append("fgate")
        if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"):
            targets.append("alpha_proj")
        for attr in targets:
            proj = getattr(module, attr, None)
            if proj is None:
                continue
            weight = getattr(proj, "weight", None)
            if weight is None or not isinstance(weight, nn.Parameter):
                continue
            with torch.no_grad():
                weight.data = torch.randint(-1, 2, weight.shape, dtype=weight.dtype, device=weight.device)
                norm = torch.linalg.matrix_norm(weight.data.float(), ord=2).clamp(min=1.0)
                weight.data.div_(norm)
            weight.requires_grad = False
            frozen += weight.numel()
    return frozen


class SeedReplayMeZO:
    def __init__(self, model, *, lr=1e-4, eps=1e-3, weight_decay=0.0, momentum=0.9):
        self.model = model
        self.lr = float(lr)
        self.eps = float(eps)
        self.wd = float(weight_decay)
        self.mom = float(momentum)
        self._params = []
        seen = set()
        for _, param in model.named_parameters():
            if param.requires_grad and id(param) not in seen:
                self._params.append(param)
                seen.add(id(param))
        self._momentum = [torch.zeros_like(param.data) for param in self._params] if self.mom > 0 else None

    def _perturb_inplace(self, seed: int, scale: float) -> None:
        gen = torch.Generator(device="cpu")
        for i, param in enumerate(self._params):
            gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
            z = torch.empty_like(param.data)
            z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
            param.data.add_(z, alpha=scale)

    def _update_inplace(self, seed: int, projected_grad: float) -> None:
        gen = torch.Generator(device="cpu")
        for i, param in enumerate(self._params):
            gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
            z = torch.empty_like(param.data)
            z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
            param.data.add_(z, alpha=self.eps)
            if self._momentum is not None:
                buf = self._momentum[i]
                buf.mul_(self.mom).add_(z, alpha=projected_grad)
                param.data.add_(buf, alpha=-self.lr)
            else:
                param.data.add_(z, alpha=-self.lr * projected_grad)
            if self.wd > 0:
                param.data.mul_(1 - self.lr * self.wd)

    @torch.no_grad()
    def step(self, loss_fn, batch) -> float:
        seed = int(torch.randint(0, 2**31, (1,)).item())
        self._perturb_inplace(seed, +self.eps)
        loss_pos = float(loss_fn(batch).item())
        self._perturb_inplace(seed, -2.0 * self.eps)
        loss_neg = float(loss_fn(batch).item())
        projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
        self._update_inplace(seed, projected_grad)
        return 0.5 * (loss_pos + loss_neg)


class ProgressiveUnfreezer:
    def __init__(self, model, total_steps, n_stages=4):
        self._layers = model.layers
        self._n = len(self._layers)
        self._total = total_steps
        self._stages = n_stages
        self._block = max(1, self._n // n_stages)
        self._current = self._n
        self.update(0)

    def update(self, step: int) -> int:
        stage = min(step * self._stages // max(1, self._total), self._stages - 1)
        target = max(0, self._n - (stage + 1) * self._block)
        if target != self._current:
            self._current = target
            for i, layer in enumerate(self._layers):
                requires_grad = i >= self._current
                for param in layer.parameters():
                    param.requires_grad = requires_grad
        return self._current


class ProgressiveLoopScheduler:
    """Gradually increase Parcae loop depth during training.

    With STE+AdamW (not MeZO), multi-loop training is affordable.
    Progressive schedule avoids instability from deep loops early on.

    FIX: Old schedule (1→2→3 at 20%/60%/100%) was too aggressive —
    with 5000 steps, loops=2 at step 1000 while the model is still at
    loss=10. Now: loops=1 for 50% (stabilize), loops=2 for 30%, loops=3
    for 20%. This gives the model time to learn basics before iterating.
    """

    def __init__(self, total_steps: int, max_loops: int = 3):
        self._total = total_steps
        self._max_loops = max_loops
        self._schedule = [
            (0.50, 1),  # First 50%: stabilize weights with single pass
            (0.80, 2),  # Next 30%: learn to iterate
            (1.01, min(3, max_loops)),  # Last 20%: deep refinement
        ]

    def get_loops(self, step: int) -> int:
        frac = step / max(1, self._total)
        for threshold, loops in self._schedule:
            if frac < threshold:
                return loops
        return self._schedule[-1][1]


def patch_training_loops(model, num_loops=1) -> None:
    """Set initial loop config. Use ProgressiveLoopScheduler to change during training."""
    if hasattr(model, "loop_controller"):
        model.loop_controller.loop_default = num_loops
        model.loop_controller.loop_min = 1
        model.loop_controller.loop_max = max(num_loops, 3)
    # FIX: Evolution modulation is very expensive on CPU (HDC projections,
    # Hamming distance queries over 50K entries, episodic retrieval).
    # With evo_every_n_layers=4 and 28 layers, that's 7 calls per forward.
    # Set to 28 → evolution fires once per full pass (at layer 0 only),
    # which is enough for the memory to modulate the input embedding.
    if hasattr(model, "evo_every_n_layers"):
        model.evo_every_n_layers = max(model.evo_every_n_layers, 28)