""" Phase 2 Benchmark: 6-way optimizer comparison on pure ternary MORPH. All 6 configs run in parallel on the same GPU. Configs (all T32 ternary forward): 1. SignSGD + Config C (group-avg S, no shadow weight, no momentum) 2. SignSGD + Config E (per-element S=|W|, no shadow weight, no momentum) 3. Lion + bf16 shadow (bf16 model params, Lion momentum in FP32) 4. Lion + FP32 shadow (FP32 model params, Lion momentum in FP32) 5. Adam + bf16 shadow (bf16 model params, Adam m/v in FP32) 6. Adam + FP32 shadow (FP32 model params, Adam m/v in FP32) Metrics: loss curve, step time (ms), peak VRAM (MB) """ import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import sys import time import json import math import gc import torch import torch.nn as nn import torch.nn.functional as F import bitsandbytes as bnb import urllib.request from arbitor.main import MORPHTernaryModel, VOCAB, CTX, THRESHOLD, SPECIAL_VOCAB, StickyZoneSTE from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, GROUP_SIZES from arbitor.optim.sign_sgd import SignSGD STEPS = 2500 WARMUP = 250 BATCH_SIZE = 64 CTX_LEN = 66 EVAL_INTERVAL = 250 SEED = 42 DATA_DIR = os.path.dirname(__file__) or "." CONFIGS = [ "SignSGD_ConfigC_T32", "SignSGD_ConfigE_T32", "Lion_bf16_T32", "Lion_FP32_T32", "Adam_bf16_T32", "Adam_FP32_T32", ] def download_data(): path = os.path.join(DATA_DIR, "tinyshakespeare.txt") if not os.path.exists(path): print("Downloading tinyshakespeare...") urllib.request.urlretrieve( "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path, ) with open(path, "r", encoding="utf-8") as f: text = f.read() byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) n = int(0.9 * len(byte_data)) return byte_data[:n], byte_data[n:] def get_lr(step, max_lr=3e-4, min_lr=1e-5): if step < WARMUP: return max_lr * (step + 1) / WARMUP progress = (step - WARMUP) / (STEPS - WARMUP) return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) def count_optimizer_memory_mb(optimizer): total = 0 for group in optimizer.param_groups: for p in group["params"]: total += p.numel() * p.element_size() state = optimizer.state.get(p, {}) for buf in state.values(): if isinstance(buf, torch.Tensor): total += buf.numel() * buf.element_size() return total / (1024 * 1024) def make_model(config_name, device): if "ConfigE" in config_name: tscale_type = TScaleType.T64 else: tscale_type = TScaleType.T32 model = MORPHTernaryModel(tscale_type=tscale_type) if "bf16" in config_name: model = model.to(torch.bfloat16) else: model = model.to(torch.float32) model = model.to(device) return model def make_optimizer(config_name, model_params, lr=3e-4, weight_decay=0.01): if "SignSGD" in config_name: return SignSGD(model_params, lr=lr, weight_decay=weight_decay) elif "Lion" in config_name: return bnb.optim.Lion(model_params, lr=lr, weight_decay=weight_decay) elif "Adam" in config_name: return torch.optim.Adam(model_params, lr=lr, weight_decay=weight_decay) else: raise ValueError(f"Unknown config: {config_name}") def run_parallel_benchmark(configs, train_data, device): torch.manual_seed(SEED) torch.cuda.reset_peak_memory_stats(device) torch.cuda.empty_cache() gc.collect() models = [] optimizers = [] streams = [] loss_histories = [[] for _ in configs] per_config_step_ms = [[] for _ in configs] print(f"\nInitializing {len(configs)} models on {device}...") for i, cfg in enumerate(configs): torch.manual_seed(SEED + i) m = make_model(cfg, device) o = make_optimizer(cfg, m.parameters()) s = torch.cuda.Stream(device) if device == "cuda" else None models.append(m) optimizers.append(o) streams.append(s) n = sum(p.numel() for p in m.parameters()) dtype = "bf16" if "bf16" in cfg else "FP32" tscale = "ConfigE(T64)" if "ConfigE" in cfg else "ConfigC(T32)" print(f" [{i}] {cfg:<22} params={n:,} dtype={dtype} tscale={tscale}") total_vram_start = torch.cuda.memory_allocated(device) / (1024 * 1024) opt_mems = [count_optimizer_memory_mb(o) for o in optimizers] model_mems = [ sum(p.numel() * p.element_size() for p in m.parameters()) / (1024 * 1024) for m in models ] print(f" VRAM after init: {total_vram_start:.0f} MB") for i, cfg in enumerate(configs): print(f" {cfg:<22} model={model_mems[i]:.1f}MB opt={opt_mems[i]:.1f}MB") print(f"\nRunning {STEPS} steps (all configs parallel per step)...") t_total_start = time.perf_counter() for step in range(STEPS): lr = get_lr(step) for o in optimizers: for pg in o.param_groups: pg["lr"] = lr step_losses = [None] * len(configs) step_t0 = time.perf_counter() for i, (model, optimizer, stream) in enumerate(zip(models, optimizers, streams)): ix = torch.randint(0, len(train_data) - CTX_LEN - 1, (BATCH_SIZE,)) x = torch.stack([train_data[j : j + CTX_LEN] for j in ix]) targets = x[:, 3:] x = x.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) if stream is not None: with torch.cuda.stream(stream): optimizer.zero_grad() if device == "cuda" and "bf16" in configs[i]: with torch.autocast("cuda", dtype=torch.bfloat16): logits, loss = model(x, targets=targets) else: logits, loss = model(x, targets=targets) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() step_losses[i] = loss.item() else: optimizer.zero_grad() logits, loss = model(x, targets=targets) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() step_losses[i] = loss.item() if device == "cuda": torch.cuda.synchronize() step_t1 = time.perf_counter() wall_ms = (step_t1 - step_t0) * 1000 per_config_step_ms_flat = wall_ms / len(configs) for i in range(len(configs)): loss_histories[i].append(step_losses[i]) per_config_step_ms[i].append(per_config_step_ms_flat) if step % EVAL_INTERVAL == 0 or step == STEPS - 1: peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024) losses_str = " ".join(f"{l:.4f}" for l in step_losses) print( f" step {step:>5d}/{STEPS} | wall={wall_ms:.0f}ms | " f"vram={peak_vram:.0f}MB | losses: {losses_str}" ) t_total_end = time.perf_counter() total_seconds = t_total_end - t_total_start torch.cuda.synchronize() peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024) results = [] for i, cfg in enumerate(configs): final_100 = loss_histories[i][-100:] final_avg = sum(final_100) / len(final_100) min_loss = min(loss_histories[i]) avg_ms = sum(per_config_step_ms[i]) / len(per_config_step_ms[i]) opt_mem = count_optimizer_memory_mb(optimizers[i]) results.append({ "config": cfg, "n_params": sum(p.numel() for p in models[i].parameters()), "model_mem_mb": round(model_mems[i], 2), "optimizer_mem_mb": round(opt_mem, 2), "peak_vram_mb": round(peak_vram, 1), "final_loss_avg100": round(final_avg, 4), "min_loss": round(min_loss, 4), "loss_1000": round(loss_histories[i][min(999, STEPS - 1)], 4), "loss_2500": round(loss_histories[i][min(2499, STEPS - 1)], 4), "loss_5000": round(loss_histories[i][-1], 4), "avg_step_ms": round(avg_ms, 2), "loss_history": loss_histories[i], }) print(f"\n Total wall time: {total_seconds:.1f}s ({total_seconds/60:.1f}min)") print(f" Per-config effective: {total_seconds/len(configs):.1f}s") print(f" Peak VRAM: {peak_vram:.0f} MB (all 6 models)") del models, optimizers, streams gc.collect() torch.cuda.empty_cache() return results def print_summary_table(results): print(f"\n{'='*100}") print(f" BENCHMARK SUMMARY — {STEPS} steps, T32 ternary forward, all parallel") print(f"{'='*100}") header = ( f"{'Config':<22} {'FinalLoss':>10} {'MinLoss':>10} " f"{'Loss@1k':>10} {'Loss@2.5k':>10} {'Step(ms)':>10} " f"{'OptMem(MB)':>10} {'vsSignC':>8}" ) print(header) print("-" * 100) baseline = results[0]["final_loss_avg100"] for r in results: ratio = r["final_loss_avg100"] / baseline if baseline > 0 else 0 row = ( f"{r['config']:<22} {r['final_loss_avg100']:>10.4f} {r['min_loss']:>10.4f} " f"{r['loss_1000']:>10.4f} {r['loss_2500']:>10.4f} {r['avg_step_ms']:>10.1f} " f"{r['optimizer_mem_mb']:>10.2f} {ratio:>7.3f}x" ) print(row) print(f"\n Peak VRAM (all 6 combined): {results[0]['peak_vram_mb']:.0f} MB") print(f"\n--- Optimizer memory comparison ---") for r in results: print(f" {r['config']:<22} model={r['model_mem_mb']:.1f}MB opt={r['optimizer_mem_mb']:.1f}MB total={r['model_mem_mb']+r['optimizer_mem_mb']:.1f}MB") print(f"\n--- Loss ratio vs SignSGD ConfigC baseline ---") for r in results[1:]: ratio = r["final_loss_avg100"] / baseline print(f" {r['config']:<22} {ratio:.4f}x ({'better' if ratio < 1.0 else 'worse'})") if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if device == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Steps: {STEPS} | Warmup: {WARMUP} | Batch: {BATCH_SIZE} | CTX: {CTX_LEN}") print(f"Configs: {len(CONFIGS)} (all parallel)") train_data, val_data = download_data() print(f"Train: {len(train_data):,} bytes | Val: {len(val_data):,} bytes") results = run_parallel_benchmark(CONFIGS, train_data, device) print_summary_table(results) out_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "benchmark", "benchmark_phase2_results.json") save_results = { r["config"]: {k: v for k, v in r.items() if k != "loss_history"} for r in results } with open(out_path, "w") as f: json.dump(save_results, f, indent=2) print(f"\nResults saved to {out_path}")