| """Fine-tune ARB model on video/latent diffusion tasks using LoRA. |
| |
| Freezes text/audio pipelines, adapts VideoHead + core MoE for |
| latent video diffusion fine-tuning. Uses pig-vae to encode training targets. |
| |
| Designed for 8GB VRAM with batch_size=1. |
| |
| Usage: |
| python training/finetuning/diffusion.py \\ |
| --video-dir ./videos --steps 2000 --batch 1 \\ |
| --lora-rank 16 --run diffusion-finetune |
| |
| Data format: directory of .mp4 files (will be encoded to latents via pig-vae). |
| """ |
| import os, sys, time |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
| def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1): |
| """Build ARB model with VideoHead + LoRA, freeze text/audio.""" |
| from arbitor import ARBModel |
| from training.finetuning.lora import apply_lora_to_model, count_lora_params |
|
|
| model = ARBModel( |
| enable_image=False, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=True, |
| max_moe_iters=max_moe_iters, |
| ).cuda() |
|
|
| target_modules = ['W_gate', 'W_transform', 'byte_head', 'router', |
| 'shared_up', 'shared_expert_gate', 'shared_expert_up', |
| 'video_head', 'diffusion_step', 'cross_attn', |
| 'halt_unit', 'noise_embed'] |
| lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha, |
| target_modules=target_modules) |
| lora_p, total_p = count_lora_params(model) |
| print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True) |
| return model, lora_layers |
|
|
|
|
| def load_video_data(video_dir, max_samples=100, frames=16, res=256): |
| """Load video files from directory and encode to VAE latents. |
| |
| Uses pig-vae to convert video frames to latent space for training targets. |
| Falls back to random latents if pig-vae is not available. |
| """ |
| import glob, torchvision.io |
| from arbitor.config import SPECIAL_VOCAB |
|
|
| files = glob.glob(os.path.join(video_dir, "*.mp4")) + \ |
| glob.glob(os.path.join(video_dir, "*.avi")) |
|
|
| if not files: |
| print(f" No video files found in {video_dir}", flush=True) |
| print(f" Using synthetic random latents for smoke testing", flush=True) |
| return _generate_synthetic(frames, res, max_samples) |
|
|
| print(f" Found {len(files)} video files", flush=True) |
| files = files[:max_samples] |
|
|
| |
| vae = None |
| try: |
| from arbitor.encoders.pig_vae import load_vae |
| vae = load_vae(device='cuda', quantize='int8') |
| print(f" pig-vae loaded for encoding", flush=True) |
| except Exception as e: |
| print(f" pig-vae not available: {e}", flush=True) |
| print(f" Using random latents (no video encoding)", flush=True) |
| return _generate_synthetic(frames, res, min(max_samples, 50)) |
|
|
| data = [] |
| for f in files: |
| try: |
| video, _, _ = torchvision.io.read_video(f, pts_unit='sec') |
| video = video.permute(3, 0, 1, 2).float() / 255.0 |
| video = video[:, :frames, :res, :res] |
|
|
| if video.shape[1] < frames: |
| continue |
|
|
| video = video.unsqueeze(0).cuda() |
| with torch.no_grad(): |
| latents = vae.encode(video).cpu() |
| data.append(latents) |
| except Exception as e: |
| continue |
|
|
| if not data: |
| return _generate_synthetic(frames, res, 50) |
|
|
| print(f" Encoded {len(data)} videos to latent space", flush=True) |
| return data |
|
|
|
|
| def _generate_synthetic(frames, res, count): |
| """Fallback: generate random latent targets for testing.""" |
| data = [] |
| for _ in range(count): |
| latents = torch.randn(1, 16, 1, 32, 32) |
| data.append(latents) |
| print(f" Generated {count} synthetic latent targets", flush=True) |
| return data |
|
|
|
|
| def _match_latents(target, pred): |
| """Resize or pad target latents to the current VideoHead output shape.""" |
| if target.shape[0] == 1 and pred.shape[0] > 1: |
| target = target.expand(pred.shape[0], -1, -1, -1, -1).contiguous() |
| if target.shape[1] != pred.shape[1]: |
| if target.shape[1] > pred.shape[1]: |
| target = target[:, :pred.shape[1]] |
| else: |
| pad = target.new_zeros(target.shape[0], pred.shape[1] - target.shape[1], *target.shape[2:]) |
| target = torch.cat([target, pad], dim=1) |
| if target.shape[2:] != pred.shape[2:]: |
| target = torch.nn.functional.interpolate( |
| target, size=pred.shape[2:], mode="trilinear", align_corners=False |
| ) |
| return target |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="ARB video/diffusion fine-tuning") |
| parser.add_argument("--video-dir", type=str, default=None, |
| help="Directory with .mp4/.avi files") |
| parser.add_argument("--steps", type=int, default=2000) |
| parser.add_argument("--batch", type=int, default=1) |
| parser.add_argument("--accum", type=int, default=4) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--lora-rank", type=int, default=16) |
| parser.add_argument("--lora-alpha", type=float, default=32.0) |
| parser.add_argument("--max-moe-iters", type=int, default=1) |
| parser.add_argument("--run", type=str, default="diffusion-finetune") |
| parser.add_argument("--eval-interval", type=int, default=100) |
| parser.add_argument("--frames", type=int, default=8) |
| parser.add_argument("--res", type=int, default=128) |
| parser.add_argument("--max-samples", type=int, default=100) |
| args = parser.parse_args() |
|
|
| print("Building model with VideoHead + LoRA...", flush=True) |
| model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters) |
|
|
| opt = torch.optim.AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=args.lr, weight_decay=0.01 |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps) |
|
|
| if args.video_dir: |
| data = load_video_data(args.video_dir, args.max_samples, args.frames, args.res) |
| else: |
| data = _generate_synthetic(args.frames, args.res, 100) |
|
|
| n = int(0.8 * len(data)) |
| if len(data) > 1: |
| n = min(max(1, n), len(data) - 1) |
| train_data = data[:n] if n > 0 else data |
| val_data = data[n:] if n < len(data) else data[:1] |
|
|
| run_dir = f"models/checkpoints/{args.run}" |
| os.makedirs(run_dir, exist_ok=True) |
| writer = SummaryWriter(run_dir) |
| step = 0 |
| best_val = float('inf') |
| model.train() |
|
|
| while step < args.steps: |
| opt.zero_grad() |
| accum_loss = 0.0 |
|
|
| for _ in range(args.accum): |
| |
| text = torch.randint(0, 256, (args.batch, 10)).cuda() |
|
|
| idx = torch.randint(0, len(train_data), (1,)).item() |
| target_latents = train_data[idx].cuda() |
| if target_latents.shape[0] == 1 and args.batch > 1: |
| target_latents = target_latents.expand(args.batch, -1, -1, -1, -1).contiguous() |
|
|
| |
| embedded = model.embedding(text) |
| seq_out = model.multimodal_sequencer({'text': embedded}) |
| rel = seq_out['text'] |
|
|
| pred_latents = model.video_head(rel) |
| target_latents = _match_latents(target_latents, pred_latents) |
|
|
| |
| loss_val = torch.nn.functional.mse_loss(pred_latents, target_latents) |
| loss = loss_val / args.accum |
| loss.backward() |
| accum_loss += loss_val.item() |
|
|
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad], 1.0 |
| ) |
| opt.step() |
| scheduler.step() |
| step += 1 |
|
|
| if step % args.eval_interval == 0: |
| model.eval() |
| val_loss = 0.0 |
| with torch.no_grad(): |
| text_v = torch.randint(0, 256, (args.batch, 10)).cuda() |
| embedded_v = model.embedding(text_v) |
| seq_v = model.multimodal_sequencer({'text': embedded_v}) |
| rel_v = seq_v['text'] |
|
|
| for idx in range(min(10, len(val_data))): |
| target = val_data[idx].cuda() |
| if target.shape[0] == 1 and args.batch > 1: |
| target = target.expand(args.batch, -1, -1, -1, -1).contiguous() |
| pred = model.video_head(rel_v) |
| target = _match_latents(target, pred) |
| val_loss += torch.nn.functional.mse_loss(pred, target).item() |
| val_loss /= min(10, len(val_data)) |
|
|
| writer.add_scalar("loss/train", accum_loss, step) |
| writer.add_scalar("loss/eval", val_loss, step) |
|
|
| if val_loss < best_val: |
| best_val = val_loss |
| from training.finetuning.lora import save_lora |
| save_lora(lora_layers, f"{run_dir}/best_lora.pt") |
|
|
| print(f"step {step:>5d}/{args.steps} train={accum_loss:.6f} " |
| f"eval={val_loss:.6f} best={best_val:.6f}", flush=True) |
| model.train() |
|
|
| from training.finetuning.lora import save_lora |
| save_lora(lora_layers, f"{run_dir}/final_lora.pt") |
| print(f"Done. LoRA saved to {run_dir}/", flush=True) |
|
|