"""Vision training. Freezes text/audio pipelines, trains image sequencer features. Useful for fine-tuning the DINOv2 projection layers on custom image data. """ import os, sys, torch sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from torch.utils.tensorboard import SummaryWriter from arbitor import ARBModel from arbitor.kernel.ternary_scale import TScaleType from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters def freeze_non_vision(model): """Freeze everything except image sequencer projection layers.""" for name, p in model.named_parameters(): p.requires_grad = False for name, p in model.named_parameters(): if any(k in name for k in ('image_sequencer', 'image', 'patch_proj', 'modality_gate', 'output_router')): p.requires_grad = True if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="ARB vision/image training") parser.add_argument("--steps", type=int, default=3000) parser.add_argument("--batch", type=int, default=4) parser.add_argument("--image-size", type=int, default=224) parser.add_argument("--run", type=str, default="vision") parser.add_argument("--data", type=str, default=None, help="Image dataset path") parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton") args = parser.parse_args() os.environ["ARB_TERNARY_BACKEND"] = args.backend if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.") model = ARBModel(enable_image=True, enable_audio=False, enable_vq=True, enable_graph=False, enable_memory_modules=False, enable_moe=False, max_moe_iters=4, enable_attention=False, enable_output_router=False, enable_video_output=False, enable_talker_output=False).cuda() freeze_non_vision(model) freeze_float_parameters(model) print(format_audit(audit_model(model))) if trainable_parameters(model): raise RuntimeError("Vision trainer is pure ternary; use training/finetuning/vision.py for LoRA adapters.") run_dir = f"models/checkpoints/{args.run}" os.makedirs(run_dir, exist_ok=True) writer = SummaryWriter(run_dir) # Generate synthetic image data for smoke testing for step in range(args.steps): images = torch.randn(args.batch, 3, args.image_size, args.image_size).cuda() text = torch.randint(0, 256, (args.batch, 16), device=images.device) targets = text[:, 3:] model.zero_grad(set_to_none=True) _, losses, _, _ = model(x=text, images=images, targets=targets) loss = losses.total model.prepare_ternary_backward(loss.detach(), update_scales=True) loss.backward() model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=loss) model.zero_grad(set_to_none=True) if step % 200 == 0: writer.add_scalar("loss/vision", loss.item(), step) print(f"step {step:>5d} loss={loss.item():.3f}")