| """ |
| Phase 2 Benchmark: 6-way optimizer comparison on pure ternary MORPH. |
| All 6 configs run in parallel on the same GPU. |
| |
| Configs (all T32 ternary forward): |
| 1. SignSGD + Config C (group-avg S, no shadow weight, no momentum) |
| 2. SignSGD + Config E (per-element S=|W|, no shadow weight, no momentum) |
| 3. Lion + bf16 shadow (bf16 model params, Lion momentum in FP32) |
| 4. Lion + FP32 shadow (FP32 model params, Lion momentum in FP32) |
| 5. Adam + bf16 shadow (bf16 model params, Adam m/v in FP32) |
| 6. Adam + FP32 shadow (FP32 model params, Adam m/v in FP32) |
| |
| Metrics: loss curve, step time (ms), peak VRAM (MB) |
| """ |
| import os |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import sys |
| import time |
| import json |
| import math |
| import gc |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import bitsandbytes as bnb |
| import urllib.request |
|
|
| from arbitor.main import MORPHTernaryModel, VOCAB, CTX, THRESHOLD, SPECIAL_VOCAB, StickyZoneSTE |
| from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, GROUP_SIZES |
| from arbitor.optim.sign_sgd import SignSGD |
|
|
|
|
| STEPS = 2500 |
| WARMUP = 250 |
| BATCH_SIZE = 64 |
| CTX_LEN = 66 |
| EVAL_INTERVAL = 250 |
| SEED = 42 |
| DATA_DIR = os.path.dirname(__file__) or "." |
|
|
| CONFIGS = [ |
| "SignSGD_ConfigC_T32", |
| "SignSGD_ConfigE_T32", |
| "Lion_bf16_T32", |
| "Lion_FP32_T32", |
| "Adam_bf16_T32", |
| "Adam_FP32_T32", |
| ] |
|
|
|
|
| def download_data(): |
| path = os.path.join(DATA_DIR, "tinyshakespeare.txt") |
| if not os.path.exists(path): |
| print("Downloading tinyshakespeare...") |
| urllib.request.urlretrieve( |
| "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", |
| path, |
| ) |
| with open(path, "r", encoding="utf-8") as f: |
| text = f.read() |
| byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) |
| n = int(0.9 * len(byte_data)) |
| return byte_data[:n], byte_data[n:] |
|
|
|
|
| def get_lr(step, max_lr=3e-4, min_lr=1e-5): |
| if step < WARMUP: |
| return max_lr * (step + 1) / WARMUP |
| progress = (step - WARMUP) / (STEPS - WARMUP) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def count_optimizer_memory_mb(optimizer): |
| total = 0 |
| for group in optimizer.param_groups: |
| for p in group["params"]: |
| total += p.numel() * p.element_size() |
| state = optimizer.state.get(p, {}) |
| for buf in state.values(): |
| if isinstance(buf, torch.Tensor): |
| total += buf.numel() * buf.element_size() |
| return total / (1024 * 1024) |
|
|
|
|
| def make_model(config_name, device): |
| if "ConfigE" in config_name: |
| tscale_type = TScaleType.T64 |
| else: |
| tscale_type = TScaleType.T32 |
| model = MORPHTernaryModel(tscale_type=tscale_type) |
| if "bf16" in config_name: |
| model = model.to(torch.bfloat16) |
| else: |
| model = model.to(torch.float32) |
| model = model.to(device) |
| return model |
|
|
|
|
| def make_optimizer(config_name, model_params, lr=3e-4, weight_decay=0.01): |
| if "SignSGD" in config_name: |
| return SignSGD(model_params, lr=lr, weight_decay=weight_decay) |
| elif "Lion" in config_name: |
| return bnb.optim.Lion(model_params, lr=lr, weight_decay=weight_decay) |
| elif "Adam" in config_name: |
| return torch.optim.Adam(model_params, lr=lr, weight_decay=weight_decay) |
| else: |
| raise ValueError(f"Unknown config: {config_name}") |
|
|
|
|
| def run_parallel_benchmark(configs, train_data, device): |
| torch.manual_seed(SEED) |
| torch.cuda.reset_peak_memory_stats(device) |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| models = [] |
| optimizers = [] |
| streams = [] |
| loss_histories = [[] for _ in configs] |
| per_config_step_ms = [[] for _ in configs] |
|
|
| print(f"\nInitializing {len(configs)} models on {device}...") |
| for i, cfg in enumerate(configs): |
| torch.manual_seed(SEED + i) |
| m = make_model(cfg, device) |
| o = make_optimizer(cfg, m.parameters()) |
| s = torch.cuda.Stream(device) if device == "cuda" else None |
| models.append(m) |
| optimizers.append(o) |
| streams.append(s) |
| n = sum(p.numel() for p in m.parameters()) |
| dtype = "bf16" if "bf16" in cfg else "FP32" |
| tscale = "ConfigE(T64)" if "ConfigE" in cfg else "ConfigC(T32)" |
| print(f" [{i}] {cfg:<22} params={n:,} dtype={dtype} tscale={tscale}") |
|
|
| total_vram_start = torch.cuda.memory_allocated(device) / (1024 * 1024) |
| opt_mems = [count_optimizer_memory_mb(o) for o in optimizers] |
| model_mems = [ |
| sum(p.numel() * p.element_size() for p in m.parameters()) / (1024 * 1024) |
| for m in models |
| ] |
| print(f" VRAM after init: {total_vram_start:.0f} MB") |
| for i, cfg in enumerate(configs): |
| print(f" {cfg:<22} model={model_mems[i]:.1f}MB opt={opt_mems[i]:.1f}MB") |
|
|
| print(f"\nRunning {STEPS} steps (all configs parallel per step)...") |
| t_total_start = time.perf_counter() |
|
|
| for step in range(STEPS): |
| lr = get_lr(step) |
| for o in optimizers: |
| for pg in o.param_groups: |
| pg["lr"] = lr |
|
|
| step_losses = [None] * len(configs) |
| step_t0 = time.perf_counter() |
|
|
| for i, (model, optimizer, stream) in enumerate(zip(models, optimizers, streams)): |
| ix = torch.randint(0, len(train_data) - CTX_LEN - 1, (BATCH_SIZE,)) |
| x = torch.stack([train_data[j : j + CTX_LEN] for j in ix]) |
| targets = x[:, 3:] |
| x = x.to(device, non_blocking=True) |
| targets = targets.to(device, non_blocking=True) |
|
|
| if stream is not None: |
| with torch.cuda.stream(stream): |
| optimizer.zero_grad() |
| if device == "cuda" and "bf16" in configs[i]: |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| logits, loss = model(x, targets=targets) |
| else: |
| logits, loss = model(x, targets=targets) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| step_losses[i] = loss.item() |
| else: |
| optimizer.zero_grad() |
| logits, loss = model(x, targets=targets) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| step_losses[i] = loss.item() |
|
|
| if device == "cuda": |
| torch.cuda.synchronize() |
|
|
| step_t1 = time.perf_counter() |
| wall_ms = (step_t1 - step_t0) * 1000 |
| per_config_step_ms_flat = wall_ms / len(configs) |
|
|
| for i in range(len(configs)): |
| loss_histories[i].append(step_losses[i]) |
| per_config_step_ms[i].append(per_config_step_ms_flat) |
|
|
| if step % EVAL_INTERVAL == 0 or step == STEPS - 1: |
| peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024) |
| losses_str = " ".join(f"{l:.4f}" for l in step_losses) |
| print( |
| f" step {step:>5d}/{STEPS} | wall={wall_ms:.0f}ms | " |
| f"vram={peak_vram:.0f}MB | losses: {losses_str}" |
| ) |
|
|
| t_total_end = time.perf_counter() |
| total_seconds = t_total_end - t_total_start |
|
|
| torch.cuda.synchronize() |
| peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024) |
|
|
| results = [] |
| for i, cfg in enumerate(configs): |
| final_100 = loss_histories[i][-100:] |
| final_avg = sum(final_100) / len(final_100) |
| min_loss = min(loss_histories[i]) |
| avg_ms = sum(per_config_step_ms[i]) / len(per_config_step_ms[i]) |
| opt_mem = count_optimizer_memory_mb(optimizers[i]) |
|
|
| results.append({ |
| "config": cfg, |
| "n_params": sum(p.numel() for p in models[i].parameters()), |
| "model_mem_mb": round(model_mems[i], 2), |
| "optimizer_mem_mb": round(opt_mem, 2), |
| "peak_vram_mb": round(peak_vram, 1), |
| "final_loss_avg100": round(final_avg, 4), |
| "min_loss": round(min_loss, 4), |
| "loss_1000": round(loss_histories[i][min(999, STEPS - 1)], 4), |
| "loss_2500": round(loss_histories[i][min(2499, STEPS - 1)], 4), |
| "loss_5000": round(loss_histories[i][-1], 4), |
| "avg_step_ms": round(avg_ms, 2), |
| "loss_history": loss_histories[i], |
| }) |
|
|
| print(f"\n Total wall time: {total_seconds:.1f}s ({total_seconds/60:.1f}min)") |
| print(f" Per-config effective: {total_seconds/len(configs):.1f}s") |
| print(f" Peak VRAM: {peak_vram:.0f} MB (all 6 models)") |
|
|
| del models, optimizers, streams |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def print_summary_table(results): |
| print(f"\n{'='*100}") |
| print(f" BENCHMARK SUMMARY — {STEPS} steps, T32 ternary forward, all parallel") |
| print(f"{'='*100}") |
| header = ( |
| f"{'Config':<22} {'FinalLoss':>10} {'MinLoss':>10} " |
| f"{'Loss@1k':>10} {'Loss@2.5k':>10} {'Step(ms)':>10} " |
| f"{'OptMem(MB)':>10} {'vsSignC':>8}" |
| ) |
| print(header) |
| print("-" * 100) |
|
|
| baseline = results[0]["final_loss_avg100"] |
| for r in results: |
| ratio = r["final_loss_avg100"] / baseline if baseline > 0 else 0 |
| row = ( |
| f"{r['config']:<22} {r['final_loss_avg100']:>10.4f} {r['min_loss']:>10.4f} " |
| f"{r['loss_1000']:>10.4f} {r['loss_2500']:>10.4f} {r['avg_step_ms']:>10.1f} " |
| f"{r['optimizer_mem_mb']:>10.2f} {ratio:>7.3f}x" |
| ) |
| print(row) |
|
|
| print(f"\n Peak VRAM (all 6 combined): {results[0]['peak_vram_mb']:.0f} MB") |
|
|
| print(f"\n--- Optimizer memory comparison ---") |
| for r in results: |
| print(f" {r['config']:<22} model={r['model_mem_mb']:.1f}MB opt={r['optimizer_mem_mb']:.1f}MB total={r['model_mem_mb']+r['optimizer_mem_mb']:.1f}MB") |
|
|
| print(f"\n--- Loss ratio vs SignSGD ConfigC baseline ---") |
| for r in results[1:]: |
| ratio = r["final_loss_avg100"] / baseline |
| print(f" {r['config']:<22} {ratio:.4f}x ({'better' if ratio < 1.0 else 'worse'})") |
|
|
|
|
| if __name__ == "__main__": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
| if device == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"Steps: {STEPS} | Warmup: {WARMUP} | Batch: {BATCH_SIZE} | CTX: {CTX_LEN}") |
| print(f"Configs: {len(CONFIGS)} (all parallel)") |
|
|
| train_data, val_data = download_data() |
| print(f"Train: {len(train_data):,} bytes | Val: {len(val_data):,} bytes") |
|
|
| results = run_parallel_benchmark(CONFIGS, train_data, device) |
| print_summary_table(results) |
|
|
| out_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "benchmark", "benchmark_phase2_results.json") |
| save_results = { |
| r["config"]: {k: v for k, v in r.items() if k != "loss_history"} |
| for r in results |
| } |
| with open(out_path, "w") as f: |
| json.dump(save_results, f, indent=2) |
| print(f"\nResults saved to {out_path}") |
|
|