#!/usr/bin/env python3 """ Test BigInt-accumulated ScaledOptum on a ~150M param ternary MLP. Architecture: Embedding(288, 2048) → [Repeat: Linear(2048→8192) → ReLU → Linear(8192→2048)] × 5 → RMSNorm(2048) → Linear(2048→288) All linear weights use TernaryScaleTensor (packed ternary T + S from optimizer). Training: predict next byte on TinyShakespeare. Key metrics: - Loss trend (should decrease if optimizer works) - Memory usage (model + optimizer state) - Effective bits-per-weight """ import os, sys, math, gc sys.path.insert(0, os.path.dirname(__file__)) import torch import torch.nn as nn import torch.nn.functional as F from tscale_mini import TernaryScaleTensor, TernaryRMSNorm, _n_groups from scaled_optum import ScaledOptum torch.set_float32_matmul_precision('high') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # ─── Config ─── VOCAB = 288 HIDDEN = 2048 FFN_HIDDEN = 8192 N_LAYERS = 2 GROUP_SIZE = 32 THRESHOLD = 0.05 # ─── Model ─── class TernaryMLP(nn.Module): """ Pure ternary MLP with packed weights + ALL-INT persistent state. No float32/16 anywhere in model buffers. """ def __init__(self): super().__init__() self.embed = TernaryScaleTensor(VOCAB, HIDDEN, threshold=THRESHOLD, group_size=GROUP_SIZE) self.layers = nn.ModuleList() for i in range(N_LAYERS): layer = nn.ModuleDict({ 'w1': TernaryScaleTensor(HIDDEN, FFN_HIDDEN, threshold=THRESHOLD, group_size=GROUP_SIZE), 'w2': TernaryScaleTensor(FFN_HIDDEN, HIDDEN, threshold=THRESHOLD, group_size=GROUP_SIZE), 'norm': TernaryRMSNorm(HIDDEN, group_size=GROUP_SIZE), }) self.layers.append(layer) self.final_norm = TernaryRMSNorm(HIDDEN, group_size=GROUP_SIZE) self.head = TernaryScaleTensor(HIDDEN, VOCAB, threshold=THRESHOLD, group_size=GROUP_SIZE) def forward(self, x, targets=None): B, T = x.shape emb = self.embed(F.one_hot(x, num_classes=VOCAB).float()) h = emb for layer in self.layers: h = layer['w1'](h) h = F.relu(h) h = layer['w2'](h) h = self.final_norm(h) logits = self.head(h) if targets is not None: loss = F.cross_entropy(logits.view(-1, VOCAB), targets.view(-1)) return logits, loss return logits def param_counts(self): total_ternary = 0 total_float = 0 for _, mod in self.named_modules(): if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm)): total_ternary += mod.total_ternary_params() else: for p in mod.parameters(recurse=False): total_float += p.numel() return total_ternary, total_float def persistent_memory_mb(self): total = 0 for mod in self.modules(): if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm)): total += mod.persistent_memory_mb() return total # ─── Data (TinyShakespeare) ─── def load_data(path="/tmp/tinyshakespeare.txt"): if not os.path.exists(path): import urllib.request url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" urllib.request.urlretrieve(url, path) with open(path, 'rb') as f: data = f.read() return torch.tensor(list(data), dtype=torch.long) def get_batch(data, bs, ctx, device=device): ix = torch.randint(0, len(data) - ctx - 1, (bs,), device='cpu') x = torch.stack([data[i:i + ctx] for i in ix]) y = torch.stack([data[i + 1:i + ctx + 1] for i in ix]) return x.to(device), y.to(device) # ─── Test ─── @torch.no_grad() def compute_loss(model, data, bs=4, ctx=256): model.eval() x, y = get_batch(data, bs, ctx) _, loss = model(x, targets=y) return loss.item() def train_step(model, opt, data, bs=2, ctx=128): model.train() x, y = get_batch(data, bs, ctx) logits, loss = model(x, targets=y) loss.backward() opt.step() opt.zero_grad(set_to_none=True) return loss.item() def main(): print("Building TernaryMLP...") model = TernaryMLP().to(device) total_ternary, total_float = model.param_counts() total_params = total_ternary + total_float persistent_mb = model.persistent_memory_mb() # Breakdown of persistent int storage t_b = sum(m.T_packed.numel() * m.T_packed.element_size() for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) e_b = sum(m.E.numel() * m.E.element_size() for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) a_b = sum(m.corr_accum.numel() * m.corr_accum.element_size() for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) sc_b = sum(getattr(m, 'step_counter', torch.zeros(1)).numel() * 8 for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) bpw = (t_b * 8 + e_b * 8 + a_b * 8 + sc_b * 8) / max(1, total_params) print(f"\n Total params: {total_params:,}") print(f" Ternary params: {total_ternary:,} ({total_ternary/max(1,total_params)*100:.1f}%)") print(f" Float params: {total_float:,}") print(f" Persistent buffers: {persistent_mb:.2f} MB (ALL INTEGER)") print(f" T_packed: {t_b/1e6:.2f} MB ({t_b*8/total_ternary:.2f} bpw)") print(f" E (int8): {e_b/1e6:.2f} MB") print(f" corr_accum (int64):{a_b/1e6:.2f} MB") print(f" step_counter: {sc_b/1e6:.2f} MB") print(f" Effective bpw: {bpw:.2f}") print(f" Float params (bias): {sum(p.numel()*p.element_size() for p in model.parameters())/1e6:.1f} MB") # Collect all ternary modules ternary_modules = [mod for mod in model.modules() if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm))] # Optimizer: pure integer, no float state dummy = nn.Parameter(torch.zeros(1)) opt = ScaledOptum([dummy], lr=0.3, default_group_size=GROUP_SIZE) opt.add_ternary_modules(ternary_modules) n_mods = len(opt.param_groups[0].get('ternary_modules', [])) print(f" Optimizer state: 0 bytes (pure integer, stored on modules)") print(f" Ternary modules: {n_mods}") # Data data = load_data() train_data = data[:int(0.9 * len(data))] val_data = data[int(0.9 * len(data)):] print(f" Train data: {len(train_data):,} bytes") print(f" Val data: {len(val_data):,} bytes") # Warmup: nn.Module.parameters() won't find TernaryScaleTensor buffers # (T_packed etc are buffers, not parameters). The optimizer only sees # the .S_opt and the norms' float params. That's fine — we handle # ternary params via hooks, not nn.Parameter. # Training N_STEPS = 5000 print(f"\nTraining for {N_STEPS} steps...") print(f"{'step':>6s} {'loss':>8s} {'bpw':>8s} {'acc%':>6s} {'S_range':>10s} {'VRAM':>6s}") print("-" * 60) for step in range(N_STEPS): loss = train_step(model, opt, train_data, bs=2, ctx=128) if step % 200 == 0 or step == N_STEPS - 1: # Compute accuracy model.eval() x, y = get_batch(val_data, 1, 128) logits, _ = model(x, targets=y) acc = (logits.argmax(-1) == y).float().mean().item() model.train() # Get E range and sign_bias for first layer e_vals = model.layers[0]['w1'].E e_min, e_max = e_vals.min().item(), e_vals.max().item() bpw_val = loss / math.log(2) vram = torch.cuda.max_memory_allocated() / 1e6 if torch.cuda.is_available() else 0 torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None print(f"{step:6d} {loss:8.4f} {bpw_val:8.3f} {acc*100:5.1f}% " f"2^{e_min:+3d}–2^{e_max:+3d} {vram:5.0f}MB") print("\nDone.") print(f"Final loss: {loss:.4f}") if __name__ == '__main__': main()