| """Text-only pure-ternary training. |
| |
| For the maintained CLI entrypoint prefer ``python -m arbitor.train`` or the |
| ``arbs-train`` console script after ``pip install -e .``. |
| """ |
| import os, sys, time, math, json, torch |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| from torch.utils.tensorboard import SummaryWriter |
| from arbitor import ARBModel, VOCAB, CTX |
| from arbitor.kernel.ternary_scale import TScaleType |
| from arbitor.kernel.ternary_audit import freeze_float_parameters, audit_model, format_audit |
|
|
|
|
| def load_text_data(source, ctx=CTX): |
| """Load byte dataset from .txt or .pt file, or HuggingFace dataset name.""" |
| if source.endswith('.pt'): |
| data = torch.load(source, weights_only=True) |
| elif os.path.isfile(source): |
| data = torch.tensor(list(open(source, 'rb').read()), dtype=torch.long) |
| else: |
| raise ValueError(f"Dataset not found: {source}") |
| n = int(0.9 * len(data)) |
| return data[:n].cuda(), data[n:].cuda() |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="ARB text training") |
| parser.add_argument("--steps", type=int, default=5000) |
| parser.add_argument("--batch", type=int, default=8) |
| parser.add_argument("--ctx", type=int, default=min(CTX, 128)) |
| parser.add_argument("--data", type=str, default="training/data/tinyshakespeare.txt") |
| parser.add_argument("--run", type=str, default="text") |
| parser.add_argument("--eval-interval", type=int, default=250) |
| parser.add_argument("--log-interval", type=int, default=10) |
| parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton") |
| parser.add_argument("--full-text-stack", action="store_true", |
| help="Enable text VQ + KG/MoEGraph + KV attention. Default is byte-text only.") |
| parser.add_argument("--with-vq", action="store_true") |
| parser.add_argument("--with-graph", action="store_true") |
| parser.add_argument("--with-moe", action="store_true") |
| parser.add_argument("--with-attention", action="store_true") |
| args = parser.parse_args() |
| os.environ["ARB_TERNARY_BACKEND"] = args.backend |
| if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: |
| raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.") |
|
|
| enable_vq = args.full_text_stack or args.with_vq |
| enable_graph = (args.full_text_stack or args.with_graph) and enable_vq |
| enable_moe = args.full_text_stack or args.with_moe |
| enable_attention = (args.full_text_stack or args.with_attention) and enable_graph |
|
|
| model = ARBModel(enable_image=False, enable_audio=False, |
| enable_vq=enable_vq, enable_graph=enable_graph, |
| enable_memory_modules=False, enable_moe=enable_moe, |
| max_moe_iters=4, |
| enable_attention=enable_attention, |
| enable_output_router=False, |
| enable_video_output=False, |
| enable_talker_output=False).cuda() |
| freeze_float_parameters(model) |
|
|
| print(format_audit(audit_model(model))) |
| print( |
| "Text trainer config: " |
| f"backend={args.backend}, vq={enable_vq}, graph={enable_graph}, " |
| f"moe={enable_moe}, attention={enable_attention}" |
| ) |
| train_params = [p for p in model.parameters() if p.requires_grad] |
| if train_params: |
| raise RuntimeError("Text trainer is pure ternary; use training/finetuning for float LoRA adapters.") |
| print("Pure ternary update path is active") |
|
|
| train_data, val_data = load_text_data(args.data, args.ctx) |
| run_dir = f"models/checkpoints/{args.run}" |
| os.makedirs(run_dir, exist_ok=True) |
| writer = SummaryWriter(run_dir) |
| step = 0 |
| best = float('inf') |
|
|
| while step < args.steps: |
| ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,)) |
| x = torch.stack([train_data[i:i+args.ctx] for i in ix]) |
| t = x[:, 3:] |
|
|
| _, losses, _, _ = model(x, targets=t) |
| model.prepare_ternary_backward(losses.total.detach(), update_scales=True) |
| losses.total.backward() |
|
|
| model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=losses.total) |
| model.zero_grad(set_to_none=True) |
|
|
| step += 1 |
| if args.log_interval and step % args.log_interval == 0: |
| writer.add_scalar("loss/train", losses.total.item(), step) |
| print(f"step {step:>5d} train={losses.total.item():.3f}") |
|
|
| if args.eval_interval and step % args.eval_interval == 0: |
| with torch.no_grad(): |
| ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,)) |
| xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v]) |
| tv = xv[:, 3:] |
| _, lv, _, _ = model(xv, targets=tv) |
| val = lv.total.item() |
| writer.add_scalar("loss/train", losses.total.item(), step) |
| writer.add_scalar("loss/eval", val, step) |
| if val < best: |
| best = val |
| torch.save(model.state_dict(), f"{run_dir}/best.pt") |
| print(f"step {step:>5d} train={losses.total.item():.3f} eval={val:.3f} best={best:.3f}") |
|
|