| """Fine-tune ARB model on text/coding data using LoRA adapters. |
| |
| Memory-efficient: base 1.5B weights stay frozen, only ~50MB adapters train. |
| Designed for 8GB VRAM with batch_size=1 and gradient accumulation. |
| |
| Usage: |
| python training/finetuning/text.py \\ |
| --data training/data/coding-Instructions.pt \\ |
| --steps 1000 --batch 1 --accum 4 --lr 1e-4 \\ |
| --lora-rank 16 --run my-finetune |
| |
| Data format: .pt file with tokenized byte sequences (use data/tokenize_from_hf.py). |
| """ |
| import os, sys, time, math, json |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
| def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1): |
| """Build ARB model with LoRA adapters, all frozen except adapters.""" |
| from arbitor import ARBModel |
| from training.finetuning.lora import apply_lora_to_model, count_lora_params |
|
|
| model = ARBModel( |
| enable_image=False, enable_audio=False, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=True, |
| max_moe_iters=max_moe_iters, |
| ).cuda() |
|
|
| target_modules = ['moe', 'byte_head', 'head', 'embedding', 'router', |
| 'output_router', 'moe', 'shared', 'projection'] |
| lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha, |
| target_modules=target_modules) |
|
|
| lora_p, total_p = count_lora_params(model) |
| print(f" Base frozen: {total_p-lora_p:,} params", flush=True) |
| print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True) |
| return model, lora_layers |
|
|
|
|
| def load_data(source, ctx=256): |
| """Load tokenized .pt dataset.""" |
| data = torch.load(source, weights_only=True) |
| n = int(0.9 * len(data)) |
| print(f" Data: {len(data):,} tokens, {n:,} train / {len(data)-n:,} val", flush=True) |
| return data[:n].cuda(), data[n:].cuda() |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="ARB LoRA fine-tuning") |
| parser.add_argument("--data", type=str, default="training/data/fineweb-sample.pt") |
| parser.add_argument("--steps", type=int, default=1000) |
| parser.add_argument("--batch", type=int, default=1) |
| parser.add_argument("--accum", type=int, default=4, help="Gradient accumulation steps") |
| parser.add_argument("--ctx", type=int, default=128) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--lora-rank", type=int, default=16) |
| parser.add_argument("--lora-alpha", type=float, default=32.0) |
| parser.add_argument("--max-moe-iters", type=int, default=1) |
| parser.add_argument("--run", type=str, default="finetune") |
| parser.add_argument("--eval-interval", type=int, default=100) |
| parser.add_argument("--save-every", type=int, default=500) |
| parser.add_argument("--resume", type=str, default=None, help="LoRA checkpoint to resume from") |
| args = parser.parse_args() |
|
|
| print("Building model with LoRA adapters...", flush=True) |
| model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters) |
|
|
| if args.resume: |
| from training.finetuning.lora import load_lora |
| load_lora(model, args.resume) |
| print(f" Resumed from {args.resume}", flush=True) |
|
|
| opt = torch.optim.AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=args.lr, weight_decay=0.01 |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps) |
|
|
| print(f"Loading data: {args.data}", flush=True) |
| train_data, val_data = load_data(args.data, args.ctx) |
|
|
| run_dir = f"models/checkpoints/{args.run}" |
| os.makedirs(run_dir, exist_ok=True) |
| writer = SummaryWriter(run_dir) |
| step = 0 |
| best_val = float('inf') |
| accum_buffer = None |
| model.train() |
|
|
| |
| while step < args.steps: |
| opt.zero_grad() |
| accum_loss = 0.0 |
|
|
| for micro in range(args.accum): |
| ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,)) |
| x = torch.stack([train_data[i:i+args.ctx] for i in ix]) |
| t = x[:, 3:] |
|
|
| _, losses, _, _ = model(x, targets=t) |
| loss = losses.total / args.accum |
| loss.backward() |
|
|
| accum_loss += losses.total.item() |
|
|
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad], 1.0 |
| ) |
| opt.step() |
| scheduler.step() |
| step += 1 |
|
|
| if step % args.eval_interval == 0: |
| model.eval() |
| with torch.no_grad(): |
| ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,)) |
| xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v]) |
| tv = xv[:, 3:] |
| _, lv, _, _ = model(xv, targets=tv) |
| val_loss = lv.total.item() |
|
|
| writer.add_scalar("loss/train", accum_loss, step) |
| writer.add_scalar("loss/eval", val_loss, step) |
| writer.add_scalar("lr", scheduler.get_last_lr()[0], step) |
|
|
| if val_loss < best_val: |
| best_val = val_loss |
| from training.finetuning.lora import save_lora |
| save_lora(lora_layers, f"{run_dir}/best_lora.pt") |
|
|
| print(f"step {step:>5d}/{args.steps} " |
| f"train={accum_loss:.3f} eval={val_loss:.3f} " |
| f"best={best_val:.3f} lr={scheduler.get_last_lr()[0]:.2e}", |
| flush=True) |
| model.train() |
|
|
| |
| from training.finetuning.lora import save_lora |
| save_lora(lora_layers, f"{run_dir}/final_lora.pt") |
| print(f"Done. LoRA adapters saved to {run_dir}/", flush=True) |
|
|