""" tscale_mini — TernaryScaleTensor with BigInt correlation tracking. Key idea from main ternary_scale.py: score = Σ (grad_sign × T) per group — correlation of gradient with weight direction - score > 0: grad aligns with T (direction correct, need LESS magnitude) - score < 0: grad opposes T (direction wrong, need MORE magnitude) Instead of E_accum (int8 threshold flips), use BigInt accumulators (int64): corr_accum += score — never clips, never resets step += 1 mean_corr = corr_accum / (step × gs) — rational in [-1, +1] S = 2^E × (1 + mean_corr) — continuous S from E base + BigInt-derived adjustment The BigInt division (corr_accum / (step×gs)) provides the precision. mean_corr is a fixed-point number with ~30 bits of precision from int64. This gives S continuous fine-tuning per group instead of just 256 discrete 2^E values. PERSISTENT (all int): T_packed (uint8) — 5 trits/byte E (int8) — per-group base log2 scale corr_accum (int64) — per-group BigInt correlation accumulator step_counter (int64) — total steps EPHEMERAL (float32, only during forward/backward): w_eff = S × T = 2^E × (1 + mean_corr) × T """ import math, torch, torch.nn as nn, torch.nn.functional as F from math import ceil # ─── Pack / Unpack (5 trit → 1 byte, base-3) ─── def pack_ternary(w): q = torch.empty_like(w, dtype=torch.uint8) q[w < 0] = 0; q[w == 0] = 1; q[w > 0] = 2 flat = q.flatten() pad = (-len(flat)) % 5 if pad: flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)]) flat = flat.view(-1, 5) packed = (flat[:, 0] + flat[:, 1] * 3 + flat[:, 2] * 9 + flat[:, 3] * 27 + flat[:, 4] * 81).to(torch.uint8) return packed.cpu(), w.shape, pad def unpack_ternary(packed, shape, pad=0): p = packed.to(torch.int16) t0 = p % 3; p //= 3; t1 = p % 3; p //= 3 t2 = p % 3; p //= 3; t3 = p % 3; p //= 3; t4 = p % 3 out = torch.stack([t0, t1, t2, t3, t4], dim=1).flatten() if pad: out = out[:-pad] out = out.view(shape).to(torch.int8) out[out == 0] = -1; out[out == 1] = 0; out[out == 2] = 1 return out # ─── Helpers ─── def _ternarize(x, threshold=0.05): return x.sign() * (x.abs() > threshold).to(x.dtype) def _n_groups(out_dim, in_dim, gs): return out_dim * ceil(in_dim / gs) # ─── TernaryScaleTensor with BigInt correlation tracking ─── class TernaryScaleTensor(nn.Module): """ Ternary linear layer with BigInt correlation tracking for S. Forward: S = 2^E × (1 + mean_corr), w_eff = S × T where mean_corr = corr_accum / (step × gs) — from BigInt, ephem to float32 Persistent: T_packed (uint8) — 5 trits/byte E (int8) — per-group base log2 scale corr_accum (int64) — per-group BigInt: Σ grad_sign × T (correlation) step_counter (int64) """ def __init__(self, in_dim, out_dim, threshold=0.05, group_size=32, bias=False): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.group_size = group_size init_std = 0.1 if not bias else 0.02 w_init = torch.randn(out_dim, in_dim) * init_std T_init = _ternarize(w_init, threshold) packed_T, T_shape, T_pad = pack_ternary(T_init) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) # E: base log2 scale (int8, updated via BigInt correlation) target_S = 0.5 * (in_dim ** -0.5) E_init = max(-8, min(0, int(round(math.log2(max(target_S, 2**-8)))))) n_grp = _n_groups(out_dim, in_dim, group_size) self.register_buffer("E", torch.full((n_grp,), E_init, dtype=torch.int8)) # BigInt correlation accumulator (int64, never resets) self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64)) self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) if bias: self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32)) else: self.bias = None def _get_T(self): return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) def forward(self, x): T = self._get_T() # int8, ephemeral out_d, in_d = self.out_dim, self.in_dim gs = self.group_size gpr = ceil(in_d / gs) # ─── Compute S from E + BigInt correlation ─── E_float = self.E.float() # ephemeral step = self.step_counter.item() if step > 0: # mean_corr = corr_accum / (step × gs) as [-1, +1] # Scale by K=4 for wider S range: S = 2^(E + K × mean_corr) # mean_corr [-1,+1] → adj [-4,+4] → S scaling [1/16, 16] denom = max(step * gs, 1) mean_corr = self.corr_accum.float() / denom E_adj = E_float + mean_corr * 4.0 else: E_adj = E_float S_per = torch.exp2(E_adj) # (n_groups,) continuous S del E_float, E_adj # Expand S to full (out_d, in_d) S_2d = S_per.view(out_d, gpr) S_exp = S_2d.repeat_interleave(gs, dim=1) if S_exp.shape[1] > in_d: S_exp = S_exp[:, :in_d] w_eff = S_exp * T.float() del S_exp, T, S_per, S_2d w_eff_grad = w_eff.detach().requires_grad_(True) del w_eff def _capture(grad_w): """grad_w: (out_d, in_d) float32 — ephemeral, only captured as int8 stats.""" self._hook_grad_T_sign = grad_w.sign().to(torch.int8) self._hook_grad_full = grad_w.detach() w_eff_grad.register_hook(_capture) y = F.linear(x, w_eff_grad) if self.bias is not None: y = y + self.bias.float() return y @torch.no_grad() def update_corr(self): """ Pure-integer update: accumulates grad_sign × T correlation into corr_accum. Called by optimizer after backward. Reads hooks, discards grad tensor. """ if not hasattr(self, '_hook_grad_T_sign'): return gs = self.group_size out_d, in_d = self.out_dim, self.in_dim gpr = ceil(in_d / gs) grad_sign = self._hook_grad_T_sign # (out_d, in_d) int8 T = self._get_T().to(device=grad_sign.device) # (out_d, in_d) int8 # score = grad_sign × T per element, then sum per group # Both are {-1,0,+1}, product is also {-1,0,+1} signed = (grad_sign.to(torch.int8) * T.to(torch.int8)) # {-1,0,+1} # Pad to group boundary pad = gpr * gs - in_d if pad > 0: signed = torch.nn.functional.pad(signed, (0, pad)) # Sum per group → correlation score sv = signed.view(out_d, gpr, gs) score = sv.sum(dim=2, dtype=torch.int16) # (out_d, gpr) int16 # BigInt accumulate (int64, never clips) # NOTE: subtract because score > 0 means grad aligns with T (direction # correct), so we need LESS magnitude (decrease corr → decrease S) self.corr_accum -= score.flatten().to(torch.int64) self.step_counter += 1 # Clean up hooks del self._hook_grad_T_sign if hasattr(self, '_hook_grad_full'): del self._hook_grad_full def n_groups(self): return _n_groups(self.out_dim, self.in_dim, self.group_size) def total_ternary_params(self): return self.out_dim * self.in_dim def persistent_memory_mb(self): total = 0 for buf in [self.T_packed, self._T_shape, self._T_pad, self.E, self.corr_accum, self.step_counter]: total += buf.numel() * buf.element_size() return total / (1024 * 1024) # ─── TernaryRMSNorm (same approach: 2^E × BigInt correlation) ─── class TernaryRMSNorm(nn.Module): def __init__(self, dim, group_size=32): super().__init__() self.dim = dim self.group_size = group_size w_init = torch.randn(dim) * 0.02 T_init = _ternarize(w_init.view(1, dim), 0.01) packed_T, T_shape, T_pad = pack_ternary(T_init) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) n_grp = _n_groups(1, dim, group_size) self.register_buffer("E", torch.full((n_grp,), -4, dtype=torch.int8)) self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64)) self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) def _get_T(self): return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())).flatten() def forward(self, x): T = self._get_T(); gs = self.group_size; dim = self.dim; gpr = ceil(dim / gs) E_f = self.E.float(); step = self.step_counter.item() if step > 0: mc = self.corr_accum.float() / max(step * gs, 1) E_adj = E_f + mc * 4.0 else: E_adj = E_f S_p = torch.exp2(E_adj) S_2 = S_p.view(1, gpr).repeat_interleave(gs, dim=1) if S_2.shape[1] > dim: S_2 = S_2[:, :dim] w = S_2.flatten() * T.float() norm = F.rms_norm(x.float(), (dim,)) return norm * w def n_groups(self): return _n_groups(1, self.dim, self.group_size) def total_ternary_params(self): return self.dim def update_corr(self): pass # RMSNorm doesn't capture gradient hooks def persistent_memory_mb(self): return (self.T_packed.numel() + self.E.numel() + self.corr_accum.numel()) * 1 / 1e6