File size: 3,384 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Vision training.
Freezes text/audio pipelines, trains image sequencer features.
Useful for fine-tuning the DINOv2 projection layers on custom image data.
"""
import os, sys, torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from torch.utils.tensorboard import SummaryWriter
from arbitor import ARBModel
from arbitor.kernel.ternary_scale import TScaleType
from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters


def freeze_non_vision(model):
    """Freeze everything except image sequencer projection layers."""
    for name, p in model.named_parameters():
        p.requires_grad = False
    for name, p in model.named_parameters():
        if any(k in name for k in ('image_sequencer', 'image', 'patch_proj',
                                     'modality_gate', 'output_router')):
            p.requires_grad = True


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="ARB vision/image training")
    parser.add_argument("--steps", type=int, default=3000)
    parser.add_argument("--batch", type=int, default=4)
    parser.add_argument("--image-size", type=int, default=224)
    parser.add_argument("--run", type=str, default="vision")
    parser.add_argument("--data", type=str, default=None, help="Image dataset path")
    parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton")
    args = parser.parse_args()
    os.environ["ARB_TERNARY_BACKEND"] = args.backend
    if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}:
        raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.")

    model = ARBModel(enable_image=True, enable_audio=False,
                     enable_vq=True, enable_graph=False,
                     enable_memory_modules=False, enable_moe=False,
                     max_moe_iters=4,
                     enable_attention=False,
                     enable_output_router=False,
                     enable_video_output=False,
                     enable_talker_output=False).cuda()
    freeze_non_vision(model)
    freeze_float_parameters(model)
    print(format_audit(audit_model(model)))

    if trainable_parameters(model):
        raise RuntimeError("Vision trainer is pure ternary; use training/finetuning/vision.py for LoRA adapters.")
    run_dir = f"models/checkpoints/{args.run}"
    os.makedirs(run_dir, exist_ok=True)
    writer = SummaryWriter(run_dir)

    # Generate synthetic image data for smoke testing
    for step in range(args.steps):
        images = torch.randn(args.batch, 3, args.image_size, args.image_size).cuda()
        text = torch.randint(0, 256, (args.batch, 16), device=images.device)
        targets = text[:, 3:]
        model.zero_grad(set_to_none=True)
        _, losses, _, _ = model(x=text, images=images, targets=targets)

        loss = losses.total
        model.prepare_ternary_backward(loss.detach(), update_scales=True)
        loss.backward()
        model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=loss)
        model.zero_grad(set_to_none=True)

        if step % 200 == 0:
            writer.add_scalar("loss/vision", loss.item(), step)
            print(f"step {step:>5d}  loss={loss.item():.3f}")