ARBS / testing /test_bigint_ternary.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
#!/usr/bin/env python3
"""
Test BigInt-accumulated ScaledOptum on a ~150M param ternary MLP.
Architecture:
Embedding(288, 2048) β†’ [Repeat: Linear(2048β†’8192) β†’ ReLU β†’ Linear(8192β†’2048)] Γ— 5
β†’ RMSNorm(2048) β†’ Linear(2048β†’288)
All linear weights use TernaryScaleTensor (packed ternary T + S from optimizer).
Training: predict next byte on TinyShakespeare.
Key metrics:
- Loss trend (should decrease if optimizer works)
- Memory usage (model + optimizer state)
- Effective bits-per-weight
"""
import os, sys, math, gc
sys.path.insert(0, os.path.dirname(__file__))
import torch
import torch.nn as nn
import torch.nn.functional as F
from tscale_mini import TernaryScaleTensor, TernaryRMSNorm, _n_groups
from scaled_optum import ScaledOptum
torch.set_float32_matmul_precision('high')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# ─── Config ───
VOCAB = 288
HIDDEN = 2048
FFN_HIDDEN = 8192
N_LAYERS = 2
GROUP_SIZE = 32
THRESHOLD = 0.05
# ─── Model ───
class TernaryMLP(nn.Module):
"""
Pure ternary MLP with packed weights + ALL-INT persistent state.
No float32/16 anywhere in model buffers.
"""
def __init__(self):
super().__init__()
self.embed = TernaryScaleTensor(VOCAB, HIDDEN, threshold=THRESHOLD,
group_size=GROUP_SIZE)
self.layers = nn.ModuleList()
for i in range(N_LAYERS):
layer = nn.ModuleDict({
'w1': TernaryScaleTensor(HIDDEN, FFN_HIDDEN, threshold=THRESHOLD,
group_size=GROUP_SIZE),
'w2': TernaryScaleTensor(FFN_HIDDEN, HIDDEN, threshold=THRESHOLD,
group_size=GROUP_SIZE),
'norm': TernaryRMSNorm(HIDDEN, group_size=GROUP_SIZE),
})
self.layers.append(layer)
self.final_norm = TernaryRMSNorm(HIDDEN, group_size=GROUP_SIZE)
self.head = TernaryScaleTensor(HIDDEN, VOCAB, threshold=THRESHOLD,
group_size=GROUP_SIZE)
def forward(self, x, targets=None):
B, T = x.shape
emb = self.embed(F.one_hot(x, num_classes=VOCAB).float())
h = emb
for layer in self.layers:
h = layer['w1'](h)
h = F.relu(h)
h = layer['w2'](h)
h = self.final_norm(h)
logits = self.head(h)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, VOCAB), targets.view(-1))
return logits, loss
return logits
def param_counts(self):
total_ternary = 0
total_float = 0
for _, mod in self.named_modules():
if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm)):
total_ternary += mod.total_ternary_params()
else:
for p in mod.parameters(recurse=False):
total_float += p.numel()
return total_ternary, total_float
def persistent_memory_mb(self):
total = 0
for mod in self.modules():
if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm)):
total += mod.persistent_memory_mb()
return total
# ─── Data (TinyShakespeare) ───
def load_data(path="/tmp/tinyshakespeare.txt"):
if not os.path.exists(path):
import urllib.request
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, path)
with open(path, 'rb') as f:
data = f.read()
return torch.tensor(list(data), dtype=torch.long)
def get_batch(data, bs, ctx, device=device):
ix = torch.randint(0, len(data) - ctx - 1, (bs,), device='cpu')
x = torch.stack([data[i:i + ctx] for i in ix])
y = torch.stack([data[i + 1:i + ctx + 1] for i in ix])
return x.to(device), y.to(device)
# ─── Test ───
@torch.no_grad()
def compute_loss(model, data, bs=4, ctx=256):
model.eval()
x, y = get_batch(data, bs, ctx)
_, loss = model(x, targets=y)
return loss.item()
def train_step(model, opt, data, bs=2, ctx=128):
model.train()
x, y = get_batch(data, bs, ctx)
logits, loss = model(x, targets=y)
loss.backward()
opt.step()
opt.zero_grad(set_to_none=True)
return loss.item()
def main():
print("Building TernaryMLP...")
model = TernaryMLP().to(device)
total_ternary, total_float = model.param_counts()
total_params = total_ternary + total_float
persistent_mb = model.persistent_memory_mb()
# Breakdown of persistent int storage
t_b = sum(m.T_packed.numel() * m.T_packed.element_size()
for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm)))
e_b = sum(m.E.numel() * m.E.element_size()
for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm)))
a_b = sum(m.corr_accum.numel() * m.corr_accum.element_size()
for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm)))
sc_b = sum(getattr(m, 'step_counter', torch.zeros(1)).numel() * 8
for m in model.modules() if isinstance(m, (TernaryScaleTensor, TernaryRMSNorm)))
bpw = (t_b * 8 + e_b * 8 + a_b * 8 + sc_b * 8) / max(1, total_params)
print(f"\n Total params: {total_params:,}")
print(f" Ternary params: {total_ternary:,} ({total_ternary/max(1,total_params)*100:.1f}%)")
print(f" Float params: {total_float:,}")
print(f" Persistent buffers: {persistent_mb:.2f} MB (ALL INTEGER)")
print(f" T_packed: {t_b/1e6:.2f} MB ({t_b*8/total_ternary:.2f} bpw)")
print(f" E (int8): {e_b/1e6:.2f} MB")
print(f" corr_accum (int64):{a_b/1e6:.2f} MB")
print(f" step_counter: {sc_b/1e6:.2f} MB")
print(f" Effective bpw: {bpw:.2f}")
print(f" Float params (bias): {sum(p.numel()*p.element_size() for p in model.parameters())/1e6:.1f} MB")
# Collect all ternary modules
ternary_modules = [mod for mod in model.modules()
if isinstance(mod, (TernaryScaleTensor, TernaryRMSNorm))]
# Optimizer: pure integer, no float state
dummy = nn.Parameter(torch.zeros(1))
opt = ScaledOptum([dummy], lr=0.3, default_group_size=GROUP_SIZE)
opt.add_ternary_modules(ternary_modules)
n_mods = len(opt.param_groups[0].get('ternary_modules', []))
print(f" Optimizer state: 0 bytes (pure integer, stored on modules)")
print(f" Ternary modules: {n_mods}")
# Data
data = load_data()
train_data = data[:int(0.9 * len(data))]
val_data = data[int(0.9 * len(data)):]
print(f" Train data: {len(train_data):,} bytes")
print(f" Val data: {len(val_data):,} bytes")
# Warmup: nn.Module.parameters() won't find TernaryScaleTensor buffers
# (T_packed etc are buffers, not parameters). The optimizer only sees
# the .S_opt and the norms' float params. That's fine β€” we handle
# ternary params via hooks, not nn.Parameter.
# Training
N_STEPS = 5000
print(f"\nTraining for {N_STEPS} steps...")
print(f"{'step':>6s} {'loss':>8s} {'bpw':>8s} {'acc%':>6s} {'S_range':>10s} {'VRAM':>6s}")
print("-" * 60)
for step in range(N_STEPS):
loss = train_step(model, opt, train_data, bs=2, ctx=128)
if step % 200 == 0 or step == N_STEPS - 1:
# Compute accuracy
model.eval()
x, y = get_batch(val_data, 1, 128)
logits, _ = model(x, targets=y)
acc = (logits.argmax(-1) == y).float().mean().item()
model.train()
# Get E range and sign_bias for first layer
e_vals = model.layers[0]['w1'].E
e_min, e_max = e_vals.min().item(), e_vals.max().item()
bpw_val = loss / math.log(2)
vram = torch.cuda.max_memory_allocated() / 1e6 if torch.cuda.is_available() else 0
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
print(f"{step:6d} {loss:8.4f} {bpw_val:8.3f} {acc*100:5.1f}% "
f"2^{e_min:+3d}–2^{e_max:+3d} {vram:5.0f}MB")
print("\nDone.")
print(f"Final loss: {loss:.4f}")
if __name__ == '__main__':
main()