File size: 7,131 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | """Fine-tune ARB model on vision/image understanding tasks using LoRA.
Freezes text/audio pipelines, adapts vision encoder + core MoE.
Designed for 8GB VRAM with batch_size=1.
Usage:
python training/finetuning/vision.py \\
--data ./coco-captions \\
--steps 2000 --batch 1 --accum 4 --lr 1e-4 \\
--lora-rank 16 --run vision-finetune
Data format: directory of .jpg images + captions.json
captions.json: [{"image": "img001.jpg", "caption": "a cat sitting on..."}]
"""
import os, sys, time, json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
"""Build ARB model with vision + LoRA, freeze non-vision parts."""
from arbitor import ARBModel
from training.finetuning.lora import apply_lora_to_model, count_lora_params
model = ARBModel(
enable_image=True, enable_audio=False,
enable_vq=True, enable_graph=True,
enable_memory_modules=False, enable_moe=True,
max_moe_iters=max_moe_iters,
).cuda()
model.eval()
# Freeze everything, then enable gradients only for LoRA adapters
target_modules = ['W_gate', 'W_transform', 'byte_head', 'head', 'router',
'shared_up', 'shared_expert_gate', 'shared_expert_up',
'patch_proj', 'image_sequencer', 'projection']
lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
target_modules=target_modules)
lora_p, total_p = count_lora_params(model)
print(f" Base frozen: {total_p-lora_p:,} params", flush=True)
print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
return model, lora_layers
def load_image_data(data_dir, max_samples=None):
"""Load image-caption pairs from directory.
Expects {data_dir}/captions.json and images in {data_dir}/.
Each caption is tokenized to byte sequence by the model's ByteEmbedding.
"""
cap_path = os.path.join(data_dir, "captions.json")
with open(cap_path, "r") as f:
entries = json.load(f)
if max_samples:
entries = entries[:max_samples]
from torchvision import transforms
from arbitor.config import SPECIAL_VOCAB
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
data = []
for entry in entries:
img_path = os.path.join(data_dir, entry["image"])
caption = entry["caption"]
# Load and transform image
img = Image.open(img_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
# Encode caption as byte tokens with BOS/EOS
tokens = [SPECIAL_VOCAB['BOS']]
for byte in caption.encode('utf-8'):
tokens.append(byte)
tokens.append(SPECIAL_VOCAB['EOS'])
while len(tokens) < 4:
tokens.append(SPECIAL_VOCAB['PAD'])
text_tensor = torch.tensor(tokens, dtype=torch.long)
data.append((img_tensor, text_tensor))
print(f" Loaded {len(data)} image-caption pairs from {data_dir}", flush=True)
return data
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB vision fine-tuning")
parser.add_argument("--data", type=str, required=True, help="Image directory with captions.json")
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--lora-alpha", type=float, default=32.0)
parser.add_argument("--max-moe-iters", type=int, default=1)
parser.add_argument("--run", type=str, default="vision-finetune")
parser.add_argument("--eval-interval", type=int, default=100)
parser.add_argument("--save-every", type=int, default=500)
parser.add_argument("--max-samples", type=int, default=None)
args = parser.parse_args()
print("Building model with vision + LoRA adapters...", flush=True)
model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)
opt = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)
print(f"Loading data from {args.data}...", flush=True)
data = load_image_data(args.data, args.max_samples)
n = int(0.8 * len(data))
if len(data) > 1:
n = min(max(1, n), len(data) - 1)
train_data = data[:n] if n > 0 else data
val_data = data[n:] if n < len(data) else data[:1]
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
step = 0
best_val = float('inf')
model.train()
while step < args.steps:
opt.zero_grad()
accum_loss = 0.0
for micro in range(args.accum):
idx = torch.randint(0, len(train_data), (args.batch,)).item()
img_tensor, text_tokens = train_data[idx]
img_tensor = img_tensor.cuda()
text_tokens = text_tokens.cuda().unsqueeze(0)
_, losses, _, _ = model(x=text_tokens, images=img_tensor,
targets=text_tokens[:, 3:])
loss = losses.total / args.accum
loss.backward()
accum_loss += losses.total.item()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0
)
opt.step()
scheduler.step()
step += 1
if step % args.eval_interval == 0:
model.eval()
val_loss = 0.0
with torch.no_grad():
for idx in range(min(10, len(val_data))):
img, txt = val_data[idx]
img, txt = img.cuda(), txt.cuda().unsqueeze(0)
txt_ctx = txt[:, :max(4, min(txt.shape[1], 16))]
_, lv, _, _ = model(x=txt_ctx, images=img, targets=txt_ctx[:, 3:])
val_loss += lv.total.item()
val_loss /= min(10, len(val_data))
writer.add_scalar("loss/train", accum_loss, step)
writer.add_scalar("loss/eval", val_loss, step)
if val_loss < best_val:
best_val = val_loss
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/best_lora.pt")
print(f"step {step:>5d}/{args.steps} train={accum_loss:.3f} "
f"eval={val_loss:.3f} best={best_val:.3f}", flush=True)
model.train()
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/final_lora.pt")
print(f"Done. LoRA saved to {run_dir}/", flush=True)
|