File size: 5,730 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """Fine-tune ARB model on text/coding data using LoRA adapters.
Memory-efficient: base 1.5B weights stay frozen, only ~50MB adapters train.
Designed for 8GB VRAM with batch_size=1 and gradient accumulation.
Usage:
python training/finetuning/text.py \\
--data training/data/coding-Instructions.pt \\
--steps 1000 --batch 1 --accum 4 --lr 1e-4 \\
--lora-rank 16 --run my-finetune
Data format: .pt file with tokenized byte sequences (use data/tokenize_from_hf.py).
"""
import os, sys, time, math, json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter
def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
"""Build ARB model with LoRA adapters, all frozen except adapters."""
from arbitor import ARBModel
from training.finetuning.lora import apply_lora_to_model, count_lora_params
model = ARBModel(
enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=True,
max_moe_iters=max_moe_iters,
).cuda()
target_modules = ['moe', 'byte_head', 'head', 'embedding', 'router',
'output_router', 'moe', 'shared', 'projection']
lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
target_modules=target_modules)
lora_p, total_p = count_lora_params(model)
print(f" Base frozen: {total_p-lora_p:,} params", flush=True)
print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
return model, lora_layers
def load_data(source, ctx=256):
"""Load tokenized .pt dataset."""
data = torch.load(source, weights_only=True)
n = int(0.9 * len(data))
print(f" Data: {len(data):,} tokens, {n:,} train / {len(data)-n:,} val", flush=True)
return data[:n].cuda(), data[n:].cuda()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB LoRA fine-tuning")
parser.add_argument("--data", type=str, default="training/data/fineweb-sample.pt")
parser.add_argument("--steps", type=int, default=1000)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--accum", type=int, default=4, help="Gradient accumulation steps")
parser.add_argument("--ctx", type=int, default=128)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--lora-alpha", type=float, default=32.0)
parser.add_argument("--max-moe-iters", type=int, default=1)
parser.add_argument("--run", type=str, default="finetune")
parser.add_argument("--eval-interval", type=int, default=100)
parser.add_argument("--save-every", type=int, default=500)
parser.add_argument("--resume", type=str, default=None, help="LoRA checkpoint to resume from")
args = parser.parse_args()
print("Building model with LoRA adapters...", flush=True)
model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)
if args.resume:
from training.finetuning.lora import load_lora
load_lora(model, args.resume)
print(f" Resumed from {args.resume}", flush=True)
opt = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)
print(f"Loading data: {args.data}", flush=True)
train_data, val_data = load_data(args.data, args.ctx)
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
step = 0
best_val = float('inf')
accum_buffer = None
model.train()
# Training loop
while step < args.steps:
opt.zero_grad()
accum_loss = 0.0
for micro in range(args.accum):
ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,))
x = torch.stack([train_data[i:i+args.ctx] for i in ix])
t = x[:, 3:]
_, losses, _, _ = model(x, targets=t)
loss = losses.total / args.accum
loss.backward()
accum_loss += losses.total.item()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0
)
opt.step()
scheduler.step()
step += 1
if step % args.eval_interval == 0:
model.eval()
with torch.no_grad():
ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,))
xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v])
tv = xv[:, 3:]
_, lv, _, _ = model(xv, targets=tv)
val_loss = lv.total.item()
writer.add_scalar("loss/train", accum_loss, step)
writer.add_scalar("loss/eval", val_loss, step)
writer.add_scalar("lr", scheduler.get_last_lr()[0], step)
if val_loss < best_val:
best_val = val_loss
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/best_lora.pt")
print(f"step {step:>5d}/{args.steps} "
f"train={accum_loss:.3f} eval={val_loss:.3f} "
f"best={best_val:.3f} lr={scheduler.get_last_lr()[0]:.2e}",
flush=True)
model.train()
# Final save
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/final_lora.pt")
print(f"Done. LoRA adapters saved to {run_dir}/", flush=True)
|