ARBS / training /text.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Text-only pure-ternary training.
For the maintained CLI entrypoint prefer ``python -m arbitor.train`` or the
``arbs-train`` console script after ``pip install -e .``.
"""
import os, sys, time, math, json, torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from torch.utils.tensorboard import SummaryWriter
from arbitor import ARBModel, VOCAB, CTX
from arbitor.kernel.ternary_scale import TScaleType
from arbitor.kernel.ternary_audit import freeze_float_parameters, audit_model, format_audit
def load_text_data(source, ctx=CTX):
"""Load byte dataset from .txt or .pt file, or HuggingFace dataset name."""
if source.endswith('.pt'):
data = torch.load(source, weights_only=True)
elif os.path.isfile(source):
data = torch.tensor(list(open(source, 'rb').read()), dtype=torch.long)
else:
raise ValueError(f"Dataset not found: {source}")
n = int(0.9 * len(data))
return data[:n].cuda(), data[n:].cuda()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB text training")
parser.add_argument("--steps", type=int, default=5000)
parser.add_argument("--batch", type=int, default=8)
parser.add_argument("--ctx", type=int, default=min(CTX, 128))
parser.add_argument("--data", type=str, default="training/data/tinyshakespeare.txt")
parser.add_argument("--run", type=str, default="text")
parser.add_argument("--eval-interval", type=int, default=250)
parser.add_argument("--log-interval", type=int, default=10)
parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton")
parser.add_argument("--full-text-stack", action="store_true",
help="Enable text VQ + KG/MoEGraph + KV attention. Default is byte-text only.")
parser.add_argument("--with-vq", action="store_true")
parser.add_argument("--with-graph", action="store_true")
parser.add_argument("--with-moe", action="store_true")
parser.add_argument("--with-attention", action="store_true")
args = parser.parse_args()
os.environ["ARB_TERNARY_BACKEND"] = args.backend
if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}:
raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.")
enable_vq = args.full_text_stack or args.with_vq
enable_graph = (args.full_text_stack or args.with_graph) and enable_vq
enable_moe = args.full_text_stack or args.with_moe
enable_attention = (args.full_text_stack or args.with_attention) and enable_graph
model = ARBModel(enable_image=False, enable_audio=False,
enable_vq=enable_vq, enable_graph=enable_graph,
enable_memory_modules=False, enable_moe=enable_moe,
max_moe_iters=4,
enable_attention=enable_attention,
enable_output_router=False,
enable_video_output=False,
enable_talker_output=False).cuda()
freeze_float_parameters(model)
print(format_audit(audit_model(model)))
print(
"Text trainer config: "
f"backend={args.backend}, vq={enable_vq}, graph={enable_graph}, "
f"moe={enable_moe}, attention={enable_attention}"
)
train_params = [p for p in model.parameters() if p.requires_grad]
if train_params:
raise RuntimeError("Text trainer is pure ternary; use training/finetuning for float LoRA adapters.")
print("Pure ternary update path is active")
train_data, val_data = load_text_data(args.data, args.ctx)
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
step = 0
best = float('inf')
while step < args.steps:
ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,))
x = torch.stack([train_data[i:i+args.ctx] for i in ix])
t = x[:, 3:]
_, losses, _, _ = model(x, targets=t)
model.prepare_ternary_backward(losses.total.detach(), update_scales=True)
losses.total.backward()
model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=losses.total)
model.zero_grad(set_to_none=True)
step += 1
if args.log_interval and step % args.log_interval == 0:
writer.add_scalar("loss/train", losses.total.item(), step)
print(f"step {step:>5d} train={losses.total.item():.3f}")
if args.eval_interval and step % args.eval_interval == 0:
with torch.no_grad():
ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,))
xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v])
tv = xv[:, 3:]
_, lv, _, _ = model(xv, targets=tv)
val = lv.total.item()
writer.add_scalar("loss/train", losses.total.item(), step)
writer.add_scalar("loss/eval", val, step)
if val < best:
best = val
torch.save(model.state_dict(), f"{run_dir}/best.pt")
print(f"step {step:>5d} train={losses.total.item():.3f} eval={val:.3f} best={best:.3f}")