| |
| """ |
| Test BigInt-accumulated ScaledOptum on a ~150M param ternary MLP. |
| |
| Architecture: |
| Embedding(288, 2048) β [Repeat: Linear(2048β8192) β ReLU β Linear(8192β2048)] Γ 5 |
| β RMSNorm(2048) β Linear(2048β288) |
| |
| All linear weights use TernaryScaleTensor (packed ternary T + S from optimizer). |
| Training: predict next byte on TinyShakespeare. |
| |
| Key metrics: |
| - Loss trend (should decrease if optimizer works) |
| - Memory usage (model + optimizer state) |
| - Effective bits-per-weight |
| """ |
| import os, sys, math, gc |
| sys.path.insert(0, os.path.dirname(__file__)) |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tscale_mini import TernaryScaleTensor, TernaryRMSNorm, _n_groups |
| from scaled_optum import ScaledOptum |
|
|
| torch.set_float32_matmul_precision('high') |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
|
|
| |
|
|
| VOCAB = 288 |
| HIDDEN = 2048 |
| FFN_HIDDEN = 8192 |
| N_LAYERS = 2 |
| GROUP_SIZE = 32 |
| THRESHOLD = 0.05 |
|
|
|
|
| |
|
|
| class TernaryMLP(nn.Module): |
| """ |
| Pure ternary MLP with packed weights + ALL-INT persistent state. |
| No float32/16 anywhere in model buffers. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.embed = TernaryScaleTensor(VOCAB, HIDDEN, threshold=THRESHOLD, |
| group_size=GROUP_SIZE) |
|
|
| self.layers = nn.ModuleList() |
| for i in range(N_LAYERS): |
| layer = nn.ModuleDict({ |
| 'w1': TernaryScaleTensor(HIDDEN, FFN_HIDDEN, threshold=THRESHOLD, |
| group_size=GROUP_SIZE), |
| 'w2': TernaryScaleTensor(FFN_HIDDEN, HIDDEN, threshold=THRESHOLD, |
| group_size=GROUP_SIZE), |
| 'norm': TernaryRMSNorm(HIDDEN, group_size=GROUP_SIZE), |
| }) |
| self.layers.append(layer) |
|
|
| self.final_norm = TernaryRMSNorm(HIDDEN, group_size=GROUP_SIZE) |
| self.head = TernaryScaleTensor(HIDDEN, VOCAB, threshold=THRESHOLD, |
| group_size=GROUP_SIZE) |
|
|
| def forward(self, x, targets=None): |
| B, T = x.shape |
| emb = self.embed(F.one_hot(x, num_classes=VOCAB).float()) |
|
|
| h = emb |
| for layer in self.layers: |
| h = layer['w1'](h) |
| h = F.relu(h) |
| h = layer['w2'](h) |
|
|
| h = self.final_norm(h) |
| logits = self.head(h) |
|
|
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, VOCAB), targets.view(-1)) |
| return logits, loss |
| return logits |
|
|
| def param_counts(self): |
| total_ternary = 0 |
| total_float = 0 |
| for _, mod in self.named_modules(): |
| if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm)): |
| total_ternary += mod.total_ternary_params() |
| else: |
| for p in mod.parameters(recurse=False): |
| total_float += p.numel() |
| return total_ternary, total_float |
|
|
| def persistent_memory_mb(self): |
| total = 0 |
| for mod in self.modules(): |
| if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm)): |
| total += mod.persistent_memory_mb() |
| return total |
|
|
|
|
| |
|
|
| def load_data(path="/tmp/tinyshakespeare.txt"): |
| if not os.path.exists(path): |
| import urllib.request |
| url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
| urllib.request.urlretrieve(url, path) |
| with open(path, 'rb') as f: |
| data = f.read() |
| return torch.tensor(list(data), dtype=torch.long) |
|
|
|
|
| def get_batch(data, bs, ctx, device=device): |
| ix = torch.randint(0, len(data) - ctx - 1, (bs,), device='cpu') |
| x = torch.stack([data[i:i + ctx] for i in ix]) |
| y = torch.stack([data[i + 1:i + ctx + 1] for i in ix]) |
| return x.to(device), y.to(device) |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def compute_loss(model, data, bs=4, ctx=256): |
| model.eval() |
| x, y = get_batch(data, bs, ctx) |
| _, loss = model(x, targets=y) |
| return loss.item() |
|
|
|
|
| def train_step(model, opt, data, bs=2, ctx=128): |
| model.train() |
| x, y = get_batch(data, bs, ctx) |
| logits, loss = model(x, targets=y) |
| loss.backward() |
| opt.step() |
| opt.zero_grad(set_to_none=True) |
| return loss.item() |
|
|
|
|
| def main(): |
| print("Building TernaryMLP...") |
| model = TernaryMLP().to(device) |
| total_ternary, total_float = model.param_counts() |
| total_params = total_ternary + total_float |
| persistent_mb = model.persistent_memory_mb() |
|
|
| |
| t_b = sum(m.T_packed.numel() * m.T_packed.element_size() |
| for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) |
| e_b = sum(m.E.numel() * m.E.element_size() |
| for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) |
| a_b = sum(m.corr_accum.numel() * m.corr_accum.element_size() |
| for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) |
| sc_b = sum(getattr(m, 'step_counter', torch.zeros(1)).numel() * 8 |
| for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm))) |
|
|
| bpw = (t_b * 8 + e_b * 8 + a_b * 8 + sc_b * 8) / max(1, total_params) |
| print(f"\n Total params: {total_params:,}") |
| print(f" Ternary params: {total_ternary:,} ({total_ternary/max(1,total_params)*100:.1f}%)") |
| print(f" Float params: {total_float:,}") |
| print(f" Persistent buffers: {persistent_mb:.2f} MB (ALL INTEGER)") |
| print(f" T_packed: {t_b/1e6:.2f} MB ({t_b*8/total_ternary:.2f} bpw)") |
| print(f" E (int8): {e_b/1e6:.2f} MB") |
| print(f" corr_accum (int64):{a_b/1e6:.2f} MB") |
| print(f" step_counter: {sc_b/1e6:.2f} MB") |
| print(f" Effective bpw: {bpw:.2f}") |
| print(f" Float params (bias): {sum(p.numel()*p.element_size() for p in model.parameters())/1e6:.1f} MB") |
|
|
| |
| ternary_modules = [mod for mod in model.modules() |
| if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm))] |
|
|
| |
| dummy = nn.Parameter(torch.zeros(1)) |
| opt = ScaledOptum([dummy], lr=0.3, default_group_size=GROUP_SIZE) |
| opt.add_ternary_modules(ternary_modules) |
| n_mods = len(opt.param_groups[0].get('ternary_modules', [])) |
| print(f" Optimizer state: 0 bytes (pure integer, stored on modules)") |
| print(f" Ternary modules: {n_mods}") |
|
|
| |
| data = load_data() |
| train_data = data[:int(0.9 * len(data))] |
| val_data = data[int(0.9 * len(data)):] |
| print(f" Train data: {len(train_data):,} bytes") |
| print(f" Val data: {len(val_data):,} bytes") |
|
|
| |
| |
| |
| |
|
|
| |
| N_STEPS = 5000 |
| print(f"\nTraining for {N_STEPS} steps...") |
| print(f"{'step':>6s} {'loss':>8s} {'bpw':>8s} {'acc%':>6s} {'S_range':>10s} {'VRAM':>6s}") |
| print("-" * 60) |
|
|
| for step in range(N_STEPS): |
| loss = train_step(model, opt, train_data, bs=2, ctx=128) |
| |
| if step % 200 == 0 or step == N_STEPS - 1: |
| |
| model.eval() |
| x, y = get_batch(val_data, 1, 128) |
| logits, _ = model(x, targets=y) |
| acc = (logits.argmax(-1) == y).float().mean().item() |
| model.train() |
|
|
| |
| e_vals = model.layers[0]['w1'].E |
| e_min, e_max = e_vals.min().item(), e_vals.max().item() |
| bpw_val = loss / math.log(2) |
|
|
| vram = torch.cuda.max_memory_allocated() / 1e6 if torch.cuda.is_available() else 0 |
| torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None |
|
|
| print(f"{step:6d} {loss:8.4f} {bpw_val:8.3f} {acc*100:5.1f}% " |
| f"2^{e_min:+3d}β2^{e_max:+3d} {vram:5.0f}MB") |
|
|
| print("\nDone.") |
| print(f"Final loss: {loss:.4f}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|