| """ |
| True Ternary Benchmark: Compare training methods on ARBModel. |
| |
| Configs: |
| 1. Adam_FP32 — standard FP32 Adam (full model, float params) |
| 2. SignSGD_Old — SignSGD optimizer (full model, float params) |
| 3. TrueTernary — pure ternary training (0 float params, T flips + E_accum) |
| |
| Metrics: loss curve, step time, peak VRAM, model/optimizer memory, convergence |
| |
| After REFACTOR6 (architecture ternarization), the internal model has 0 trainable |
| float params. Adam_FP32 and SignSGD_Old use the pre-ternarization float weights. |
| TrueTernary uses the post-REFACTOR6 strict ternary-only path. |
| """ |
| import os, sys, time, json, math, gc, argparse |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| from arbitor.main import ARBModel, VOCAB, CTX, LossComponents |
| from arbitor.kernel.ternary_scale import TScaleType |
| from arbitor.kernel.ternary_scale import _triton_ternary_grad_sign, _triton_update_e, _triton_ternary_step |
| from arbitor.optim.sign_sgd import SignSGD |
| from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters |
|
|
| STEPS = 50 |
| WARMUP = 10 |
| BATCH = 8 |
| CTX_LEN = 66 |
| SEED = 42 |
|
|
| DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
| DATA_PATH = os.path.join(os.path.dirname(__file__), "tinyshakespeare.txt") |
|
|
| CONFIGS = [ |
| "Adam_FP32", |
| "SignSGD_Old", |
| "TrueTernary", |
| ] |
|
|
|
|
| class NoTrainableParametersOptimizer: |
| def __init__(self): |
| self.param_groups = [] |
| self.state = {} |
|
|
| def zero_grad(self, *args, **kwargs): |
| return None |
|
|
| def step(self, *args, **kwargs): |
| return None |
|
|
|
|
| def download_data(): |
| if not os.path.exists(DATA_PATH): |
| import urllib.request |
| print(" Downloading tinyshakespeare...") |
| urllib.request.urlretrieve(DATA_URL, DATA_PATH) |
| with open(DATA_PATH, "r", encoding="utf-8") as f: |
| text = f.read() |
| byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) |
| n = int(0.9 * len(byte_data)) |
| return byte_data[:n], byte_data[n:] |
|
|
|
|
| def get_batch(data, device): |
| ix = torch.randint(0, len(data) - CTX_LEN - 1, (BATCH,)) |
| x = torch.stack([data[i: i + CTX_LEN] for i in ix]) |
| targets = x[:, 3:] |
| return x.to(device, non_blocking=True), targets.to(device, non_blocking=True) |
|
|
|
|
| def get_lr(step, max_lr=1e-4, min_lr=1e-6): |
| if step < WARMUP: |
| return max_lr * (step + 1) / WARMUP |
| progress = (step - WARMUP) / max(1, STEPS - WARMUP) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def cpu_update_memory(model, accum_threshold=3, loss_signal=None): |
| """CPU-based update that avoids the Triton compilation bug (14s/step).""" |
| import torch.nn.functional as F |
| from arbitor.converters.convert_to_ternary8 import pack_ternary |
| t_step = 1 |
| if loss_signal is not None: |
| loss_val = float(loss_signal.detach().clamp(min=0, max=32).item()) |
| t_step = max(1, min(4, int(loss_val // 2) + 1)) |
| for module in model.modules(): |
| if not hasattr(module, 'update_E') and not hasattr(module, 'ternary_step'): |
| continue |
| has_grad = hasattr(module, '_hook_grad_T_sign') |
| has_direct = hasattr(module, '_hook_grad_2d') and hasattr(module, '_hook_x_2d') |
| if not has_grad and not has_direct: |
| continue |
|
|
| device = module.T_accum.device |
| N, K = tuple(module._T_shape.tolist()) |
| if has_direct: |
| grads = module._hook_grad_2d |
| xs = module._hook_x_2d |
| grad_W = torch.matmul(grads.float().t(), xs.float()) |
| grad_sign = grad_W.sign().to(torch.int8) |
| else: |
| grad_sign = module._hook_grad_T_sign.to(device=device) |
|
|
| |
| if hasattr(module, 'update_E'): |
| T_source = module._get_T() if not hasattr(module, '_hook_T') else module._hook_T |
| T = T_source.to(device=device) |
| grad_T = grad_sign.float() * T.float() |
| gpr = (K + module.group_size - 1) // module.group_size |
| total_in = gpr * module.group_size |
| padded = F.pad(grad_T, (0, total_in - K)) |
| grouped = padded.view(N, gpr, module.group_size) |
| group_score = grouped.sum(dim=2) |
| delta = -group_score.sign().to(torch.int8).flatten() |
| if not hasattr(module, "E_accum"): |
| module.register_buffer("E_accum", torch.zeros_like(module.E, dtype=torch.int8)) |
| e_accum_threshold = int(getattr(module, "_e_accum_threshold", 4)) |
| new_accum = torch.clamp(module.E_accum + delta, -128, 127).to(torch.int8) |
| step_up = new_accum >= e_accum_threshold |
| step_down = new_accum <= -e_accum_threshold |
| e_step = torch.where(step_up, torch.ones_like(new_accum), |
| torch.where(step_down, -torch.ones_like(new_accum), torch.zeros_like(new_accum))) |
| module.E = torch.clamp(module.E.to(torch.int16) + e_step.to(torch.int16), -128, 127).to(torch.int8) |
| module.E_accum = (new_accum.to(torch.int16) - e_step.to(torch.int16) * e_accum_threshold).to(torch.int8) |
|
|
| |
| if hasattr(module, 'ternary_step'): |
| module.T_accum = torch.clamp(module.T_accum + grad_sign.to(device) * t_step, -128, 127).to(torch.int8) |
| fu = module.T_accum > accum_threshold |
| fd = module.T_accum < -accum_threshold |
| if fu.any() or fd.any(): |
| T = module._get_T().to(device) |
| T[fu] = torch.tensor(1, dtype=T.dtype, device=device) |
| T[fd] = torch.tensor(-1, dtype=T.dtype, device=device) |
| torch.cuda.synchronize() |
| module.T_packed = pack_ternary(T.cpu())[0].to(device=device) |
| module.T_accum = torch.where(fu | fd, torch.zeros_like(module.T_accum), module.T_accum) |
|
|
| |
| if has_direct: |
| del module._hook_grad_2d, module._hook_x_2d |
| else: |
| del module._hook_grad_T_sign |
|
|
|
|
| def gpu_signcache_update_memory(model, accum_threshold=3, update_scales=True, loss_signal=None): |
| """GPU update that computes one temporary int8 grad_sign per module, then frees it. |
| |
| This avoids the very slow per-packed-byte direct reduction path for benchmark |
| shapes with large M = batch * sequence. It still keeps persistent model state |
| ternary-first: packed T, int8 E, int8 accumulators, no FP master weights. |
| """ |
| t_step = 1 |
| if loss_signal is not None: |
| loss_val = float(loss_signal.detach().clamp(min=0, max=32).item()) |
| t_step = max(1, min(4, int(loss_val // 2) + 1)) |
| for module in model.modules(): |
| has_grad = hasattr(module, '_hook_grad_T_sign') |
| has_direct = hasattr(module, '_hook_grad_2d') and hasattr(module, '_hook_x_2d') |
| if not has_grad and not has_direct: |
| continue |
|
|
| if has_direct: |
| n_out, k_in = tuple(module._T_shape.tolist()) |
| grad_sign = _triton_ternary_grad_sign(module._hook_grad_2d, module._hook_x_2d, n_out, k_in) |
| module._hook_grad_T_sign = grad_sign |
| del module._hook_grad_2d, module._hook_x_2d |
|
|
| if update_scales and hasattr(module, 'update_E'): |
| if getattr(module, "E", None) is not None and module.E.is_cuda and hasattr(module, "_hook_grad_T_sign"): |
| n_out, k_in = tuple(module._T_shape.tolist()) |
| if not hasattr(module, "E_accum"): |
| module.register_buffer("E_accum", torch.zeros_like(module.E, dtype=torch.int8)) |
| _triton_update_e( |
| module.T_packed.contiguous(), |
| module._hook_grad_T_sign.contiguous(), |
| module.E, |
| module.E_accum, |
| n_out, |
| k_in, |
| module.group_size, |
| int(getattr(module, "_e_accum_threshold", 4)), |
| ) |
| else: |
| module.update_E(loss_signal=loss_signal) |
|
|
| if hasattr(module, 'ternary_step'): |
| if getattr(module, "T_packed", None) is not None and module.T_packed.is_cuda and hasattr(module, "_hook_grad_T_sign"): |
| total = int(module._T_shape[0].item() * module._T_shape[1].item()) |
| _triton_ternary_step( |
| module.T_packed, |
| module._hook_grad_T_sign.contiguous(), |
| module.T_accum, |
| total, |
| accum_threshold, |
| t_step, |
| ) |
| del module._hook_grad_T_sign |
| else: |
| module.ternary_step(accum_threshold=accum_threshold) |
|
|
|
|
| def build_model(strict_ternary): |
| return ARBModel( |
| tscale_type=TScaleType.T32, |
| enable_image=not strict_ternary, |
| enable_audio=not strict_ternary, |
| enable_vq=not strict_ternary, |
| enable_graph=not strict_ternary, |
| enable_memory_modules=not strict_ternary, |
| enable_moe=True, |
| ) |
|
|
|
|
| def run_config( |
| name, |
| device, |
| base_state=None, |
| strict_true_ternary=True, |
| update_backend="gpu", |
| scale_update_interval=4, |
| accum_threshold=3, |
| print_every=1, |
| ): |
| torch.manual_seed(SEED) |
| torch.cuda.reset_peak_memory_stats(device) |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| is_true_ternary = "TrueTernary" in name |
| is_signsgd = "SignSGD" in name or "TrueTernary" in name |
| use_bf16 = "BF16" in name |
|
|
| |
| strict_model = "TrueTernary" in name |
|
|
| if strict_model: |
| model = build_model(strict_ternary=True).to(device) |
| freeze_float_parameters(model) |
| elif base_state is not None: |
| model = build_model(strict_ternary=False).to(device) |
| model.load_state_dict(base_state, strict=False) |
| |
| for param_name, p in model.named_parameters(): |
| bn = param_name.split('.')[0] |
| if bn in ('vit', 'image_sequencer', 'audio_sequencer'): |
| p.requires_grad = False |
| else: |
| model = build_model(strict_ternary=strict_model).to(device) |
|
|
| if strict_model: |
| freeze_float_parameters(model) |
|
|
| opt_params = trainable_parameters(model) |
| if use_bf16: |
| import bitsandbytes as bnb |
| print(f" Creating Adam8bit optimizer...", flush=True) |
| optimizer = bnb.optim.Adam8bit(opt_params, lr=1e-4, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer() |
| elif name == "Adam_FP32": |
| print(f" Creating Adam FP32 optimizer...", flush=True) |
| optimizer = torch.optim.Adam(opt_params, lr=1e-4, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer() |
| elif is_signsgd: |
| print(f" Creating SignSGD optimizer...", flush=True) |
| optimizer = SignSGD(opt_params, lr=0.001, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer() |
| else: |
| raise ValueError(f"Unknown config: {name}") |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| |
| ternary_bytes = 0 |
| for buf_name, buf in model.named_buffers(): |
| if 'T_packed' in buf_name: |
| ternary_bytes += buf.numel() |
| e_bytes = sum(b.numel() for n, b in model.named_buffers() if n.endswith('.E')) |
| e_accum_bytes = sum(b.numel() for n, b in model.named_buffers() if n.endswith('.E_accum')) |
| ternary_p_unique = ternary_bytes * 5 |
| e_count = e_bytes |
|
|
| |
| model_mem = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) |
| opt_mem = 0 |
| for g in optimizer.param_groups: |
| for p in g["params"]: |
| opt_mem += p.numel() * p.element_size() |
| state = optimizer.state.get(p, {}) |
| for v in state.values(): |
| if isinstance(v, torch.Tensor): |
| opt_mem += v.numel() * v.element_size() |
| opt_mem /= 1024 * 1024 |
| buf_mem = sum(b.numel() * b.element_size() for n, b in model.named_buffers()) / (1024 * 1024) |
|
|
| print(f"\n [{name}]", flush=True) |
| print(f" Params: {n_params:,} total, {trainable:,} trainable", flush=True) |
| print(f" Model mode: {'strict ternary text-only' if strict_model else 'full multimodal'}") |
| print(format_audit(audit_model(model), limit=5), flush=True) |
| print(f" Ternary: ~{ternary_p_unique/1e6:.1f}M packed trits, {e_count:,} int8 E values, {e_accum_bytes:,} int8 E_accum values") |
| print(f" Model weights: {model_mem:.1f}MB | Buffers: {buf_mem:.1f}MB | Optimizer: {opt_mem:.1f}MB") |
| print(f" Compiling warmup...", end=" ", flush=True) |
|
|
| |
| x_warm, t_warm = get_batch(train_data, device) |
| with torch.no_grad(): |
| with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16): |
| _ = model(x_warm, targets=t_warm) |
| torch.cuda.synchronize() |
| print(f"done.", flush=True) |
| if device == "cuda": |
| torch.cuda.reset_peak_memory_stats(device) |
|
|
| loss_history = [] |
| step_times = [] |
|
|
| for step in range(STEPS): |
| lr = get_lr(step) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| x, targets = get_batch(train_data, device) |
| t0 = time.perf_counter() |
|
|
| optimizer.zero_grad() |
| with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16): |
| logits, losses, _, _ = model(x, targets=targets) |
|
|
| losses.total.backward() |
| if opt_params: |
| torch.nn.utils.clip_grad_norm_(opt_params, 1.0) |
| optimizer.step() |
|
|
| if is_true_ternary: |
| update_scales = scale_update_interval > 0 and step % scale_update_interval == 0 |
| if update_backend == "gpu": |
| model._ternary_update_memory( |
| accum_threshold=accum_threshold, |
| update_scales=update_scales, |
| loss_signal=losses.total, |
| ) |
| elif update_backend == "gpu-signcache": |
| gpu_signcache_update_memory( |
| model, |
| accum_threshold=accum_threshold, |
| update_scales=update_scales, |
| loss_signal=losses.total, |
| ) |
| elif update_backend == "dense-fallback": |
| if update_scales: |
| cpu_update_memory(model, accum_threshold=accum_threshold, loss_signal=losses.total) |
| else: |
| model._ternary_update_memory( |
| accum_threshold=accum_threshold, |
| update_scales=False, |
| loss_signal=losses.total, |
| ) |
| elif update_backend != "none": |
| raise ValueError(f"Unknown update backend: {update_backend}") |
|
|
| if device == "cuda": |
| torch.cuda.synchronize() |
| t1 = time.perf_counter() |
|
|
| loss = losses.total.item() |
| loss_history.append(loss) |
| step_ms = (t1 - t0) * 1000 |
| step_times.append(step_ms) |
|
|
| if step % print_every == 0 or step == STEPS - 1: |
| peak = torch.cuda.max_memory_allocated(device) / (1024 * 1024) |
| allocated = torch.cuda.memory_allocated(device) / (1024 * 1024) |
| reserved = torch.cuda.memory_reserved(device) / (1024 * 1024) |
| toks_sec = BATCH * (CTX_LEN - 3) / (step_ms / 1000) |
| print( |
| f" step {step:>4d}/{STEPS} | loss={loss:.4f} | {step_ms:.0f}ms | " |
| f"{toks_sec:.0f} tok/s | alloc={allocated:.0f}MB reserved={reserved:.0f}MB peak={peak:.0f}MB", |
| flush=True, |
| ) |
|
|
| final_window = loss_history[-min(20, len(loss_history)):] |
| final_avg = sum(final_window) / len(final_window) |
| min_loss = min(loss_history) |
| avg_step_ms = sum(step_times[WARMUP:]) / len(step_times[WARMUP:]) |
| avg_toks_sec = BATCH * (CTX_LEN - 3) / (avg_step_ms / 1000) |
| peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024) |
|
|
| del model, optimizer |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return { |
| "config": name, |
| "n_params": n_params, |
| "trainable_params": trainable, |
| "model_mem_mb": round(model_mem, 1), |
| "optimizer_mem_mb": round(opt_mem, 1), |
| "buffer_mem_mb": round(buf_mem, 1), |
| "peak_vram_mb": round(peak_vram, 1), |
| "final_loss_avg20": round(final_avg, 4), |
| "min_loss": round(min_loss, 4), |
| "avg_step_ms": round(avg_step_ms, 1), |
| "avg_toks_sec": round(avg_toks_sec, 1), |
| "loss_history": [round(l, 4) for l in loss_history], |
| } |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Benchmark full or strict true-ternary MORPH configs.") |
| parser.add_argument("--steps", type=int, default=STEPS) |
| parser.add_argument("--warmup", type=int, default=WARMUP) |
| parser.add_argument("--batch", type=int, default=BATCH) |
| parser.add_argument("--ctx", type=int, default=CTX_LEN) |
| parser.add_argument("--configs", type=str, default=",".join(CONFIGS), |
| help="Comma-separated configs: Adam_FP32,SignSGD_Old,TrueTernary") |
| parser.add_argument("--strict-true-ternary", action=argparse.BooleanOptionalAction, default=True, |
| help="Run TrueTernary as text-only strict ternary with frozen float params.") |
| parser.add_argument("--update-backend", choices=["gpu", "gpu-signcache", "dense-fallback", "none"], default="gpu-signcache", |
| help="TrueTernary state update implementation.") |
| parser.add_argument("--scale-update-interval", type=int, default=4, |
| help="Update int8 E every N TrueTernary steps. 0 disables E updates.") |
| parser.add_argument("--accum-threshold", type=int, default=3, |
| help="T_accum threshold for ternary sign flips.") |
| parser.add_argument("--print-every", type=int, default=1) |
| parser.add_argument("--reuse-base", action=argparse.BooleanOptionalAction, default=False, |
| help="Create one full base model on CPU and load it into full-model configs.") |
| args = parser.parse_args() |
|
|
| STEPS = args.steps |
| WARMUP = args.warmup |
| BATCH = args.batch |
| CTX_LEN = args.ctx |
| CONFIGS = [item.strip() for item in args.configs.split(",") if item.strip()] |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
| if device == "cuda": |
| print(f" GPU: {torch.cuda.get_device_name(0)}") |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") |
|
|
| print("\nDownloading data...") |
| global train_data, val_data |
| train_data, val_data = download_data() |
| print(f" Train: {len(train_data):,} bytes, Val: {len(val_data):,} bytes") |
| print(f" Batch={BATCH}, CTX={CTX_LEN}, Steps={STEPS}, Warmup={WARMUP}") |
|
|
| results = [] |
| t_all_0 = time.perf_counter() |
|
|
| base_state = None |
| if args.reuse_base and any(cfg != "TrueTernary" or not args.strict_true_ternary for cfg in CONFIGS): |
| |
| print(f"\nCreating base model (CPU state reuse)...", flush=True) |
| base_model = build_model(strict_ternary=False) |
| base_state = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} |
| del base_model |
| gc.collect() |
| if device == "cuda": |
| torch.cuda.empty_cache() |
| print(" Done.", flush=True) |
|
|
| for cfg in CONFIGS: |
| r = run_config( |
| cfg, |
| device, |
| base_state=base_state, |
| strict_true_ternary=args.strict_true_ternary, |
| update_backend=args.update_backend, |
| scale_update_interval=args.scale_update_interval, |
| accum_threshold=args.accum_threshold, |
| print_every=max(1, args.print_every), |
| ) |
| results.append(r) |
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| t_all = time.perf_counter() - t_all_0 |
|
|
| |
| print(f"\n{'='*90}") |
| print(f" BENCHMARK RESULTS — {STEPS} steps, {BATCH}x{CTX_LEN} batch") |
| print(f"{'='*90}") |
| print(f" {'Config':<20} {'Loss(avg20)':<12} {'Loss(min)':<10} {'Step(ms)':<10} {'tok/s':<10} {'PeakMB':<8} {'ModelMB':<8} {'OptMB':<8}") |
| print(f" {'-'*86}") |
| for r in results: |
| print(f" {r['config']:<20} {r['final_loss_avg20']:<12} {r['min_loss']:<10} {r['avg_step_ms']:<10} {r['avg_toks_sec']:<10} {r['peak_vram_mb']:<8} {r['model_mem_mb']:<8} {r['optimizer_mem_mb']:<8}") |
|
|
| |
| baseline = None |
| for r in results: |
| if r['config'] == 'Adam_FP32': |
| baseline = r |
| break |
| if baseline: |
| print(f"\n {'─'*86}") |
| print(f" {'Relative to Adam_FP32':<50}") |
| print(f" {'─'*86}") |
| for r in results: |
| if r['config'] == 'Adam_FP32': |
| continue |
| loss_ratio = r['final_loss_avg20'] / baseline['final_loss_avg20'] |
| speed_ratio = baseline['avg_toks_sec'] / r['avg_toks_sec'] if r['avg_toks_sec'] > 0 else float('inf') |
| vram_ratio = r['peak_vram_mb'] / baseline['peak_vram_mb'] |
| print(f" {r['config']:<20} loss={loss_ratio:.2f}x speed={speed_ratio:.2f}x vram={vram_ratio:.2f}x") |
|
|
| |
| out = { |
| "config": "True Ternary vs Baselines", |
| "steps": STEPS, |
| "batch": BATCH, |
| "context": CTX_LEN, |
| "total_time_s": round(t_all, 1), |
| "results": results, |
| } |
| path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "benchmark", "benchmark_results.json") |
| with open(path, "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"\n Results saved to {path}") |
| print(f" Total benchmark time: {t_all:.0f}s ({t_all/60:.1f}min)") |
|
|