ARBS / testing /tscale_mini.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
tscale_mini β€” TernaryScaleTensor with BigInt correlation tracking.
Key idea from main ternary_scale.py:
score = Ξ£ (grad_sign Γ— T) per group β€” correlation of gradient with weight direction
- score > 0: grad aligns with T (direction correct, need LESS magnitude)
- score < 0: grad opposes T (direction wrong, need MORE magnitude)
Instead of E_accum (int8 threshold flips), use BigInt accumulators (int64):
corr_accum += score β€” never clips, never resets
step += 1
mean_corr = corr_accum / (step Γ— gs) β€” rational in [-1, +1]
S = 2^E Γ— (1 + mean_corr) β€” continuous S from E base + BigInt-derived adjustment
The BigInt division (corr_accum / (stepΓ—gs)) provides the precision.
mean_corr is a fixed-point number with ~30 bits of precision from int64.
This gives S continuous fine-tuning per group instead of just 256 discrete 2^E values.
PERSISTENT (all int):
T_packed (uint8) β€” 5 trits/byte
E (int8) β€” per-group base log2 scale
corr_accum (int64) β€” per-group BigInt correlation accumulator
step_counter (int64) β€” total steps
EPHEMERAL (float32, only during forward/backward):
w_eff = S Γ— T = 2^E Γ— (1 + mean_corr) Γ— T
"""
import math, torch, torch.nn as nn, torch.nn.functional as F
from math import ceil
# ─── Pack / Unpack (5 trit β†’ 1 byte, base-3) ───
def pack_ternary(w):
q = torch.empty_like(w, dtype=torch.uint8)
q[w < 0] = 0; q[w == 0] = 1; q[w > 0] = 2
flat = q.flatten()
pad = (-len(flat)) % 5
if pad:
flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)])
flat = flat.view(-1, 5)
packed = (flat[:, 0] + flat[:, 1] * 3 + flat[:, 2] * 9
+ flat[:, 3] * 27 + flat[:, 4] * 81).to(torch.uint8)
return packed.cpu(), w.shape, pad
def unpack_ternary(packed, shape, pad=0):
p = packed.to(torch.int16)
t0 = p % 3; p //= 3; t1 = p % 3; p //= 3
t2 = p % 3; p //= 3; t3 = p % 3; p //= 3; t4 = p % 3
out = torch.stack([t0, t1, t2, t3, t4], dim=1).flatten()
if pad: out = out[:-pad]
out = out.view(shape).to(torch.int8)
out[out == 0] = -1; out[out == 1] = 0; out[out == 2] = 1
return out
# ─── Helpers ───
def _ternarize(x, threshold=0.05):
return x.sign() * (x.abs() > threshold).to(x.dtype)
def _n_groups(out_dim, in_dim, gs):
return out_dim * ceil(in_dim / gs)
# ─── TernaryScaleTensor with BigInt correlation tracking ───
class TernaryScaleTensor(nn.Module):
"""
Ternary linear layer with BigInt correlation tracking for S.
Forward: S = 2^E Γ— (1 + mean_corr), w_eff = S Γ— T
where mean_corr = corr_accum / (step Γ— gs) β€” from BigInt, ephem to float32
Persistent:
T_packed (uint8) β€” 5 trits/byte
E (int8) β€” per-group base log2 scale
corr_accum (int64) β€” per-group BigInt: Ξ£ grad_sign Γ— T (correlation)
step_counter (int64)
"""
def __init__(self, in_dim, out_dim, threshold=0.05, group_size=32, bias=False):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.group_size = group_size
init_std = 0.1 if not bias else 0.02
w_init = torch.randn(out_dim, in_dim) * init_std
T_init = _ternarize(w_init, threshold)
packed_T, T_shape, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
# E: base log2 scale (int8, updated via BigInt correlation)
target_S = 0.5 * (in_dim ** -0.5)
E_init = max(-8, min(0, int(round(math.log2(max(target_S, 2**-8))))))
n_grp = _n_groups(out_dim, in_dim, group_size)
self.register_buffer("E", torch.full((n_grp,), E_init, dtype=torch.int8))
# BigInt correlation accumulator (int64, never resets)
self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
if bias:
self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32))
else:
self.bias = None
def _get_T(self):
return unpack_ternary(self.T_packed,
tuple(self._T_shape.tolist()),
int(self._T_pad.item()))
def forward(self, x):
T = self._get_T() # int8, ephemeral
out_d, in_d = self.out_dim, self.in_dim
gs = self.group_size
gpr = ceil(in_d / gs)
# ─── Compute S from E + BigInt correlation ───
E_float = self.E.float() # ephemeral
step = self.step_counter.item()
if step > 0:
# mean_corr = corr_accum / (step Γ— gs) as [-1, +1]
# Scale by K=4 for wider S range: S = 2^(E + K Γ— mean_corr)
# mean_corr [-1,+1] β†’ adj [-4,+4] β†’ S scaling [1/16, 16]
denom = max(step * gs, 1)
mean_corr = self.corr_accum.float() / denom
E_adj = E_float + mean_corr * 4.0
else:
E_adj = E_float
S_per = torch.exp2(E_adj) # (n_groups,) continuous S
del E_float, E_adj
# Expand S to full (out_d, in_d)
S_2d = S_per.view(out_d, gpr)
S_exp = S_2d.repeat_interleave(gs, dim=1)
if S_exp.shape[1] > in_d:
S_exp = S_exp[:, :in_d]
w_eff = S_exp * T.float()
del S_exp, T, S_per, S_2d
w_eff_grad = w_eff.detach().requires_grad_(True)
del w_eff
def _capture(grad_w):
"""grad_w: (out_d, in_d) float32 β€” ephemeral, only captured as int8 stats."""
self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
self._hook_grad_full = grad_w.detach()
w_eff_grad.register_hook(_capture)
y = F.linear(x, w_eff_grad)
if self.bias is not None:
y = y + self.bias.float()
return y
@torch.no_grad()
def update_corr(self):
"""
Pure-integer update: accumulates grad_sign Γ— T correlation into corr_accum.
Called by optimizer after backward. Reads hooks, discards grad tensor.
"""
if not hasattr(self, '_hook_grad_T_sign'):
return
gs = self.group_size
out_d, in_d = self.out_dim, self.in_dim
gpr = ceil(in_d / gs)
grad_sign = self._hook_grad_T_sign # (out_d, in_d) int8
T = self._get_T().to(device=grad_sign.device) # (out_d, in_d) int8
# score = grad_sign Γ— T per element, then sum per group
# Both are {-1,0,+1}, product is also {-1,0,+1}
signed = (grad_sign.to(torch.int8) * T.to(torch.int8)) # {-1,0,+1}
# Pad to group boundary
pad = gpr * gs - in_d
if pad > 0:
signed = torch.nn.functional.pad(signed, (0, pad))
# Sum per group β†’ correlation score
sv = signed.view(out_d, gpr, gs)
score = sv.sum(dim=2, dtype=torch.int16) # (out_d, gpr) int16
# BigInt accumulate (int64, never clips)
# NOTE: subtract because score > 0 means grad aligns with T (direction
# correct), so we need LESS magnitude (decrease corr β†’ decrease S)
self.corr_accum -= score.flatten().to(torch.int64)
self.step_counter += 1
# Clean up hooks
del self._hook_grad_T_sign
if hasattr(self, '_hook_grad_full'):
del self._hook_grad_full
def n_groups(self):
return _n_groups(self.out_dim, self.in_dim, self.group_size)
def total_ternary_params(self):
return self.out_dim * self.in_dim
def persistent_memory_mb(self):
total = 0
for buf in [self.T_packed, self._T_shape, self._T_pad,
self.E, self.corr_accum, self.step_counter]:
total += buf.numel() * buf.element_size()
return total / (1024 * 1024)
# ─── TernaryRMSNorm (same approach: 2^E Γ— BigInt correlation) ───
class TernaryRMSNorm(nn.Module):
def __init__(self, dim, group_size=32):
super().__init__()
self.dim = dim
self.group_size = group_size
w_init = torch.randn(dim) * 0.02
T_init = _ternarize(w_init.view(1, dim), 0.01)
packed_T, T_shape, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
n_grp = _n_groups(1, dim, group_size)
self.register_buffer("E", torch.full((n_grp,), -4, dtype=torch.int8))
self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
def _get_T(self):
return unpack_ternary(self.T_packed,
tuple(self._T_shape.tolist()),
int(self._T_pad.item())).flatten()
def forward(self, x):
T = self._get_T(); gs = self.group_size; dim = self.dim; gpr = ceil(dim / gs)
E_f = self.E.float(); step = self.step_counter.item()
if step > 0:
mc = self.corr_accum.float() / max(step * gs, 1)
E_adj = E_f + mc * 4.0
else:
E_adj = E_f
S_p = torch.exp2(E_adj)
S_2 = S_p.view(1, gpr).repeat_interleave(gs, dim=1)
if S_2.shape[1] > dim: S_2 = S_2[:, :dim]
w = S_2.flatten() * T.float()
norm = F.rms_norm(x.float(), (dim,))
return norm * w
def n_groups(self): return _n_groups(1, self.dim, self.group_size)
def total_ternary_params(self): return self.dim
def update_corr(self):
pass # RMSNorm doesn't capture gradient hooks
def persistent_memory_mb(self):
return (self.T_packed.numel() + self.E.numel() + self.corr_accum.numel()) * 1 / 1e6