"""Fine-tune ARB model on vision/image understanding tasks using LoRA. Freezes text/audio pipelines, adapts vision encoder + core MoE. Designed for 8GB VRAM with batch_size=1. Usage: python training/finetuning/vision.py \\ --data ./coco-captions \\ --steps 2000 --batch 1 --accum 4 --lr 1e-4 \\ --lora-rank 16 --run vision-finetune Data format: directory of .jpg images + captions.json captions.json: [{"image": "img001.jpg", "caption": "a cat sitting on..."}] """ import os, sys, time, json sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch from torch.utils.tensorboard import SummaryWriter from PIL import Image def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1): """Build ARB model with vision + LoRA, freeze non-vision parts.""" from arbitor import ARBModel from training.finetuning.lora import apply_lora_to_model, count_lora_params model = ARBModel( enable_image=True, enable_audio=False, enable_vq=True, enable_graph=True, enable_memory_modules=False, enable_moe=True, max_moe_iters=max_moe_iters, ).cuda() model.eval() # Freeze everything, then enable gradients only for LoRA adapters target_modules = ['W_gate', 'W_transform', 'byte_head', 'head', 'router', 'shared_up', 'shared_expert_gate', 'shared_expert_up', 'patch_proj', 'image_sequencer', 'projection'] lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha, target_modules=target_modules) lora_p, total_p = count_lora_params(model) print(f" Base frozen: {total_p-lora_p:,} params", flush=True) print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True) return model, lora_layers def load_image_data(data_dir, max_samples=None): """Load image-caption pairs from directory. Expects {data_dir}/captions.json and images in {data_dir}/. Each caption is tokenized to byte sequence by the model's ByteEmbedding. """ cap_path = os.path.join(data_dir, "captions.json") with open(cap_path, "r") as f: entries = json.load(f) if max_samples: entries = entries[:max_samples] from torchvision import transforms from arbitor.config import SPECIAL_VOCAB transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) data = [] for entry in entries: img_path = os.path.join(data_dir, entry["image"]) caption = entry["caption"] # Load and transform image img = Image.open(img_path).convert("RGB") img_tensor = transform(img).unsqueeze(0) # Encode caption as byte tokens with BOS/EOS tokens = [SPECIAL_VOCAB['BOS']] for byte in caption.encode('utf-8'): tokens.append(byte) tokens.append(SPECIAL_VOCAB['EOS']) while len(tokens) < 4: tokens.append(SPECIAL_VOCAB['PAD']) text_tensor = torch.tensor(tokens, dtype=torch.long) data.append((img_tensor, text_tensor)) print(f" Loaded {len(data)} image-caption pairs from {data_dir}", flush=True) return data if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="ARB vision fine-tuning") parser.add_argument("--data", type=str, required=True, help="Image directory with captions.json") parser.add_argument("--steps", type=int, default=2000) parser.add_argument("--batch", type=int, default=1) parser.add_argument("--accum", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--lora-rank", type=int, default=16) parser.add_argument("--lora-alpha", type=float, default=32.0) parser.add_argument("--max-moe-iters", type=int, default=1) parser.add_argument("--run", type=str, default="vision-finetune") parser.add_argument("--eval-interval", type=int, default=100) parser.add_argument("--save-every", type=int, default=500) parser.add_argument("--max-samples", type=int, default=None) args = parser.parse_args() print("Building model with vision + LoRA adapters...", flush=True) model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters) opt = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=0.01 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps) print(f"Loading data from {args.data}...", flush=True) data = load_image_data(args.data, args.max_samples) n = int(0.8 * len(data)) if len(data) > 1: n = min(max(1, n), len(data) - 1) train_data = data[:n] if n > 0 else data val_data = data[n:] if n < len(data) else data[:1] run_dir = f"models/checkpoints/{args.run}" os.makedirs(run_dir, exist_ok=True) writer = SummaryWriter(run_dir) step = 0 best_val = float('inf') model.train() while step < args.steps: opt.zero_grad() accum_loss = 0.0 for micro in range(args.accum): idx = torch.randint(0, len(train_data), (args.batch,)).item() img_tensor, text_tokens = train_data[idx] img_tensor = img_tensor.cuda() text_tokens = text_tokens.cuda().unsqueeze(0) _, losses, _, _ = model(x=text_tokens, images=img_tensor, targets=text_tokens[:, 3:]) loss = losses.total / args.accum loss.backward() accum_loss += losses.total.item() torch.nn.utils.clip_grad_norm_( [p for p in model.parameters() if p.requires_grad], 1.0 ) opt.step() scheduler.step() step += 1 if step % args.eval_interval == 0: model.eval() val_loss = 0.0 with torch.no_grad(): for idx in range(min(10, len(val_data))): img, txt = val_data[idx] img, txt = img.cuda(), txt.cuda().unsqueeze(0) txt_ctx = txt[:, :max(4, min(txt.shape[1], 16))] _, lv, _, _ = model(x=txt_ctx, images=img, targets=txt_ctx[:, 3:]) val_loss += lv.total.item() val_loss /= min(10, len(val_data)) writer.add_scalar("loss/train", accum_loss, step) writer.add_scalar("loss/eval", val_loss, step) if val_loss < best_val: best_val = val_loss from training.finetuning.lora import save_lora save_lora(lora_layers, f"{run_dir}/best_lora.pt") print(f"step {step:>5d}/{args.steps} train={accum_loss:.3f} " f"eval={val_loss:.3f} best={best_val:.3f}", flush=True) model.train() from training.finetuning.lora import save_lora save_lora(lora_layers, f"{run_dir}/final_lora.pt") print(f"Done. LoRA saved to {run_dir}/", flush=True)