ARBS / training /finetuning /diffusion.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Fine-tune ARB model on video/latent diffusion tasks using LoRA.
Freezes text/audio pipelines, adapts VideoHead + core MoE for
latent video diffusion fine-tuning. Uses pig-vae to encode training targets.
Designed for 8GB VRAM with batch_size=1.
Usage:
python training/finetuning/diffusion.py \\
--video-dir ./videos --steps 2000 --batch 1 \\
--lora-rank 16 --run diffusion-finetune
Data format: directory of .mp4 files (will be encoded to latents via pig-vae).
"""
import os, sys, time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter
def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
"""Build ARB model with VideoHead + LoRA, freeze text/audio."""
from arbitor import ARBModel
from training.finetuning.lora import apply_lora_to_model, count_lora_params
model = ARBModel(
enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=True,
max_moe_iters=max_moe_iters,
).cuda()
target_modules = ['W_gate', 'W_transform', 'byte_head', 'router',
'shared_up', 'shared_expert_gate', 'shared_expert_up',
'video_head', 'diffusion_step', 'cross_attn',
'halt_unit', 'noise_embed']
lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
target_modules=target_modules)
lora_p, total_p = count_lora_params(model)
print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
return model, lora_layers
def load_video_data(video_dir, max_samples=100, frames=16, res=256):
"""Load video files from directory and encode to VAE latents.
Uses pig-vae to convert video frames to latent space for training targets.
Falls back to random latents if pig-vae is not available.
"""
import glob, torchvision.io
from arbitor.config import SPECIAL_VOCAB
files = glob.glob(os.path.join(video_dir, "*.mp4")) + \
glob.glob(os.path.join(video_dir, "*.avi"))
if not files:
print(f" No video files found in {video_dir}", flush=True)
print(f" Using synthetic random latents for smoke testing", flush=True)
return _generate_synthetic(frames, res, max_samples)
print(f" Found {len(files)} video files", flush=True)
files = files[:max_samples]
# Try loading pig-vae
vae = None
try:
from arbitor.encoders.pig_vae import load_vae
vae = load_vae(device='cuda', quantize='int8')
print(f" pig-vae loaded for encoding", flush=True)
except Exception as e:
print(f" pig-vae not available: {e}", flush=True)
print(f" Using random latents (no video encoding)", flush=True)
return _generate_synthetic(frames, res, min(max_samples, 50))
data = []
for f in files:
try:
video, _, _ = torchvision.io.read_video(f, pts_unit='sec')
video = video.permute(3, 0, 1, 2).float() / 255.0
video = video[:, :frames, :res, :res]
if video.shape[1] < frames:
continue
video = video.unsqueeze(0).cuda()
with torch.no_grad():
latents = vae.encode(video).cpu()
data.append(latents)
except Exception as e:
continue
if not data:
return _generate_synthetic(frames, res, 50)
print(f" Encoded {len(data)} videos to latent space", flush=True)
return data
def _generate_synthetic(frames, res, count):
"""Fallback: generate random latent targets for testing."""
data = []
for _ in range(count):
latents = torch.randn(1, 16, 1, 32, 32)
data.append(latents)
print(f" Generated {count} synthetic latent targets", flush=True)
return data
def _match_latents(target, pred):
"""Resize or pad target latents to the current VideoHead output shape."""
if target.shape[0] == 1 and pred.shape[0] > 1:
target = target.expand(pred.shape[0], -1, -1, -1, -1).contiguous()
if target.shape[1] != pred.shape[1]:
if target.shape[1] > pred.shape[1]:
target = target[:, :pred.shape[1]]
else:
pad = target.new_zeros(target.shape[0], pred.shape[1] - target.shape[1], *target.shape[2:])
target = torch.cat([target, pad], dim=1)
if target.shape[2:] != pred.shape[2:]:
target = torch.nn.functional.interpolate(
target, size=pred.shape[2:], mode="trilinear", align_corners=False
)
return target
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB video/diffusion fine-tuning")
parser.add_argument("--video-dir", type=str, default=None,
help="Directory with .mp4/.avi files")
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--lora-alpha", type=float, default=32.0)
parser.add_argument("--max-moe-iters", type=int, default=1)
parser.add_argument("--run", type=str, default="diffusion-finetune")
parser.add_argument("--eval-interval", type=int, default=100)
parser.add_argument("--frames", type=int, default=8)
parser.add_argument("--res", type=int, default=128)
parser.add_argument("--max-samples", type=int, default=100)
args = parser.parse_args()
print("Building model with VideoHead + LoRA...", flush=True)
model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)
opt = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)
if args.video_dir:
data = load_video_data(args.video_dir, args.max_samples, args.frames, args.res)
else:
data = _generate_synthetic(args.frames, args.res, 100)
n = int(0.8 * len(data))
if len(data) > 1:
n = min(max(1, n), len(data) - 1)
train_data = data[:n] if n > 0 else data
val_data = data[n:] if n < len(data) else data[:1]
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
step = 0
best_val = float('inf')
model.train()
while step < args.steps:
opt.zero_grad()
accum_loss = 0.0
for _ in range(args.accum):
# Generate random text context for VideoHead conditioning
text = torch.randint(0, 256, (args.batch, 10)).cuda()
idx = torch.randint(0, len(train_data), (1,)).item()
target_latents = train_data[idx].cuda()
if target_latents.shape[0] == 1 and args.batch > 1:
target_latents = target_latents.expand(args.batch, -1, -1, -1, -1).contiguous()
# Forward through model → relational tokens → VideoHead → latents
embedded = model.embedding(text)
seq_out = model.multimodal_sequencer({'text': embedded})
rel = seq_out['text']
pred_latents = model.video_head(rel)
target_latents = _match_latents(target_latents, pred_latents)
# MSE loss on latents
loss_val = torch.nn.functional.mse_loss(pred_latents, target_latents)
loss = loss_val / args.accum
loss.backward()
accum_loss += loss_val.item()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0
)
opt.step()
scheduler.step()
step += 1
if step % args.eval_interval == 0:
model.eval()
val_loss = 0.0
with torch.no_grad():
text_v = torch.randint(0, 256, (args.batch, 10)).cuda()
embedded_v = model.embedding(text_v)
seq_v = model.multimodal_sequencer({'text': embedded_v})
rel_v = seq_v['text']
for idx in range(min(10, len(val_data))):
target = val_data[idx].cuda()
if target.shape[0] == 1 and args.batch > 1:
target = target.expand(args.batch, -1, -1, -1, -1).contiguous()
pred = model.video_head(rel_v)
target = _match_latents(target, pred)
val_loss += torch.nn.functional.mse_loss(pred, target).item()
val_loss /= min(10, len(val_data))
writer.add_scalar("loss/train", accum_loss, step)
writer.add_scalar("loss/eval", val_loss, step)
if val_loss < best_val:
best_val = val_loss
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/best_lora.pt")
print(f"step {step:>5d}/{args.steps} train={accum_loss:.6f} "
f"eval={val_loss:.6f} best={best_val:.6f}", flush=True)
model.train()
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/final_lora.pt")
print(f"Done. LoRA saved to {run_dir}/", flush=True)