ARBS / training /finetuning /text.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Fine-tune ARB model on text/coding data using LoRA adapters.
Memory-efficient: base 1.5B weights stay frozen, only ~50MB adapters train.
Designed for 8GB VRAM with batch_size=1 and gradient accumulation.
Usage:
python training/finetuning/text.py \\
--data training/data/coding-Instructions.pt \\
--steps 1000 --batch 1 --accum 4 --lr 1e-4 \\
--lora-rank 16 --run my-finetune
Data format: .pt file with tokenized byte sequences (use data/tokenize_from_hf.py).
"""
import os, sys, time, math, json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter
def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
"""Build ARB model with LoRA adapters, all frozen except adapters."""
from arbitor import ARBModel
from training.finetuning.lora import apply_lora_to_model, count_lora_params
model = ARBModel(
enable_image=False, enable_audio=False,
enable_vq=False, enable_graph=False,
enable_memory_modules=False, enable_moe=True,
max_moe_iters=max_moe_iters,
).cuda()
target_modules = ['moe', 'byte_head', 'head', 'embedding', 'router',
'output_router', 'moe', 'shared', 'projection']
lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
target_modules=target_modules)
lora_p, total_p = count_lora_params(model)
print(f" Base frozen: {total_p-lora_p:,} params", flush=True)
print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
return model, lora_layers
def load_data(source, ctx=256):
"""Load tokenized .pt dataset."""
data = torch.load(source, weights_only=True)
n = int(0.9 * len(data))
print(f" Data: {len(data):,} tokens, {n:,} train / {len(data)-n:,} val", flush=True)
return data[:n].cuda(), data[n:].cuda()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB LoRA fine-tuning")
parser.add_argument("--data", type=str, default="training/data/fineweb-sample.pt")
parser.add_argument("--steps", type=int, default=1000)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--accum", type=int, default=4, help="Gradient accumulation steps")
parser.add_argument("--ctx", type=int, default=128)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--lora-alpha", type=float, default=32.0)
parser.add_argument("--max-moe-iters", type=int, default=1)
parser.add_argument("--run", type=str, default="finetune")
parser.add_argument("--eval-interval", type=int, default=100)
parser.add_argument("--save-every", type=int, default=500)
parser.add_argument("--resume", type=str, default=None, help="LoRA checkpoint to resume from")
args = parser.parse_args()
print("Building model with LoRA adapters...", flush=True)
model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)
if args.resume:
from training.finetuning.lora import load_lora
load_lora(model, args.resume)
print(f" Resumed from {args.resume}", flush=True)
opt = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)
print(f"Loading data: {args.data}", flush=True)
train_data, val_data = load_data(args.data, args.ctx)
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
step = 0
best_val = float('inf')
accum_buffer = None
model.train()
# Training loop
while step < args.steps:
opt.zero_grad()
accum_loss = 0.0
for micro in range(args.accum):
ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,))
x = torch.stack([train_data[i:i+args.ctx] for i in ix])
t = x[:, 3:]
_, losses, _, _ = model(x, targets=t)
loss = losses.total / args.accum
loss.backward()
accum_loss += losses.total.item()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0
)
opt.step()
scheduler.step()
step += 1
if step % args.eval_interval == 0:
model.eval()
with torch.no_grad():
ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,))
xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v])
tv = xv[:, 3:]
_, lv, _, _ = model(xv, targets=tv)
val_loss = lv.total.item()
writer.add_scalar("loss/train", accum_loss, step)
writer.add_scalar("loss/eval", val_loss, step)
writer.add_scalar("lr", scheduler.get_last_lr()[0], step)
if val_loss < best_val:
best_val = val_loss
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/best_lora.pt")
print(f"step {step:>5d}/{args.steps} "
f"train={accum_loss:.3f} eval={val_loss:.3f} "
f"best={best_val:.3f} lr={scheduler.get_last_lr()[0]:.2e}",
flush=True)
model.train()
# Final save
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/final_lora.pt")
print(f"Done. LoRA adapters saved to {run_dir}/", flush=True)