""" True Ternary Benchmark: Compare training methods on ARBModel. Configs: 1. Adam_FP32 — standard FP32 Adam (full model, float params) 2. SignSGD_Old — SignSGD optimizer (full model, float params) 3. TrueTernary — pure ternary training (0 float params, T flips + E_accum) Metrics: loss curve, step time, peak VRAM, model/optimizer memory, convergence After REFACTOR6 (architecture ternarization), the internal model has 0 trainable float params. Adam_FP32 and SignSGD_Old use the pre-ternarization float weights. TrueTernary uses the post-REFACTOR6 strict ternary-only path. """ import os, sys, time, json, math, gc, argparse import torch import torch.nn as nn import torch.nn.functional as F sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.main import ARBModel, VOCAB, CTX, LossComponents from arbitor.kernel.ternary_scale import TScaleType from arbitor.kernel.ternary_scale import _triton_ternary_grad_sign, _triton_update_e, _triton_ternary_step from arbitor.optim.sign_sgd import SignSGD from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters STEPS = 50 WARMUP = 10 BATCH = 8 CTX_LEN = 66 SEED = 42 DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" DATA_PATH = os.path.join(os.path.dirname(__file__), "tinyshakespeare.txt") CONFIGS = [ "Adam_FP32", "SignSGD_Old", "TrueTernary", ] class NoTrainableParametersOptimizer: def __init__(self): self.param_groups = [] self.state = {} def zero_grad(self, *args, **kwargs): return None def step(self, *args, **kwargs): return None def download_data(): if not os.path.exists(DATA_PATH): import urllib.request print(" Downloading tinyshakespeare...") urllib.request.urlretrieve(DATA_URL, DATA_PATH) with open(DATA_PATH, "r", encoding="utf-8") as f: text = f.read() byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) n = int(0.9 * len(byte_data)) return byte_data[:n], byte_data[n:] def get_batch(data, device): ix = torch.randint(0, len(data) - CTX_LEN - 1, (BATCH,)) x = torch.stack([data[i: i + CTX_LEN] for i in ix]) targets = x[:, 3:] return x.to(device, non_blocking=True), targets.to(device, non_blocking=True) def get_lr(step, max_lr=1e-4, min_lr=1e-6): if step < WARMUP: return max_lr * (step + 1) / WARMUP progress = (step - WARMUP) / max(1, STEPS - WARMUP) return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) def cpu_update_memory(model, accum_threshold=3, loss_signal=None): """CPU-based update that avoids the Triton compilation bug (14s/step).""" import torch.nn.functional as F from arbitor.converters.convert_to_ternary8 import pack_ternary t_step = 1 if loss_signal is not None: loss_val = float(loss_signal.detach().clamp(min=0, max=32).item()) t_step = max(1, min(4, int(loss_val // 2) + 1)) for module in model.modules(): if not hasattr(module, 'update_E') and not hasattr(module, 'ternary_step'): continue has_grad = hasattr(module, '_hook_grad_T_sign') has_direct = hasattr(module, '_hook_grad_2d') and hasattr(module, '_hook_x_2d') if not has_grad and not has_direct: continue device = module.T_accum.device N, K = tuple(module._T_shape.tolist()) if has_direct: grads = module._hook_grad_2d xs = module._hook_x_2d grad_W = torch.matmul(grads.float().t(), xs.float()) grad_sign = grad_W.sign().to(torch.int8) else: grad_sign = module._hook_grad_T_sign.to(device=device) # --- update_E (CPU fixed-point residual path) --- if hasattr(module, 'update_E'): T_source = module._get_T() if not hasattr(module, '_hook_T') else module._hook_T T = T_source.to(device=device) grad_T = grad_sign.float() * T.float() gpr = (K + module.group_size - 1) // module.group_size total_in = gpr * module.group_size padded = F.pad(grad_T, (0, total_in - K)) grouped = padded.view(N, gpr, module.group_size) group_score = grouped.sum(dim=2) delta = -group_score.sign().to(torch.int8).flatten() if not hasattr(module, "E_accum"): module.register_buffer("E_accum", torch.zeros_like(module.E, dtype=torch.int8)) e_accum_threshold = int(getattr(module, "_e_accum_threshold", 4)) new_accum = torch.clamp(module.E_accum + delta, -128, 127).to(torch.int8) step_up = new_accum >= e_accum_threshold step_down = new_accum <= -e_accum_threshold e_step = torch.where(step_up, torch.ones_like(new_accum), torch.where(step_down, -torch.ones_like(new_accum), torch.zeros_like(new_accum))) module.E = torch.clamp(module.E.to(torch.int16) + e_step.to(torch.int16), -128, 127).to(torch.int8) module.E_accum = (new_accum.to(torch.int16) - e_step.to(torch.int16) * e_accum_threshold).to(torch.int8) # --- ternary_step (CPU T flip) --- if hasattr(module, 'ternary_step'): module.T_accum = torch.clamp(module.T_accum + grad_sign.to(device) * t_step, -128, 127).to(torch.int8) fu = module.T_accum > accum_threshold fd = module.T_accum < -accum_threshold if fu.any() or fd.any(): T = module._get_T().to(device) T[fu] = torch.tensor(1, dtype=T.dtype, device=device) T[fd] = torch.tensor(-1, dtype=T.dtype, device=device) torch.cuda.synchronize() module.T_packed = pack_ternary(T.cpu())[0].to(device=device) module.T_accum = torch.where(fu | fd, torch.zeros_like(module.T_accum), module.T_accum) # Clean up hooks if has_direct: del module._hook_grad_2d, module._hook_x_2d else: del module._hook_grad_T_sign def gpu_signcache_update_memory(model, accum_threshold=3, update_scales=True, loss_signal=None): """GPU update that computes one temporary int8 grad_sign per module, then frees it. This avoids the very slow per-packed-byte direct reduction path for benchmark shapes with large M = batch * sequence. It still keeps persistent model state ternary-first: packed T, int8 E, int8 accumulators, no FP master weights. """ t_step = 1 if loss_signal is not None: loss_val = float(loss_signal.detach().clamp(min=0, max=32).item()) t_step = max(1, min(4, int(loss_val // 2) + 1)) for module in model.modules(): has_grad = hasattr(module, '_hook_grad_T_sign') has_direct = hasattr(module, '_hook_grad_2d') and hasattr(module, '_hook_x_2d') if not has_grad and not has_direct: continue if has_direct: n_out, k_in = tuple(module._T_shape.tolist()) grad_sign = _triton_ternary_grad_sign(module._hook_grad_2d, module._hook_x_2d, n_out, k_in) module._hook_grad_T_sign = grad_sign del module._hook_grad_2d, module._hook_x_2d if update_scales and hasattr(module, 'update_E'): if getattr(module, "E", None) is not None and module.E.is_cuda and hasattr(module, "_hook_grad_T_sign"): n_out, k_in = tuple(module._T_shape.tolist()) if not hasattr(module, "E_accum"): module.register_buffer("E_accum", torch.zeros_like(module.E, dtype=torch.int8)) _triton_update_e( module.T_packed.contiguous(), module._hook_grad_T_sign.contiguous(), module.E, module.E_accum, n_out, k_in, module.group_size, int(getattr(module, "_e_accum_threshold", 4)), ) else: module.update_E(loss_signal=loss_signal) if hasattr(module, 'ternary_step'): if getattr(module, "T_packed", None) is not None and module.T_packed.is_cuda and hasattr(module, "_hook_grad_T_sign"): total = int(module._T_shape[0].item() * module._T_shape[1].item()) _triton_ternary_step( module.T_packed, module._hook_grad_T_sign.contiguous(), module.T_accum, total, accum_threshold, t_step, ) del module._hook_grad_T_sign else: module.ternary_step(accum_threshold=accum_threshold) def build_model(strict_ternary): return ARBModel( tscale_type=TScaleType.T32, enable_image=not strict_ternary, enable_audio=not strict_ternary, enable_vq=not strict_ternary, enable_graph=not strict_ternary, enable_memory_modules=not strict_ternary, enable_moe=True, ) def run_config( name, device, base_state=None, strict_true_ternary=True, update_backend="gpu", scale_update_interval=4, accum_threshold=3, print_every=1, ): torch.manual_seed(SEED) torch.cuda.reset_peak_memory_stats(device) torch.cuda.empty_cache() gc.collect() is_true_ternary = "TrueTernary" in name is_signsgd = "SignSGD" in name or "TrueTernary" in name use_bf16 = "BF16" in name # TrueTernary always uses strict mode (0 float params, no encoders) strict_model = "TrueTernary" in name if strict_model: model = build_model(strict_ternary=True).to(device) freeze_float_parameters(model) elif base_state is not None: model = build_model(strict_ternary=False).to(device) model.load_state_dict(base_state, strict=False) # Re-freeze ViT/audio params that load_state_dict may have unfrozen for param_name, p in model.named_parameters(): bn = param_name.split('.')[0] if bn in ('vit', 'image_sequencer', 'audio_sequencer'): p.requires_grad = False else: model = build_model(strict_ternary=strict_model).to(device) if strict_model: freeze_float_parameters(model) opt_params = trainable_parameters(model) if use_bf16: import bitsandbytes as bnb print(f" Creating Adam8bit optimizer...", flush=True) optimizer = bnb.optim.Adam8bit(opt_params, lr=1e-4, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer() elif name == "Adam_FP32": print(f" Creating Adam FP32 optimizer...", flush=True) optimizer = torch.optim.Adam(opt_params, lr=1e-4, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer() elif is_signsgd: print(f" Creating SignSGD optimizer...", flush=True) optimizer = SignSGD(opt_params, lr=0.001, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer() else: raise ValueError(f"Unknown config: {name}") n_params = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) # Compute persistent ternary memory ternary_bytes = 0 for buf_name, buf in model.named_buffers(): if 'T_packed' in buf_name: ternary_bytes += buf.numel() e_bytes = sum(b.numel() for n, b in model.named_buffers() if n.endswith('.E')) e_accum_bytes = sum(b.numel() for n, b in model.named_buffers() if n.endswith('.E_accum')) ternary_p_unique = ternary_bytes * 5 # 5 trits per byte e_count = e_bytes # int8 E # Memory accounting model_mem = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) opt_mem = 0 for g in optimizer.param_groups: for p in g["params"]: opt_mem += p.numel() * p.element_size() state = optimizer.state.get(p, {}) for v in state.values(): if isinstance(v, torch.Tensor): opt_mem += v.numel() * v.element_size() opt_mem /= 1024 * 1024 buf_mem = sum(b.numel() * b.element_size() for n, b in model.named_buffers()) / (1024 * 1024) print(f"\n [{name}]", flush=True) print(f" Params: {n_params:,} total, {trainable:,} trainable", flush=True) print(f" Model mode: {'strict ternary text-only' if strict_model else 'full multimodal'}") print(format_audit(audit_model(model), limit=5), flush=True) print(f" Ternary: ~{ternary_p_unique/1e6:.1f}M packed trits, {e_count:,} int8 E values, {e_accum_bytes:,} int8 E_accum values") print(f" Model weights: {model_mem:.1f}MB | Buffers: {buf_mem:.1f}MB | Optimizer: {opt_mem:.1f}MB") print(f" Compiling warmup...", end=" ", flush=True) # Warmup forward pass to trigger JIT compilation x_warm, t_warm = get_batch(train_data, device) with torch.no_grad(): with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16): _ = model(x_warm, targets=t_warm) torch.cuda.synchronize() print(f"done.", flush=True) if device == "cuda": torch.cuda.reset_peak_memory_stats(device) loss_history = [] step_times = [] for step in range(STEPS): lr = get_lr(step) for pg in optimizer.param_groups: pg["lr"] = lr x, targets = get_batch(train_data, device) t0 = time.perf_counter() optimizer.zero_grad() with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16): logits, losses, _, _ = model(x, targets=targets) losses.total.backward() if opt_params: torch.nn.utils.clip_grad_norm_(opt_params, 1.0) optimizer.step() if is_true_ternary: update_scales = scale_update_interval > 0 and step % scale_update_interval == 0 if update_backend == "gpu": model._ternary_update_memory( accum_threshold=accum_threshold, update_scales=update_scales, loss_signal=losses.total, ) elif update_backend == "gpu-signcache": gpu_signcache_update_memory( model, accum_threshold=accum_threshold, update_scales=update_scales, loss_signal=losses.total, ) elif update_backend == "dense-fallback": if update_scales: cpu_update_memory(model, accum_threshold=accum_threshold, loss_signal=losses.total) else: model._ternary_update_memory( accum_threshold=accum_threshold, update_scales=False, loss_signal=losses.total, ) elif update_backend != "none": raise ValueError(f"Unknown update backend: {update_backend}") if device == "cuda": torch.cuda.synchronize() t1 = time.perf_counter() loss = losses.total.item() loss_history.append(loss) step_ms = (t1 - t0) * 1000 step_times.append(step_ms) if step % print_every == 0 or step == STEPS - 1: peak = torch.cuda.max_memory_allocated(device) / (1024 * 1024) allocated = torch.cuda.memory_allocated(device) / (1024 * 1024) reserved = torch.cuda.memory_reserved(device) / (1024 * 1024) toks_sec = BATCH * (CTX_LEN - 3) / (step_ms / 1000) print( f" step {step:>4d}/{STEPS} | loss={loss:.4f} | {step_ms:.0f}ms | " f"{toks_sec:.0f} tok/s | alloc={allocated:.0f}MB reserved={reserved:.0f}MB peak={peak:.0f}MB", flush=True, ) final_window = loss_history[-min(20, len(loss_history)):] final_avg = sum(final_window) / len(final_window) min_loss = min(loss_history) avg_step_ms = sum(step_times[WARMUP:]) / len(step_times[WARMUP:]) avg_toks_sec = BATCH * (CTX_LEN - 3) / (avg_step_ms / 1000) peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024) del model, optimizer gc.collect() torch.cuda.empty_cache() return { "config": name, "n_params": n_params, "trainable_params": trainable, "model_mem_mb": round(model_mem, 1), "optimizer_mem_mb": round(opt_mem, 1), "buffer_mem_mb": round(buf_mem, 1), "peak_vram_mb": round(peak_vram, 1), "final_loss_avg20": round(final_avg, 4), "min_loss": round(min_loss, 4), "avg_step_ms": round(avg_step_ms, 1), "avg_toks_sec": round(avg_toks_sec, 1), "loss_history": [round(l, 4) for l in loss_history], } if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark full or strict true-ternary MORPH configs.") parser.add_argument("--steps", type=int, default=STEPS) parser.add_argument("--warmup", type=int, default=WARMUP) parser.add_argument("--batch", type=int, default=BATCH) parser.add_argument("--ctx", type=int, default=CTX_LEN) parser.add_argument("--configs", type=str, default=",".join(CONFIGS), help="Comma-separated configs: Adam_FP32,SignSGD_Old,TrueTernary") parser.add_argument("--strict-true-ternary", action=argparse.BooleanOptionalAction, default=True, help="Run TrueTernary as text-only strict ternary with frozen float params.") parser.add_argument("--update-backend", choices=["gpu", "gpu-signcache", "dense-fallback", "none"], default="gpu-signcache", help="TrueTernary state update implementation.") parser.add_argument("--scale-update-interval", type=int, default=4, help="Update int8 E every N TrueTernary steps. 0 disables E updates.") parser.add_argument("--accum-threshold", type=int, default=3, help="T_accum threshold for ternary sign flips.") parser.add_argument("--print-every", type=int, default=1) parser.add_argument("--reuse-base", action=argparse.BooleanOptionalAction, default=False, help="Create one full base model on CPU and load it into full-model configs.") args = parser.parse_args() STEPS = args.steps WARMUP = args.warmup BATCH = args.batch CTX_LEN = args.ctx CONFIGS = [item.strip() for item in args.configs.split(",") if item.strip()] device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if device == "cuda": print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") print("\nDownloading data...") global train_data, val_data train_data, val_data = download_data() print(f" Train: {len(train_data):,} bytes, Val: {len(val_data):,} bytes") print(f" Batch={BATCH}, CTX={CTX_LEN}, Steps={STEPS}, Warmup={WARMUP}") results = [] t_all_0 = time.perf_counter() base_state = None if args.reuse_base and any(cfg != "TrueTernary" or not args.strict_true_ternary for cfg in CONFIGS): # Keep reusable initialization on CPU so it does not inflate per-config VRAM. print(f"\nCreating base model (CPU state reuse)...", flush=True) base_model = build_model(strict_ternary=False) base_state = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} del base_model gc.collect() if device == "cuda": torch.cuda.empty_cache() print(" Done.", flush=True) for cfg in CONFIGS: r = run_config( cfg, device, base_state=base_state, strict_true_ternary=args.strict_true_ternary, update_backend=args.update_backend, scale_update_interval=args.scale_update_interval, accum_threshold=args.accum_threshold, print_every=max(1, args.print_every), ) results.append(r) gc.collect() torch.cuda.empty_cache() t_all = time.perf_counter() - t_all_0 # Summary table print(f"\n{'='*90}") print(f" BENCHMARK RESULTS — {STEPS} steps, {BATCH}x{CTX_LEN} batch") print(f"{'='*90}") print(f" {'Config':<20} {'Loss(avg20)':<12} {'Loss(min)':<10} {'Step(ms)':<10} {'tok/s':<10} {'PeakMB':<8} {'ModelMB':<8} {'OptMB':<8}") print(f" {'-'*86}") for r in results: print(f" {r['config']:<20} {r['final_loss_avg20']:<12} {r['min_loss']:<10} {r['avg_step_ms']:<10} {r['avg_toks_sec']:<10} {r['peak_vram_mb']:<8} {r['model_mem_mb']:<8} {r['optimizer_mem_mb']:<8}") # Compare to baseline baseline = None for r in results: if r['config'] == 'Adam_FP32': baseline = r break if baseline: print(f"\n {'─'*86}") print(f" {'Relative to Adam_FP32':<50}") print(f" {'─'*86}") for r in results: if r['config'] == 'Adam_FP32': continue loss_ratio = r['final_loss_avg20'] / baseline['final_loss_avg20'] speed_ratio = baseline['avg_toks_sec'] / r['avg_toks_sec'] if r['avg_toks_sec'] > 0 else float('inf') vram_ratio = r['peak_vram_mb'] / baseline['peak_vram_mb'] print(f" {r['config']:<20} loss={loss_ratio:.2f}x speed={speed_ratio:.2f}x vram={vram_ratio:.2f}x") # Save results out = { "config": "True Ternary vs Baselines", "steps": STEPS, "batch": BATCH, "context": CTX_LEN, "total_time_s": round(t_all, 1), "results": results, } path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "benchmark", "benchmark_results.json") with open(path, "w") as f: json.dump(out, f, indent=2) print(f"\n Results saved to {path}") print(f" Total benchmark time: {t_all:.0f}s ({t_all/60:.1f}min)")