| """ |
| tscale_mini β TernaryScaleTensor with BigInt correlation tracking. |
| |
| Key idea from main ternary_scale.py: |
| score = Ξ£ (grad_sign Γ T) per group β correlation of gradient with weight direction |
| - score > 0: grad aligns with T (direction correct, need LESS magnitude) |
| - score < 0: grad opposes T (direction wrong, need MORE magnitude) |
| |
| Instead of E_accum (int8 threshold flips), use BigInt accumulators (int64): |
| corr_accum += score β never clips, never resets |
| step += 1 |
| mean_corr = corr_accum / (step Γ gs) β rational in [-1, +1] |
| |
| S = 2^E Γ (1 + mean_corr) β continuous S from E base + BigInt-derived adjustment |
| |
| The BigInt division (corr_accum / (stepΓgs)) provides the precision. |
| mean_corr is a fixed-point number with ~30 bits of precision from int64. |
| This gives S continuous fine-tuning per group instead of just 256 discrete 2^E values. |
| |
| PERSISTENT (all int): |
| T_packed (uint8) β 5 trits/byte |
| E (int8) β per-group base log2 scale |
| corr_accum (int64) β per-group BigInt correlation accumulator |
| step_counter (int64) β total steps |
| |
| EPHEMERAL (float32, only during forward/backward): |
| w_eff = S Γ T = 2^E Γ (1 + mean_corr) Γ T |
| """ |
| import math, torch, torch.nn as nn, torch.nn.functional as F |
| from math import ceil |
|
|
|
|
| |
|
|
| def pack_ternary(w): |
| q = torch.empty_like(w, dtype=torch.uint8) |
| q[w < 0] = 0; q[w == 0] = 1; q[w > 0] = 2 |
| flat = q.flatten() |
| pad = (-len(flat)) % 5 |
| if pad: |
| flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)]) |
| flat = flat.view(-1, 5) |
| packed = (flat[:, 0] + flat[:, 1] * 3 + flat[:, 2] * 9 |
| + flat[:, 3] * 27 + flat[:, 4] * 81).to(torch.uint8) |
| return packed.cpu(), w.shape, pad |
|
|
|
|
| def unpack_ternary(packed, shape, pad=0): |
| p = packed.to(torch.int16) |
| t0 = p % 3; p //= 3; t1 = p % 3; p //= 3 |
| t2 = p % 3; p //= 3; t3 = p % 3; p //= 3; t4 = p % 3 |
| out = torch.stack([t0, t1, t2, t3, t4], dim=1).flatten() |
| if pad: out = out[:-pad] |
| out = out.view(shape).to(torch.int8) |
| out[out == 0] = -1; out[out == 1] = 0; out[out == 2] = 1 |
| return out |
|
|
|
|
| |
|
|
| def _ternarize(x, threshold=0.05): |
| return x.sign() * (x.abs() > threshold).to(x.dtype) |
|
|
|
|
| def _n_groups(out_dim, in_dim, gs): |
| return out_dim * ceil(in_dim / gs) |
|
|
|
|
| |
|
|
| class TernaryScaleTensor(nn.Module): |
| """ |
| Ternary linear layer with BigInt correlation tracking for S. |
| |
| Forward: S = 2^E Γ (1 + mean_corr), w_eff = S Γ T |
| where mean_corr = corr_accum / (step Γ gs) β from BigInt, ephem to float32 |
| |
| Persistent: |
| T_packed (uint8) β 5 trits/byte |
| E (int8) β per-group base log2 scale |
| corr_accum (int64) β per-group BigInt: Ξ£ grad_sign Γ T (correlation) |
| step_counter (int64) |
| """ |
|
|
| def __init__(self, in_dim, out_dim, threshold=0.05, group_size=32, bias=False): |
| super().__init__() |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
| self.group_size = group_size |
|
|
| init_std = 0.1 if not bias else 0.02 |
| w_init = torch.randn(out_dim, in_dim) * init_std |
| T_init = _ternarize(w_init, threshold) |
| packed_T, T_shape, T_pad = pack_ternary(T_init) |
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
|
|
| |
| target_S = 0.5 * (in_dim ** -0.5) |
| E_init = max(-8, min(0, int(round(math.log2(max(target_S, 2**-8)))))) |
| n_grp = _n_groups(out_dim, in_dim, group_size) |
| self.register_buffer("E", torch.full((n_grp,), E_init, dtype=torch.int8)) |
|
|
| |
| self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64)) |
| self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) |
|
|
| if bias: |
| self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32)) |
| else: |
| self.bias = None |
|
|
| def _get_T(self): |
| return unpack_ternary(self.T_packed, |
| tuple(self._T_shape.tolist()), |
| int(self._T_pad.item())) |
|
|
| def forward(self, x): |
| T = self._get_T() |
| out_d, in_d = self.out_dim, self.in_dim |
| gs = self.group_size |
| gpr = ceil(in_d / gs) |
|
|
| |
| E_float = self.E.float() |
| step = self.step_counter.item() |
| if step > 0: |
| |
| |
| |
| denom = max(step * gs, 1) |
| mean_corr = self.corr_accum.float() / denom |
| E_adj = E_float + mean_corr * 4.0 |
| else: |
| E_adj = E_float |
|
|
| S_per = torch.exp2(E_adj) |
| del E_float, E_adj |
|
|
| |
| S_2d = S_per.view(out_d, gpr) |
| S_exp = S_2d.repeat_interleave(gs, dim=1) |
| if S_exp.shape[1] > in_d: |
| S_exp = S_exp[:, :in_d] |
|
|
| w_eff = S_exp * T.float() |
| del S_exp, T, S_per, S_2d |
| w_eff_grad = w_eff.detach().requires_grad_(True) |
| del w_eff |
|
|
| def _capture(grad_w): |
| """grad_w: (out_d, in_d) float32 β ephemeral, only captured as int8 stats.""" |
| self._hook_grad_T_sign = grad_w.sign().to(torch.int8) |
| self._hook_grad_full = grad_w.detach() |
|
|
| w_eff_grad.register_hook(_capture) |
| y = F.linear(x, w_eff_grad) |
| if self.bias is not None: |
| y = y + self.bias.float() |
| return y |
|
|
| @torch.no_grad() |
| def update_corr(self): |
| """ |
| Pure-integer update: accumulates grad_sign Γ T correlation into corr_accum. |
| |
| Called by optimizer after backward. Reads hooks, discards grad tensor. |
| """ |
| if not hasattr(self, '_hook_grad_T_sign'): |
| return |
| gs = self.group_size |
| out_d, in_d = self.out_dim, self.in_dim |
| gpr = ceil(in_d / gs) |
|
|
| grad_sign = self._hook_grad_T_sign |
| T = self._get_T().to(device=grad_sign.device) |
|
|
| |
| |
| signed = (grad_sign.to(torch.int8) * T.to(torch.int8)) |
|
|
| |
| pad = gpr * gs - in_d |
| if pad > 0: |
| signed = torch.nn.functional.pad(signed, (0, pad)) |
|
|
| |
| sv = signed.view(out_d, gpr, gs) |
| score = sv.sum(dim=2, dtype=torch.int16) |
|
|
| |
| |
| |
| self.corr_accum -= score.flatten().to(torch.int64) |
| self.step_counter += 1 |
|
|
| |
| del self._hook_grad_T_sign |
| if hasattr(self, '_hook_grad_full'): |
| del self._hook_grad_full |
|
|
| def n_groups(self): |
| return _n_groups(self.out_dim, self.in_dim, self.group_size) |
|
|
| def total_ternary_params(self): |
| return self.out_dim * self.in_dim |
|
|
| def persistent_memory_mb(self): |
| total = 0 |
| for buf in [self.T_packed, self._T_shape, self._T_pad, |
| self.E, self.corr_accum, self.step_counter]: |
| total += buf.numel() * buf.element_size() |
| return total / (1024 * 1024) |
|
|
|
|
| |
|
|
| class TernaryRMSNorm(nn.Module): |
| def __init__(self, dim, group_size=32): |
| super().__init__() |
| self.dim = dim |
| self.group_size = group_size |
| w_init = torch.randn(dim) * 0.02 |
| T_init = _ternarize(w_init.view(1, dim), 0.01) |
| packed_T, T_shape, T_pad = pack_ternary(T_init) |
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
| n_grp = _n_groups(1, dim, group_size) |
| self.register_buffer("E", torch.full((n_grp,), -4, dtype=torch.int8)) |
| self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64)) |
| self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) |
|
|
| def _get_T(self): |
| return unpack_ternary(self.T_packed, |
| tuple(self._T_shape.tolist()), |
| int(self._T_pad.item())).flatten() |
|
|
| def forward(self, x): |
| T = self._get_T(); gs = self.group_size; dim = self.dim; gpr = ceil(dim / gs) |
| E_f = self.E.float(); step = self.step_counter.item() |
| if step > 0: |
| mc = self.corr_accum.float() / max(step * gs, 1) |
| E_adj = E_f + mc * 4.0 |
| else: |
| E_adj = E_f |
| S_p = torch.exp2(E_adj) |
| S_2 = S_p.view(1, gpr).repeat_interleave(gs, dim=1) |
| if S_2.shape[1] > dim: S_2 = S_2[:, :dim] |
| w = S_2.flatten() * T.float() |
| norm = F.rms_norm(x.float(), (dim,)) |
| return norm * w |
|
|
| def n_groups(self): return _n_groups(1, self.dim, self.group_size) |
| def total_ternary_params(self): return self.dim |
|
|
| def update_corr(self): |
| pass |
|
|
| def persistent_memory_mb(self): |
| return (self.T_packed.numel() + self.E.numel() + self.corr_accum.numel()) * 1 / 1e6 |
|
|