LUNA-Training / lora_sft_train.py
ASTERIZER's picture
Upload lora_sft_train.py with huggingface_hub
cd7ee10 verified
import argparse
import gc
import math
import os
import time
from pathlib import Path
import torch
import torch.nn as nn
import yaml
from huggingface_hub import hf_hub_download
from torch.amp import GradScaler, autocast
from sft_train import LUNAModel, SFTDataset, cosine_lr, probe_hardware, run_eval_prompts
SEP = "=" * 72
class LoRALinear(nn.Module):
def __init__(self, base_layer, rank=16, alpha=32, dropout=0.05):
super().__init__()
if not isinstance(base_layer, nn.Linear):
raise TypeError("LoRALinear expects a torch.nn.Linear base layer")
self.base = base_layer
self.rank = rank
self.alpha = alpha
self.scale = alpha / max(rank, 1)
self.dropout = nn.Dropout(dropout)
self.lora_a = nn.Linear(base_layer.in_features, rank, bias=False)
self.lora_b = nn.Linear(rank, base_layer.out_features, bias=False)
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)
for parameter in self.base.parameters():
parameter.requires_grad = False
def forward(self, x):
base_out = self.base(x)
lora_out = self.lora_b(self.lora_a(self.dropout(x))) * self.scale
return base_out + lora_out
def load_config(config_path):
with open(config_path, encoding="utf-8") as handle:
raw = yaml.safe_load(handle)
cfg = {
"auto_config": raw.get("auto_config", True),
"hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
"hf_model_file": raw.get("hf_model_file", "sft_v1/final/model.pth"),
"pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/input_models/luna_sft_v1/model.pth"),
"train_json": raw.get("train_json", "Base/Datasets/rag_mcp_sft/train.json"),
"val_json": raw.get("val_json", "Base/Datasets/rag_mcp_sft/val.json"),
"out_dir": raw.get("out_dir", "Base/out/sft/rag_mcp_lora"),
"tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
"vocab_size": raw["model"]["vocab_size"],
"seq_len": raw["model"]["seq_len"],
"n_layer": raw["model"]["n_layer"],
"n_embd": raw["model"]["n_embd"],
"n_head": raw["model"]["n_head"],
"epochs": raw["train"]["epochs"],
"lr_warmup_steps": raw["train"]["lr_warmup_steps"],
"save_interval": raw["train"]["save_interval"],
"log_interval": raw["train"]["log_interval"],
"eval_interval": raw["train"]["eval_interval"],
"max_norm": raw["train"]["max_norm"],
"lr": raw["optimizer"]["lr"],
"min_lr": raw["optimizer"]["min_lr"],
"weight_decay": raw["optimizer"]["weight_decay"],
"betas": tuple(raw["optimizer"]["betas"]),
"eps": raw["optimizer"]["eps"],
"global_batch": raw["batch"]["global_batch"],
"micro_batch": raw["batch"]["micro_batch"],
"grad_accum": raw["batch"]["grad_accum"],
"auto_probe_batch": raw["batch"].get("auto_probe_batch", True),
"probe_safety": raw["batch"].get("probe_safety", 0.94),
"num_workers": raw["dataloader"]["num_workers"],
"pin_memory": raw["dataloader"]["pin_memory"],
"precision": raw["hardware"]["precision"],
"eval_prompts": raw.get("eval_prompts", []),
"lora_rank": raw["lora"]["rank"],
"lora_alpha": raw["lora"]["alpha"],
"lora_dropout": raw["lora"]["dropout"],
"target_modules": list(raw["lora"]["target_modules"]),
}
return cfg
def resolve_checkpoint(cfg):
ckpt_path = Path(cfg["pretrained_ckpt"])
if ckpt_path.exists():
return ckpt_path
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
hf_hub_download(
repo_id=cfg["hf_model_repo"],
filename=cfg["hf_model_file"],
local_dir=str(ckpt_path.parent),
token=os.environ.get("HF_TOKEN"),
)
downloaded = ckpt_path.parent / cfg["hf_model_file"]
if not downloaded.exists():
raise FileNotFoundError(f"Expected downloaded checkpoint at {downloaded}")
return downloaded
def inject_lora(model, target_modules, rank, alpha, dropout):
replaced = []
for module_name, module in list(model.named_modules()):
if not isinstance(module, nn.Linear):
continue
if not any(module_name.endswith(target) for target in target_modules):
continue
parent_name, _, child_name = module_name.rpartition(".")
parent_module = model.get_submodule(parent_name) if parent_name else model
wrapped = LoRALinear(module, rank=rank, alpha=alpha, dropout=dropout)
wrapped = wrapped.to(device=module.weight.device, dtype=module.weight.dtype)
setattr(parent_module, child_name, wrapped)
replaced.append(module_name)
if not replaced:
raise RuntimeError("No target modules matched for LoRA injection")
return replaced
def get_lora_state_dict(model):
state_dict = model.state_dict()
return {
name: tensor.cpu()
for name, tensor in state_dict.items()
if "lora_a.weight" in name or "lora_b.weight" in name
}
def count_trainable_parameters(model):
return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
def probe_max_micro_batch_lora(model, trainable_parameters, device, dtype, seq_len, vocab_size, safety=0.94, grad_accum_sim=2):
if device.type != "cuda":
return 1
optimizer = torch.optim.AdamW(trainable_parameters, lr=1e-4)
lo, hi, best = 1, 512, 1
while lo <= hi:
mid = (lo + hi) // 2
try:
torch.cuda.empty_cache()
gc.collect()
optimizer.zero_grad(set_to_none=True)
for _ in range(grad_accum_sim):
input_ids = torch.randint(0, vocab_size, (mid, seq_len), device=device)
loss_mask = torch.ones_like(input_ids)
with autocast(device_type="cuda", dtype=dtype):
_, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
loss = loss / grad_accum_sim
loss.backward()
del input_ids, loss_mask, loss
optimizer.step()
optimizer.zero_grad(set_to_none=True)
best = mid
lo = mid + 1
except (torch.cuda.OutOfMemoryError, RuntimeError) as error:
if "out of memory" not in str(error).lower() and not isinstance(error, torch.cuda.OutOfMemoryError):
raise
optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
gc.collect()
hi = mid - 1
del optimizer
torch.cuda.empty_cache()
gc.collect()
safe = max(1, int(best * safety))
print(f" LoRA batch probe: max_micro_batch={best}, using {safe} ({int(safety * 100)}% safety)")
return safe
def load_base_weights(model, checkpoint_path, device):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
model.load_state_dict(state_dict, strict=True)
def train(cfg):
hw = probe_hardware()
device = torch.device(hw["device"])
dtype = hw.get("dtype", torch.float32) if cfg["auto_config"] else {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}.get(cfg["precision"], torch.float32)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"])
ckpt_path = resolve_checkpoint(cfg)
model = LUNAModel(
vocab_size=cfg["vocab_size"],
block_size=cfg["seq_len"],
n_layer=cfg["n_layer"],
n_embd=cfg["n_embd"],
n_head=cfg["n_head"],
).to(device)
load_base_weights(model, ckpt_path, device)
for parameter in model.parameters():
parameter.requires_grad = False
replaced = inject_lora(
model,
target_modules=cfg["target_modules"],
rank=cfg["lora_rank"],
alpha=cfg["lora_alpha"],
dropout=cfg["lora_dropout"],
)
trainable_params = count_trainable_parameters(model)
total_params = sum(parameter.numel() for parameter in model.parameters())
trainable_parameters = [parameter for parameter in model.parameters() if parameter.requires_grad]
if cfg["auto_config"] and device.type == "cuda" and cfg["auto_probe_batch"]:
print(" Probing LoRA micro_batch against available VRAM...")
cfg["micro_batch"] = probe_max_micro_batch_lora(
model,
trainable_parameters=trainable_parameters,
device=device,
dtype=dtype,
seq_len=cfg["seq_len"],
vocab_size=cfg["vocab_size"],
safety=cfg["probe_safety"],
)
cfg["grad_accum"] = max(1, math.ceil(cfg["global_batch"] / cfg["micro_batch"]))
torch.cuda.reset_peak_memory_stats(device)
effective_batch = cfg["micro_batch"] * cfg["grad_accum"]
train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"])
val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"]) if Path(cfg["val_json"]).exists() else None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg["micro_batch"],
shuffle=True,
num_workers=cfg["num_workers"],
pin_memory=cfg["pin_memory"],
drop_last=True,
prefetch_factor=4 if cfg["num_workers"] > 0 else None,
persistent_workers=cfg["num_workers"] > 0,
)
val_loader = None
if val_dataset is not None:
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg["micro_batch"],
shuffle=False,
num_workers=min(2, cfg["num_workers"]),
pin_memory=cfg["pin_memory"],
drop_last=False,
)
optimizer = torch.optim.AdamW(
trainable_parameters,
lr=cfg["lr"],
weight_decay=cfg["weight_decay"],
betas=cfg["betas"],
eps=cfg["eps"],
)
scaler = GradScaler(enabled=(device.type == "cuda" and dtype == torch.float16))
steps_per_epoch = max(1, len(train_loader) // cfg["grad_accum"])
total_steps = steps_per_epoch * cfg["epochs"]
warmup_steps = min(cfg["lr_warmup_steps"], max(1, total_steps // 5))
out_dir = Path(cfg["out_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
best_val_loss = float("inf")
step = 0
latest_path = out_dir / "latest.pt"
if latest_path.exists():
checkpoint = torch.load(latest_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["adapter"], strict=False)
optimizer.load_state_dict(checkpoint["optimizer"])
step = checkpoint["step"]
print(SEP)
print(" LUNA 100M - LoRA SFT")
print(SEP)
print(f" Base checkpoint : {ckpt_path}")
print(f" Train dataset : {cfg['train_json']}")
print(f" Val dataset : {cfg['val_json']}")
print(f" Output dir : {out_dir}")
print(f" Device : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
print(f" Precision : {cfg['precision']} dtype={dtype}")
print(f" LoRA modules : {', '.join(replaced)}")
print(f" Trainable params: {trainable_params:,} / {total_params:,}")
print(f" micro_batch : {cfg['micro_batch']}")
print(f" grad_accum : {cfg['grad_accum']}")
print(f" effective_batch : {effective_batch}")
print(f" Train samples : {len(train_dataset):,}")
print(f" Val samples : {len(val_dataset):,}" if val_dataset is not None else " Val samples : 0")
print(SEP)
if cfg["eval_prompts"] and step == 0:
run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir)
model.train()
run_t0 = time.perf_counter()
for epoch in range(cfg["epochs"]):
micro_step = 0
for input_ids, loss_mask in train_loader:
current_global_step = epoch * steps_per_epoch + (micro_step // cfg["grad_accum"])
if current_global_step < step and (micro_step % cfg["grad_accum"] == cfg["grad_accum"] - 1):
micro_step += 1
continue
if current_global_step >= total_steps:
break
input_ids = input_ids.to(device, non_blocking=True)
loss_mask = loss_mask.to(device, non_blocking=True)
step_start = time.perf_counter()
with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
_, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
loss = loss / cfg["grad_accum"]
scaler.scale(loss).backward()
micro_step += 1
if micro_step % cfg["grad_accum"] != 0:
continue
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(trainable_parameters, cfg["max_norm"])
lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
for param_group in optimizer.param_groups:
param_group["lr"] = lr_now
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if device.type == "cuda":
torch.cuda.synchronize()
dt = time.perf_counter() - step_start
step += 1
if step % cfg["log_interval"] == 0 or step <= 3:
tokens_step = effective_batch * cfg["seq_len"]
tps = tokens_step / max(dt, 1e-6)
vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
eta_h = (total_steps - step) * dt / 3600
print(
f" step {step:6d}/{total_steps} | epoch {epoch + 1}/{cfg['epochs']} | "
f"loss {loss.item() * cfg['grad_accum']:.4f} | lr {lr_now:.2e} | "
f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h"
)
if step % cfg["save_interval"] == 0 or step == total_steps:
step_dir = out_dir / f"step-{step:06d}"
step_dir.mkdir(parents=True, exist_ok=True)
adapter_state = get_lora_state_dict(model)
torch.save(adapter_state, step_dir / "adapter_model.pt")
torch.save(
{
"step": step,
"adapter": adapter_state,
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"loss": loss.item() * cfg["grad_accum"],
},
latest_path,
)
print(f" Saved -> {step_dir}")
if step % cfg["eval_interval"] == 0 or step == total_steps:
if val_loader is not None:
model.eval()
val_loss_sum = 0.0
val_count = 0
with torch.no_grad():
for val_ids, val_mask in val_loader:
val_ids = val_ids.to(device, non_blocking=True)
val_mask = val_mask.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
_, val_loss = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False)
val_loss_sum += val_loss.item()
val_count += 1
if val_count >= 50:
break
avg_val = val_loss_sum / max(val_count, 1)
print(f" Val loss: {avg_val:.4f}")
if avg_val < best_val_loss:
best_val_loss = avg_val
torch.save(get_lora_state_dict(model), out_dir / "best_adapter_model.pt")
print(" New best! Saved best_adapter_model.pt")
model.train()
if cfg["eval_prompts"]:
run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, step, out_dir)
final_dir = out_dir / "final"
final_dir.mkdir(parents=True, exist_ok=True)
torch.save(get_lora_state_dict(model), final_dir / "adapter_model.pt")
torch.save(
{
"step": step,
"adapter": get_lora_state_dict(model),
"lora_rank": cfg["lora_rank"],
"lora_alpha": cfg["lora_alpha"],
"lora_dropout": cfg["lora_dropout"],
"target_modules": cfg["target_modules"],
"base_checkpoint": str(ckpt_path),
},
final_dir / "adapter_bundle.pt",
)
total_h = (time.perf_counter() - run_t0) / 3600
print(SEP)
print(f" LoRA SFT complete in {total_h:.2f}h -> {final_dir}")
print(f" Best val loss: {best_val_loss:.4f}")
print(SEP)
def parse_args():
parser = argparse.ArgumentParser(description="LUNA 100M - LoRA SFT")
parser.add_argument("--config", default="rag_mcp_lora_config.yaml")
parser.add_argument("--pretrained_ckpt", default=None)
parser.add_argument("--train_json", default=None)
parser.add_argument("--val_json", default=None)
parser.add_argument("--out_dir", default=None)
parser.add_argument("--epochs", type=int, default=None)
return parser.parse_args()
def main():
args = parse_args()
cfg = load_config(args.config)
for key in ("pretrained_ckpt", "train_json", "val_json", "out_dir"):
value = getattr(args, key)
if value:
cfg[key] = value
if args.epochs is not None:
cfg["epochs"] = args.epochs
train(cfg)
if __name__ == "__main__":
main()