| """Latent video diffusion training. |
| Freezes text/audio pipelines, trains VideoHead + OutputRouter. |
| Uses pig-vae to encode target video frames as latent training targets. |
| |
| Dataset: expects video files or pre-encoded .pt latent files. |
| """ |
| import os, sys, torch |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| from torch.utils.tensorboard import SummaryWriter |
| from arbitor import ARBModel |
| from arbitor.kernel.ternary_scale import TScaleType |
| from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters |
| from arbitor.encoders.pig_vae import load_vae |
|
|
|
|
| def freeze_non_diffusion(model): |
| """Freeze text/audio; keep VideoHead + OutputRouter trainable.""" |
| for name, p in model.named_parameters(): |
| p.requires_grad = False |
| for name, p in model.named_parameters(): |
| if any(k in name for k in ('video_head', 'output_router', 'talker_head')): |
| p.requires_grad = True |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="ARB video diffusion training") |
| parser.add_argument("--steps", type=int, default=5000) |
| parser.add_argument("--batch", type=int, default=2) |
| parser.add_argument("--run", type=str, default="diffusion") |
| parser.add_argument("--latent-dir", type=str, default=None, help="Dir of .pt latent files") |
| parser.add_argument("--video-dir", type=str, default=None, help="Dir of .mp4 files") |
| parser.add_argument("--frames", type=int, default=16, help="Frames per video clip") |
| parser.add_argument("--res", type=int, default=256, help="Frame resolution") |
| parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton") |
| args = parser.parse_args() |
| os.environ["ARB_TERNARY_BACKEND"] = args.backend |
| if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: |
| raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.") |
|
|
| model = ARBModel(enable_image=False, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=False, |
| max_moe_iters=4, |
| enable_attention=False, |
| enable_output_router=False, |
| enable_video_output=True, |
| enable_talker_output=False).cuda() |
| freeze_non_diffusion(model) |
| freeze_float_parameters(model) |
| vae = load_vae(device='cuda', quantize='int8') if args.video_dir else None |
| print(format_audit(audit_model(model))) |
|
|
| if trainable_parameters(model): |
| raise RuntimeError("Diffusion trainer is pure ternary; use training/finetuning/diffusion.py for LoRA adapters.") |
| run_dir = f"models/checkpoints/{args.run}" |
| os.makedirs(run_dir, exist_ok=True) |
| writer = SummaryWriter(run_dir) |
|
|
| for step in range(args.steps): |
| |
| text = torch.randint(0, 256, (args.batch, 20)).cuda() |
|
|
| |
| |
| target_latents = torch.randn(args.batch, 16, 1, 32, 32).cuda() |
|
|
| model.zero_grad(set_to_none=True) |
|
|
| embedded = model.embedding(text) |
| rel = model.multimodal_sequencer({'text': embedded})['text'] |
| pred_latents = model.video_head(rel) |
|
|
| loss = torch.nn.functional.mse_loss(pred_latents, target_latents) |
| model.prepare_ternary_backward(loss.detach(), update_scales=True) |
| loss.backward() |
| model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=loss) |
| model.zero_grad(set_to_none=True) |
|
|
| if step % 100 == 0: |
| writer.add_scalar("loss/diffusion", loss.item(), step) |
| print(f"step {step:>5d} loss={loss.item():.6f}") |
|
|