File size: 3,384 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | """Vision training.
Freezes text/audio pipelines, trains image sequencer features.
Useful for fine-tuning the DINOv2 projection layers on custom image data.
"""
import os, sys, torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from torch.utils.tensorboard import SummaryWriter
from arbitor import ARBModel
from arbitor.kernel.ternary_scale import TScaleType
from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters
def freeze_non_vision(model):
"""Freeze everything except image sequencer projection layers."""
for name, p in model.named_parameters():
p.requires_grad = False
for name, p in model.named_parameters():
if any(k in name for k in ('image_sequencer', 'image', 'patch_proj',
'modality_gate', 'output_router')):
p.requires_grad = True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB vision/image training")
parser.add_argument("--steps", type=int, default=3000)
parser.add_argument("--batch", type=int, default=4)
parser.add_argument("--image-size", type=int, default=224)
parser.add_argument("--run", type=str, default="vision")
parser.add_argument("--data", type=str, default=None, help="Image dataset path")
parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton")
args = parser.parse_args()
os.environ["ARB_TERNARY_BACKEND"] = args.backend
if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}:
raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.")
model = ARBModel(enable_image=True, enable_audio=False,
enable_vq=True, enable_graph=False,
enable_memory_modules=False, enable_moe=False,
max_moe_iters=4,
enable_attention=False,
enable_output_router=False,
enable_video_output=False,
enable_talker_output=False).cuda()
freeze_non_vision(model)
freeze_float_parameters(model)
print(format_audit(audit_model(model)))
if trainable_parameters(model):
raise RuntimeError("Vision trainer is pure ternary; use training/finetuning/vision.py for LoRA adapters.")
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
# Generate synthetic image data for smoke testing
for step in range(args.steps):
images = torch.randn(args.batch, 3, args.image_size, args.image_size).cuda()
text = torch.randint(0, 256, (args.batch, 16), device=images.device)
targets = text[:, 3:]
model.zero_grad(set_to_none=True)
_, losses, _, _ = model(x=text, images=images, targets=targets)
loss = losses.total
model.prepare_ternary_backward(loss.detach(), update_scales=True)
loss.backward()
model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=loss)
model.zero_grad(set_to_none=True)
if step % 200 == 0:
writer.add_scalar("loss/vision", loss.item(), step)
print(f"step {step:>5d} loss={loss.item():.3f}")
|