File size: 5,730 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Fine-tune ARB model on text/coding data using LoRA adapters.

Memory-efficient: base 1.5B weights stay frozen, only ~50MB adapters train.
Designed for 8GB VRAM with batch_size=1 and gradient accumulation.

Usage:
    python training/finetuning/text.py \\
        --data training/data/coding-Instructions.pt \\
        --steps 1000 --batch 1 --accum 4 --lr 1e-4 \\
        --lora-rank 16 --run my-finetune

Data format: .pt file with tokenized byte sequences (use data/tokenize_from_hf.py).
"""
import os, sys, time, math, json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter


def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
    """Build ARB model with LoRA adapters, all frozen except adapters."""
    from arbitor import ARBModel
    from training.finetuning.lora import apply_lora_to_model, count_lora_params

    model = ARBModel(
        enable_image=False, enable_audio=False,
        enable_vq=False, enable_graph=False,
        enable_memory_modules=False, enable_moe=True,
        max_moe_iters=max_moe_iters,
    ).cuda()

    target_modules = ['moe', 'byte_head', 'head', 'embedding', 'router',
                      'output_router', 'moe', 'shared', 'projection']
    lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
                                       target_modules=target_modules)

    lora_p, total_p = count_lora_params(model)
    print(f"  Base frozen: {total_p-lora_p:,} params", flush=True)
    print(f"  LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
    return model, lora_layers


def load_data(source, ctx=256):
    """Load tokenized .pt dataset."""
    data = torch.load(source, weights_only=True)
    n = int(0.9 * len(data))
    print(f"  Data: {len(data):,} tokens, {n:,} train / {len(data)-n:,} val", flush=True)
    return data[:n].cuda(), data[n:].cuda()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="ARB LoRA fine-tuning")
    parser.add_argument("--data", type=str, default="training/data/fineweb-sample.pt")
    parser.add_argument("--steps", type=int, default=1000)
    parser.add_argument("--batch", type=int, default=1)
    parser.add_argument("--accum", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument("--ctx", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--lora-alpha", type=float, default=32.0)
    parser.add_argument("--max-moe-iters", type=int, default=1)
    parser.add_argument("--run", type=str, default="finetune")
    parser.add_argument("--eval-interval", type=int, default=100)
    parser.add_argument("--save-every", type=int, default=500)
    parser.add_argument("--resume", type=str, default=None, help="LoRA checkpoint to resume from")
    args = parser.parse_args()

    print("Building model with LoRA adapters...", flush=True)
    model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)

    if args.resume:
        from training.finetuning.lora import load_lora
        load_lora(model, args.resume)
        print(f"  Resumed from {args.resume}", flush=True)

    opt = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.lr, weight_decay=0.01
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)

    print(f"Loading data: {args.data}", flush=True)
    train_data, val_data = load_data(args.data, args.ctx)

    run_dir = f"models/checkpoints/{args.run}"
    os.makedirs(run_dir, exist_ok=True)
    writer = SummaryWriter(run_dir)
    step = 0
    best_val = float('inf')
    accum_buffer = None
    model.train()

    # Training loop
    while step < args.steps:
        opt.zero_grad()
        accum_loss = 0.0

        for micro in range(args.accum):
            ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,))
            x = torch.stack([train_data[i:i+args.ctx] for i in ix])
            t = x[:, 3:]

            _, losses, _, _ = model(x, targets=t)
            loss = losses.total / args.accum
            loss.backward()

            accum_loss += losses.total.item()

        torch.nn.utils.clip_grad_norm_(
            [p for p in model.parameters() if p.requires_grad], 1.0
        )
        opt.step()
        scheduler.step()
        step += 1

        if step % args.eval_interval == 0:
            model.eval()
            with torch.no_grad():
                ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,))
                xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v])
                tv = xv[:, 3:]
                _, lv, _, _ = model(xv, targets=tv)
                val_loss = lv.total.item()

            writer.add_scalar("loss/train", accum_loss, step)
            writer.add_scalar("loss/eval", val_loss, step)
            writer.add_scalar("lr", scheduler.get_last_lr()[0], step)

            if val_loss < best_val:
                best_val = val_loss
                from training.finetuning.lora import save_lora
                save_lora(lora_layers, f"{run_dir}/best_lora.pt")

            print(f"step {step:>5d}/{args.steps}  "
                  f"train={accum_loss:.3f}  eval={val_loss:.3f}  "
                  f"best={best_val:.3f}  lr={scheduler.get_last_lr()[0]:.2e}",
                  flush=True)
            model.train()

    # Final save
    from training.finetuning.lora import save_lora
    save_lora(lora_layers, f"{run_dir}/final_lora.pt")
    print(f"Done. LoRA adapters saved to {run_dir}/", flush=True)