ARBS / testing /benchmarks /benchmark_true_ternary.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
True Ternary Benchmark: Compare training methods on ARBModel.
Configs:
1. Adam_FP32 — standard FP32 Adam (full model, float params)
2. SignSGD_Old — SignSGD optimizer (full model, float params)
3. TrueTernary — pure ternary training (0 float params, T flips + E_accum)
Metrics: loss curve, step time, peak VRAM, model/optimizer memory, convergence
After REFACTOR6 (architecture ternarization), the internal model has 0 trainable
float params. Adam_FP32 and SignSGD_Old use the pre-ternarization float weights.
TrueTernary uses the post-REFACTOR6 strict ternary-only path.
"""
import os, sys, time, json, math, gc, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.main import ARBModel, VOCAB, CTX, LossComponents
from arbitor.kernel.ternary_scale import TScaleType
from arbitor.kernel.ternary_scale import _triton_ternary_grad_sign, _triton_update_e, _triton_ternary_step
from arbitor.optim.sign_sgd import SignSGD
from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters
STEPS = 50
WARMUP = 10
BATCH = 8
CTX_LEN = 66
SEED = 42
DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
DATA_PATH = os.path.join(os.path.dirname(__file__), "tinyshakespeare.txt")
CONFIGS = [
"Adam_FP32",
"SignSGD_Old",
"TrueTernary",
]
class NoTrainableParametersOptimizer:
def __init__(self):
self.param_groups = []
self.state = {}
def zero_grad(self, *args, **kwargs):
return None
def step(self, *args, **kwargs):
return None
def download_data():
if not os.path.exists(DATA_PATH):
import urllib.request
print(" Downloading tinyshakespeare...")
urllib.request.urlretrieve(DATA_URL, DATA_PATH)
with open(DATA_PATH, "r", encoding="utf-8") as f:
text = f.read()
byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long)
n = int(0.9 * len(byte_data))
return byte_data[:n], byte_data[n:]
def get_batch(data, device):
ix = torch.randint(0, len(data) - CTX_LEN - 1, (BATCH,))
x = torch.stack([data[i: i + CTX_LEN] for i in ix])
targets = x[:, 3:]
return x.to(device, non_blocking=True), targets.to(device, non_blocking=True)
def get_lr(step, max_lr=1e-4, min_lr=1e-6):
if step < WARMUP:
return max_lr * (step + 1) / WARMUP
progress = (step - WARMUP) / max(1, STEPS - WARMUP)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def cpu_update_memory(model, accum_threshold=3, loss_signal=None):
"""CPU-based update that avoids the Triton compilation bug (14s/step)."""
import torch.nn.functional as F
from arbitor.converters.convert_to_ternary8 import pack_ternary
t_step = 1
if loss_signal is not None:
loss_val = float(loss_signal.detach().clamp(min=0, max=32).item())
t_step = max(1, min(4, int(loss_val // 2) + 1))
for module in model.modules():
if not hasattr(module, 'update_E') and not hasattr(module, 'ternary_step'):
continue
has_grad = hasattr(module, '_hook_grad_T_sign')
has_direct = hasattr(module, '_hook_grad_2d') and hasattr(module, '_hook_x_2d')
if not has_grad and not has_direct:
continue
device = module.T_accum.device
N, K = tuple(module._T_shape.tolist())
if has_direct:
grads = module._hook_grad_2d
xs = module._hook_x_2d
grad_W = torch.matmul(grads.float().t(), xs.float())
grad_sign = grad_W.sign().to(torch.int8)
else:
grad_sign = module._hook_grad_T_sign.to(device=device)
# --- update_E (CPU fixed-point residual path) ---
if hasattr(module, 'update_E'):
T_source = module._get_T() if not hasattr(module, '_hook_T') else module._hook_T
T = T_source.to(device=device)
grad_T = grad_sign.float() * T.float()
gpr = (K + module.group_size - 1) // module.group_size
total_in = gpr * module.group_size
padded = F.pad(grad_T, (0, total_in - K))
grouped = padded.view(N, gpr, module.group_size)
group_score = grouped.sum(dim=2)
delta = -group_score.sign().to(torch.int8).flatten()
if not hasattr(module, "E_accum"):
module.register_buffer("E_accum", torch.zeros_like(module.E, dtype=torch.int8))
e_accum_threshold = int(getattr(module, "_e_accum_threshold", 4))
new_accum = torch.clamp(module.E_accum + delta, -128, 127).to(torch.int8)
step_up = new_accum >= e_accum_threshold
step_down = new_accum <= -e_accum_threshold
e_step = torch.where(step_up, torch.ones_like(new_accum),
torch.where(step_down, -torch.ones_like(new_accum), torch.zeros_like(new_accum)))
module.E = torch.clamp(module.E.to(torch.int16) + e_step.to(torch.int16), -128, 127).to(torch.int8)
module.E_accum = (new_accum.to(torch.int16) - e_step.to(torch.int16) * e_accum_threshold).to(torch.int8)
# --- ternary_step (CPU T flip) ---
if hasattr(module, 'ternary_step'):
module.T_accum = torch.clamp(module.T_accum + grad_sign.to(device) * t_step, -128, 127).to(torch.int8)
fu = module.T_accum > accum_threshold
fd = module.T_accum < -accum_threshold
if fu.any() or fd.any():
T = module._get_T().to(device)
T[fu] = torch.tensor(1, dtype=T.dtype, device=device)
T[fd] = torch.tensor(-1, dtype=T.dtype, device=device)
torch.cuda.synchronize()
module.T_packed = pack_ternary(T.cpu())[0].to(device=device)
module.T_accum = torch.where(fu | fd, torch.zeros_like(module.T_accum), module.T_accum)
# Clean up hooks
if has_direct:
del module._hook_grad_2d, module._hook_x_2d
else:
del module._hook_grad_T_sign
def gpu_signcache_update_memory(model, accum_threshold=3, update_scales=True, loss_signal=None):
"""GPU update that computes one temporary int8 grad_sign per module, then frees it.
This avoids the very slow per-packed-byte direct reduction path for benchmark
shapes with large M = batch * sequence. It still keeps persistent model state
ternary-first: packed T, int8 E, int8 accumulators, no FP master weights.
"""
t_step = 1
if loss_signal is not None:
loss_val = float(loss_signal.detach().clamp(min=0, max=32).item())
t_step = max(1, min(4, int(loss_val // 2) + 1))
for module in model.modules():
has_grad = hasattr(module, '_hook_grad_T_sign')
has_direct = hasattr(module, '_hook_grad_2d') and hasattr(module, '_hook_x_2d')
if not has_grad and not has_direct:
continue
if has_direct:
n_out, k_in = tuple(module._T_shape.tolist())
grad_sign = _triton_ternary_grad_sign(module._hook_grad_2d, module._hook_x_2d, n_out, k_in)
module._hook_grad_T_sign = grad_sign
del module._hook_grad_2d, module._hook_x_2d
if update_scales and hasattr(module, 'update_E'):
if getattr(module, "E", None) is not None and module.E.is_cuda and hasattr(module, "_hook_grad_T_sign"):
n_out, k_in = tuple(module._T_shape.tolist())
if not hasattr(module, "E_accum"):
module.register_buffer("E_accum", torch.zeros_like(module.E, dtype=torch.int8))
_triton_update_e(
module.T_packed.contiguous(),
module._hook_grad_T_sign.contiguous(),
module.E,
module.E_accum,
n_out,
k_in,
module.group_size,
int(getattr(module, "_e_accum_threshold", 4)),
)
else:
module.update_E(loss_signal=loss_signal)
if hasattr(module, 'ternary_step'):
if getattr(module, "T_packed", None) is not None and module.T_packed.is_cuda and hasattr(module, "_hook_grad_T_sign"):
total = int(module._T_shape[0].item() * module._T_shape[1].item())
_triton_ternary_step(
module.T_packed,
module._hook_grad_T_sign.contiguous(),
module.T_accum,
total,
accum_threshold,
t_step,
)
del module._hook_grad_T_sign
else:
module.ternary_step(accum_threshold=accum_threshold)
def build_model(strict_ternary):
return ARBModel(
tscale_type=TScaleType.T32,
enable_image=not strict_ternary,
enable_audio=not strict_ternary,
enable_vq=not strict_ternary,
enable_graph=not strict_ternary,
enable_memory_modules=not strict_ternary,
enable_moe=True,
)
def run_config(
name,
device,
base_state=None,
strict_true_ternary=True,
update_backend="gpu",
scale_update_interval=4,
accum_threshold=3,
print_every=1,
):
torch.manual_seed(SEED)
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
gc.collect()
is_true_ternary = "TrueTernary" in name
is_signsgd = "SignSGD" in name or "TrueTernary" in name
use_bf16 = "BF16" in name
# TrueTernary always uses strict mode (0 float params, no encoders)
strict_model = "TrueTernary" in name
if strict_model:
model = build_model(strict_ternary=True).to(device)
freeze_float_parameters(model)
elif base_state is not None:
model = build_model(strict_ternary=False).to(device)
model.load_state_dict(base_state, strict=False)
# Re-freeze ViT/audio params that load_state_dict may have unfrozen
for param_name, p in model.named_parameters():
bn = param_name.split('.')[0]
if bn in ('vit', 'image_sequencer', 'audio_sequencer'):
p.requires_grad = False
else:
model = build_model(strict_ternary=strict_model).to(device)
if strict_model:
freeze_float_parameters(model)
opt_params = trainable_parameters(model)
if use_bf16:
import bitsandbytes as bnb
print(f" Creating Adam8bit optimizer...", flush=True)
optimizer = bnb.optim.Adam8bit(opt_params, lr=1e-4, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer()
elif name == "Adam_FP32":
print(f" Creating Adam FP32 optimizer...", flush=True)
optimizer = torch.optim.Adam(opt_params, lr=1e-4, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer()
elif is_signsgd:
print(f" Creating SignSGD optimizer...", flush=True)
optimizer = SignSGD(opt_params, lr=0.001, weight_decay=0.01) if opt_params else NoTrainableParametersOptimizer()
else:
raise ValueError(f"Unknown config: {name}")
n_params = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
# Compute persistent ternary memory
ternary_bytes = 0
for buf_name, buf in model.named_buffers():
if 'T_packed' in buf_name:
ternary_bytes += buf.numel()
e_bytes = sum(b.numel() for n, b in model.named_buffers() if n.endswith('.E'))
e_accum_bytes = sum(b.numel() for n, b in model.named_buffers() if n.endswith('.E_accum'))
ternary_p_unique = ternary_bytes * 5 # 5 trits per byte
e_count = e_bytes # int8 E
# Memory accounting
model_mem = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
opt_mem = 0
for g in optimizer.param_groups:
for p in g["params"]:
opt_mem += p.numel() * p.element_size()
state = optimizer.state.get(p, {})
for v in state.values():
if isinstance(v, torch.Tensor):
opt_mem += v.numel() * v.element_size()
opt_mem /= 1024 * 1024
buf_mem = sum(b.numel() * b.element_size() for n, b in model.named_buffers()) / (1024 * 1024)
print(f"\n [{name}]", flush=True)
print(f" Params: {n_params:,} total, {trainable:,} trainable", flush=True)
print(f" Model mode: {'strict ternary text-only' if strict_model else 'full multimodal'}")
print(format_audit(audit_model(model), limit=5), flush=True)
print(f" Ternary: ~{ternary_p_unique/1e6:.1f}M packed trits, {e_count:,} int8 E values, {e_accum_bytes:,} int8 E_accum values")
print(f" Model weights: {model_mem:.1f}MB | Buffers: {buf_mem:.1f}MB | Optimizer: {opt_mem:.1f}MB")
print(f" Compiling warmup...", end=" ", flush=True)
# Warmup forward pass to trigger JIT compilation
x_warm, t_warm = get_batch(train_data, device)
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16):
_ = model(x_warm, targets=t_warm)
torch.cuda.synchronize()
print(f"done.", flush=True)
if device == "cuda":
torch.cuda.reset_peak_memory_stats(device)
loss_history = []
step_times = []
for step in range(STEPS):
lr = get_lr(step)
for pg in optimizer.param_groups:
pg["lr"] = lr
x, targets = get_batch(train_data, device)
t0 = time.perf_counter()
optimizer.zero_grad()
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16):
logits, losses, _, _ = model(x, targets=targets)
losses.total.backward()
if opt_params:
torch.nn.utils.clip_grad_norm_(opt_params, 1.0)
optimizer.step()
if is_true_ternary:
update_scales = scale_update_interval > 0 and step % scale_update_interval == 0
if update_backend == "gpu":
model._ternary_update_memory(
accum_threshold=accum_threshold,
update_scales=update_scales,
loss_signal=losses.total,
)
elif update_backend == "gpu-signcache":
gpu_signcache_update_memory(
model,
accum_threshold=accum_threshold,
update_scales=update_scales,
loss_signal=losses.total,
)
elif update_backend == "dense-fallback":
if update_scales:
cpu_update_memory(model, accum_threshold=accum_threshold, loss_signal=losses.total)
else:
model._ternary_update_memory(
accum_threshold=accum_threshold,
update_scales=False,
loss_signal=losses.total,
)
elif update_backend != "none":
raise ValueError(f"Unknown update backend: {update_backend}")
if device == "cuda":
torch.cuda.synchronize()
t1 = time.perf_counter()
loss = losses.total.item()
loss_history.append(loss)
step_ms = (t1 - t0) * 1000
step_times.append(step_ms)
if step % print_every == 0 or step == STEPS - 1:
peak = torch.cuda.max_memory_allocated(device) / (1024 * 1024)
allocated = torch.cuda.memory_allocated(device) / (1024 * 1024)
reserved = torch.cuda.memory_reserved(device) / (1024 * 1024)
toks_sec = BATCH * (CTX_LEN - 3) / (step_ms / 1000)
print(
f" step {step:>4d}/{STEPS} | loss={loss:.4f} | {step_ms:.0f}ms | "
f"{toks_sec:.0f} tok/s | alloc={allocated:.0f}MB reserved={reserved:.0f}MB peak={peak:.0f}MB",
flush=True,
)
final_window = loss_history[-min(20, len(loss_history)):]
final_avg = sum(final_window) / len(final_window)
min_loss = min(loss_history)
avg_step_ms = sum(step_times[WARMUP:]) / len(step_times[WARMUP:])
avg_toks_sec = BATCH * (CTX_LEN - 3) / (avg_step_ms / 1000)
peak_vram = torch.cuda.max_memory_allocated(device) / (1024 * 1024)
del model, optimizer
gc.collect()
torch.cuda.empty_cache()
return {
"config": name,
"n_params": n_params,
"trainable_params": trainable,
"model_mem_mb": round(model_mem, 1),
"optimizer_mem_mb": round(opt_mem, 1),
"buffer_mem_mb": round(buf_mem, 1),
"peak_vram_mb": round(peak_vram, 1),
"final_loss_avg20": round(final_avg, 4),
"min_loss": round(min_loss, 4),
"avg_step_ms": round(avg_step_ms, 1),
"avg_toks_sec": round(avg_toks_sec, 1),
"loss_history": [round(l, 4) for l in loss_history],
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark full or strict true-ternary MORPH configs.")
parser.add_argument("--steps", type=int, default=STEPS)
parser.add_argument("--warmup", type=int, default=WARMUP)
parser.add_argument("--batch", type=int, default=BATCH)
parser.add_argument("--ctx", type=int, default=CTX_LEN)
parser.add_argument("--configs", type=str, default=",".join(CONFIGS),
help="Comma-separated configs: Adam_FP32,SignSGD_Old,TrueTernary")
parser.add_argument("--strict-true-ternary", action=argparse.BooleanOptionalAction, default=True,
help="Run TrueTernary as text-only strict ternary with frozen float params.")
parser.add_argument("--update-backend", choices=["gpu", "gpu-signcache", "dense-fallback", "none"], default="gpu-signcache",
help="TrueTernary state update implementation.")
parser.add_argument("--scale-update-interval", type=int, default=4,
help="Update int8 E every N TrueTernary steps. 0 disables E updates.")
parser.add_argument("--accum-threshold", type=int, default=3,
help="T_accum threshold for ternary sign flips.")
parser.add_argument("--print-every", type=int, default=1)
parser.add_argument("--reuse-base", action=argparse.BooleanOptionalAction, default=False,
help="Create one full base model on CPU and load it into full-model configs.")
args = parser.parse_args()
STEPS = args.steps
WARMUP = args.warmup
BATCH = args.batch
CTX_LEN = args.ctx
CONFIGS = [item.strip() for item in args.configs.split(",") if item.strip()]
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
print("\nDownloading data...")
global train_data, val_data
train_data, val_data = download_data()
print(f" Train: {len(train_data):,} bytes, Val: {len(val_data):,} bytes")
print(f" Batch={BATCH}, CTX={CTX_LEN}, Steps={STEPS}, Warmup={WARMUP}")
results = []
t_all_0 = time.perf_counter()
base_state = None
if args.reuse_base and any(cfg != "TrueTernary" or not args.strict_true_ternary for cfg in CONFIGS):
# Keep reusable initialization on CPU so it does not inflate per-config VRAM.
print(f"\nCreating base model (CPU state reuse)...", flush=True)
base_model = build_model(strict_ternary=False)
base_state = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()}
del base_model
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
print(" Done.", flush=True)
for cfg in CONFIGS:
r = run_config(
cfg,
device,
base_state=base_state,
strict_true_ternary=args.strict_true_ternary,
update_backend=args.update_backend,
scale_update_interval=args.scale_update_interval,
accum_threshold=args.accum_threshold,
print_every=max(1, args.print_every),
)
results.append(r)
gc.collect()
torch.cuda.empty_cache()
t_all = time.perf_counter() - t_all_0
# Summary table
print(f"\n{'='*90}")
print(f" BENCHMARK RESULTS — {STEPS} steps, {BATCH}x{CTX_LEN} batch")
print(f"{'='*90}")
print(f" {'Config':<20} {'Loss(avg20)':<12} {'Loss(min)':<10} {'Step(ms)':<10} {'tok/s':<10} {'PeakMB':<8} {'ModelMB':<8} {'OptMB':<8}")
print(f" {'-'*86}")
for r in results:
print(f" {r['config']:<20} {r['final_loss_avg20']:<12} {r['min_loss']:<10} {r['avg_step_ms']:<10} {r['avg_toks_sec']:<10} {r['peak_vram_mb']:<8} {r['model_mem_mb']:<8} {r['optimizer_mem_mb']:<8}")
# Compare to baseline
baseline = None
for r in results:
if r['config'] == 'Adam_FP32':
baseline = r
break
if baseline:
print(f"\n {'─'*86}")
print(f" {'Relative to Adam_FP32':<50}")
print(f" {'─'*86}")
for r in results:
if r['config'] == 'Adam_FP32':
continue
loss_ratio = r['final_loss_avg20'] / baseline['final_loss_avg20']
speed_ratio = baseline['avg_toks_sec'] / r['avg_toks_sec'] if r['avg_toks_sec'] > 0 else float('inf')
vram_ratio = r['peak_vram_mb'] / baseline['peak_vram_mb']
print(f" {r['config']:<20} loss={loss_ratio:.2f}x speed={speed_ratio:.2f}x vram={vram_ratio:.2f}x")
# Save results
out = {
"config": "True Ternary vs Baselines",
"steps": STEPS,
"batch": BATCH,
"context": CTX_LEN,
"total_time_s": round(t_all, 1),
"results": results,
}
path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "benchmark", "benchmark_results.json")
with open(path, "w") as f:
json.dump(out, f, indent=2)
print(f"\n Results saved to {path}")
print(f" Total benchmark time: {t_all:.0f}s ({t_all/60:.1f}min)")