"""Text-only pure-ternary training. For the maintained CLI entrypoint prefer ``python -m arbitor.train`` or the ``arbs-train`` console script after ``pip install -e .``. """ import os, sys, time, math, json, torch sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from torch.utils.tensorboard import SummaryWriter from arbitor import ARBModel, VOCAB, CTX from arbitor.kernel.ternary_scale import TScaleType from arbitor.kernel.ternary_audit import freeze_float_parameters, audit_model, format_audit def load_text_data(source, ctx=CTX): """Load byte dataset from .txt or .pt file, or HuggingFace dataset name.""" if source.endswith('.pt'): data = torch.load(source, weights_only=True) elif os.path.isfile(source): data = torch.tensor(list(open(source, 'rb').read()), dtype=torch.long) else: raise ValueError(f"Dataset not found: {source}") n = int(0.9 * len(data)) return data[:n].cuda(), data[n:].cuda() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="ARB text training") parser.add_argument("--steps", type=int, default=5000) parser.add_argument("--batch", type=int, default=8) parser.add_argument("--ctx", type=int, default=min(CTX, 128)) parser.add_argument("--data", type=str, default="training/data/tinyshakespeare.txt") parser.add_argument("--run", type=str, default="text") parser.add_argument("--eval-interval", type=int, default=250) parser.add_argument("--log-interval", type=int, default=10) parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton") parser.add_argument("--full-text-stack", action="store_true", help="Enable text VQ + KG/MoEGraph + KV attention. Default is byte-text only.") parser.add_argument("--with-vq", action="store_true") parser.add_argument("--with-graph", action="store_true") parser.add_argument("--with-moe", action="store_true") parser.add_argument("--with-attention", action="store_true") args = parser.parse_args() os.environ["ARB_TERNARY_BACKEND"] = args.backend if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.") enable_vq = args.full_text_stack or args.with_vq enable_graph = (args.full_text_stack or args.with_graph) and enable_vq enable_moe = args.full_text_stack or args.with_moe enable_attention = (args.full_text_stack or args.with_attention) and enable_graph model = ARBModel(enable_image=False, enable_audio=False, enable_vq=enable_vq, enable_graph=enable_graph, enable_memory_modules=False, enable_moe=enable_moe, max_moe_iters=4, enable_attention=enable_attention, enable_output_router=False, enable_video_output=False, enable_talker_output=False).cuda() freeze_float_parameters(model) print(format_audit(audit_model(model))) print( "Text trainer config: " f"backend={args.backend}, vq={enable_vq}, graph={enable_graph}, " f"moe={enable_moe}, attention={enable_attention}" ) train_params = [p for p in model.parameters() if p.requires_grad] if train_params: raise RuntimeError("Text trainer is pure ternary; use training/finetuning for float LoRA adapters.") print("Pure ternary update path is active") train_data, val_data = load_text_data(args.data, args.ctx) run_dir = f"models/checkpoints/{args.run}" os.makedirs(run_dir, exist_ok=True) writer = SummaryWriter(run_dir) step = 0 best = float('inf') while step < args.steps: ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,)) x = torch.stack([train_data[i:i+args.ctx] for i in ix]) t = x[:, 3:] _, losses, _, _ = model(x, targets=t) model.prepare_ternary_backward(losses.total.detach(), update_scales=True) losses.total.backward() model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=losses.total) model.zero_grad(set_to_none=True) step += 1 if args.log_interval and step % args.log_interval == 0: writer.add_scalar("loss/train", losses.total.item(), step) print(f"step {step:>5d} train={losses.total.item():.3f}") if args.eval_interval and step % args.eval_interval == 0: with torch.no_grad(): ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,)) xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v]) tv = xv[:, 3:] _, lv, _, _ = model(xv, targets=tv) val = lv.total.item() writer.add_scalar("loss/train", losses.total.item(), step) writer.add_scalar("loss/eval", val, step) if val < best: best = val torch.save(model.state_dict(), f"{run_dir}/best.pt") print(f"step {step:>5d} train={losses.total.item():.3f} eval={val:.3f} best={best:.3f}")