| import argparse | |
| import gc | |
| import math | |
| import os | |
| import time | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from torch.amp import GradScaler, autocast | |
| from sft_train import LUNAModel, SFTDataset, cosine_lr, probe_hardware, run_eval_prompts | |
| SEP = "=" * 72 | |
| class LoRALinear(nn.Module): | |
| def __init__(self, base_layer, rank=16, alpha=32, dropout=0.05): | |
| super().__init__() | |
| if not isinstance(base_layer, nn.Linear): | |
| raise TypeError("LoRALinear expects a torch.nn.Linear base layer") | |
| self.base = base_layer | |
| self.rank = rank | |
| self.alpha = alpha | |
| self.scale = alpha / max(rank, 1) | |
| self.dropout = nn.Dropout(dropout) | |
| self.lora_a = nn.Linear(base_layer.in_features, rank, bias=False) | |
| self.lora_b = nn.Linear(rank, base_layer.out_features, bias=False) | |
| nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_b.weight) | |
| for parameter in self.base.parameters(): | |
| parameter.requires_grad = False | |
| def forward(self, x): | |
| base_out = self.base(x) | |
| lora_out = self.lora_b(self.lora_a(self.dropout(x))) * self.scale | |
| return base_out + lora_out | |
| def load_config(config_path): | |
| with open(config_path, encoding="utf-8") as handle: | |
| raw = yaml.safe_load(handle) | |
| cfg = { | |
| "auto_config": raw.get("auto_config", True), | |
| "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"), | |
| "hf_model_file": raw.get("hf_model_file", "sft_v1/final/model.pth"), | |
| "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/input_models/luna_sft_v1/model.pth"), | |
| "train_json": raw.get("train_json", "Base/Datasets/rag_mcp_sft/train.json"), | |
| "val_json": raw.get("val_json", "Base/Datasets/rag_mcp_sft/val.json"), | |
| "out_dir": raw.get("out_dir", "Base/out/sft/rag_mcp_lora"), | |
| "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"), | |
| "vocab_size": raw["model"]["vocab_size"], | |
| "seq_len": raw["model"]["seq_len"], | |
| "n_layer": raw["model"]["n_layer"], | |
| "n_embd": raw["model"]["n_embd"], | |
| "n_head": raw["model"]["n_head"], | |
| "epochs": raw["train"]["epochs"], | |
| "lr_warmup_steps": raw["train"]["lr_warmup_steps"], | |
| "save_interval": raw["train"]["save_interval"], | |
| "log_interval": raw["train"]["log_interval"], | |
| "eval_interval": raw["train"]["eval_interval"], | |
| "max_norm": raw["train"]["max_norm"], | |
| "lr": raw["optimizer"]["lr"], | |
| "min_lr": raw["optimizer"]["min_lr"], | |
| "weight_decay": raw["optimizer"]["weight_decay"], | |
| "betas": tuple(raw["optimizer"]["betas"]), | |
| "eps": raw["optimizer"]["eps"], | |
| "global_batch": raw["batch"]["global_batch"], | |
| "micro_batch": raw["batch"]["micro_batch"], | |
| "grad_accum": raw["batch"]["grad_accum"], | |
| "auto_probe_batch": raw["batch"].get("auto_probe_batch", True), | |
| "probe_safety": raw["batch"].get("probe_safety", 0.94), | |
| "num_workers": raw["dataloader"]["num_workers"], | |
| "pin_memory": raw["dataloader"]["pin_memory"], | |
| "precision": raw["hardware"]["precision"], | |
| "eval_prompts": raw.get("eval_prompts", []), | |
| "lora_rank": raw["lora"]["rank"], | |
| "lora_alpha": raw["lora"]["alpha"], | |
| "lora_dropout": raw["lora"]["dropout"], | |
| "target_modules": list(raw["lora"]["target_modules"]), | |
| } | |
| return cfg | |
| def resolve_checkpoint(cfg): | |
| ckpt_path = Path(cfg["pretrained_ckpt"]) | |
| if ckpt_path.exists(): | |
| return ckpt_path | |
| ckpt_path.parent.mkdir(parents=True, exist_ok=True) | |
| hf_hub_download( | |
| repo_id=cfg["hf_model_repo"], | |
| filename=cfg["hf_model_file"], | |
| local_dir=str(ckpt_path.parent), | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| downloaded = ckpt_path.parent / cfg["hf_model_file"] | |
| if not downloaded.exists(): | |
| raise FileNotFoundError(f"Expected downloaded checkpoint at {downloaded}") | |
| return downloaded | |
| def inject_lora(model, target_modules, rank, alpha, dropout): | |
| replaced = [] | |
| for module_name, module in list(model.named_modules()): | |
| if not isinstance(module, nn.Linear): | |
| continue | |
| if not any(module_name.endswith(target) for target in target_modules): | |
| continue | |
| parent_name, _, child_name = module_name.rpartition(".") | |
| parent_module = model.get_submodule(parent_name) if parent_name else model | |
| wrapped = LoRALinear(module, rank=rank, alpha=alpha, dropout=dropout) | |
| wrapped = wrapped.to(device=module.weight.device, dtype=module.weight.dtype) | |
| setattr(parent_module, child_name, wrapped) | |
| replaced.append(module_name) | |
| if not replaced: | |
| raise RuntimeError("No target modules matched for LoRA injection") | |
| return replaced | |
| def get_lora_state_dict(model): | |
| state_dict = model.state_dict() | |
| return { | |
| name: tensor.cpu() | |
| for name, tensor in state_dict.items() | |
| if "lora_a.weight" in name or "lora_b.weight" in name | |
| } | |
| def count_trainable_parameters(model): | |
| return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad) | |
| def probe_max_micro_batch_lora(model, trainable_parameters, device, dtype, seq_len, vocab_size, safety=0.94, grad_accum_sim=2): | |
| if device.type != "cuda": | |
| return 1 | |
| optimizer = torch.optim.AdamW(trainable_parameters, lr=1e-4) | |
| lo, hi, best = 1, 512, 1 | |
| while lo <= hi: | |
| mid = (lo + hi) // 2 | |
| try: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| optimizer.zero_grad(set_to_none=True) | |
| for _ in range(grad_accum_sim): | |
| input_ids = torch.randint(0, vocab_size, (mid, seq_len), device=device) | |
| loss_mask = torch.ones_like(input_ids) | |
| with autocast(device_type="cuda", dtype=dtype): | |
| _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False) | |
| loss = loss / grad_accum_sim | |
| loss.backward() | |
| del input_ids, loss_mask, loss | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| best = mid | |
| lo = mid + 1 | |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as error: | |
| if "out of memory" not in str(error).lower() and not isinstance(error, torch.cuda.OutOfMemoryError): | |
| raise | |
| optimizer.zero_grad(set_to_none=True) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| hi = mid - 1 | |
| del optimizer | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| safe = max(1, int(best * safety)) | |
| print(f" LoRA batch probe: max_micro_batch={best}, using {safe} ({int(safety * 100)}% safety)") | |
| return safe | |
| def load_base_weights(model, checkpoint_path, device): | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) | |
| state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint | |
| model.load_state_dict(state_dict, strict=True) | |
| def train(cfg): | |
| hw = probe_hardware() | |
| device = torch.device(hw["device"]) | |
| dtype = hw.get("dtype", torch.float32) if cfg["auto_config"] else { | |
| "bf16": torch.bfloat16, | |
| "fp16": torch.float16, | |
| "fp32": torch.float32, | |
| }.get(cfg["precision"], torch.float32) | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"]) | |
| ckpt_path = resolve_checkpoint(cfg) | |
| model = LUNAModel( | |
| vocab_size=cfg["vocab_size"], | |
| block_size=cfg["seq_len"], | |
| n_layer=cfg["n_layer"], | |
| n_embd=cfg["n_embd"], | |
| n_head=cfg["n_head"], | |
| ).to(device) | |
| load_base_weights(model, ckpt_path, device) | |
| for parameter in model.parameters(): | |
| parameter.requires_grad = False | |
| replaced = inject_lora( | |
| model, | |
| target_modules=cfg["target_modules"], | |
| rank=cfg["lora_rank"], | |
| alpha=cfg["lora_alpha"], | |
| dropout=cfg["lora_dropout"], | |
| ) | |
| trainable_params = count_trainable_parameters(model) | |
| total_params = sum(parameter.numel() for parameter in model.parameters()) | |
| trainable_parameters = [parameter for parameter in model.parameters() if parameter.requires_grad] | |
| if cfg["auto_config"] and device.type == "cuda" and cfg["auto_probe_batch"]: | |
| print(" Probing LoRA micro_batch against available VRAM...") | |
| cfg["micro_batch"] = probe_max_micro_batch_lora( | |
| model, | |
| trainable_parameters=trainable_parameters, | |
| device=device, | |
| dtype=dtype, | |
| seq_len=cfg["seq_len"], | |
| vocab_size=cfg["vocab_size"], | |
| safety=cfg["probe_safety"], | |
| ) | |
| cfg["grad_accum"] = max(1, math.ceil(cfg["global_batch"] / cfg["micro_batch"])) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| effective_batch = cfg["micro_batch"] * cfg["grad_accum"] | |
| train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"]) | |
| val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"]) if Path(cfg["val_json"]).exists() else None | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=cfg["micro_batch"], | |
| shuffle=True, | |
| num_workers=cfg["num_workers"], | |
| pin_memory=cfg["pin_memory"], | |
| drop_last=True, | |
| prefetch_factor=4 if cfg["num_workers"] > 0 else None, | |
| persistent_workers=cfg["num_workers"] > 0, | |
| ) | |
| val_loader = None | |
| if val_dataset is not None: | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=cfg["micro_batch"], | |
| shuffle=False, | |
| num_workers=min(2, cfg["num_workers"]), | |
| pin_memory=cfg["pin_memory"], | |
| drop_last=False, | |
| ) | |
| optimizer = torch.optim.AdamW( | |
| trainable_parameters, | |
| lr=cfg["lr"], | |
| weight_decay=cfg["weight_decay"], | |
| betas=cfg["betas"], | |
| eps=cfg["eps"], | |
| ) | |
| scaler = GradScaler(enabled=(device.type == "cuda" and dtype == torch.float16)) | |
| steps_per_epoch = max(1, len(train_loader) // cfg["grad_accum"]) | |
| total_steps = steps_per_epoch * cfg["epochs"] | |
| warmup_steps = min(cfg["lr_warmup_steps"], max(1, total_steps // 5)) | |
| out_dir = Path(cfg["out_dir"]) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| best_val_loss = float("inf") | |
| step = 0 | |
| latest_path = out_dir / "latest.pt" | |
| if latest_path.exists(): | |
| checkpoint = torch.load(latest_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["adapter"], strict=False) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| step = checkpoint["step"] | |
| print(SEP) | |
| print(" LUNA 100M - LoRA SFT") | |
| print(SEP) | |
| print(f" Base checkpoint : {ckpt_path}") | |
| print(f" Train dataset : {cfg['train_json']}") | |
| print(f" Val dataset : {cfg['val_json']}") | |
| print(f" Output dir : {out_dir}") | |
| print(f" Device : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)") | |
| print(f" Precision : {cfg['precision']} dtype={dtype}") | |
| print(f" LoRA modules : {', '.join(replaced)}") | |
| print(f" Trainable params: {trainable_params:,} / {total_params:,}") | |
| print(f" micro_batch : {cfg['micro_batch']}") | |
| print(f" grad_accum : {cfg['grad_accum']}") | |
| print(f" effective_batch : {effective_batch}") | |
| print(f" Train samples : {len(train_dataset):,}") | |
| print(f" Val samples : {len(val_dataset):,}" if val_dataset is not None else " Val samples : 0") | |
| print(SEP) | |
| if cfg["eval_prompts"] and step == 0: | |
| run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir) | |
| model.train() | |
| run_t0 = time.perf_counter() | |
| for epoch in range(cfg["epochs"]): | |
| micro_step = 0 | |
| for input_ids, loss_mask in train_loader: | |
| current_global_step = epoch * steps_per_epoch + (micro_step // cfg["grad_accum"]) | |
| if current_global_step < step and (micro_step % cfg["grad_accum"] == cfg["grad_accum"] - 1): | |
| micro_step += 1 | |
| continue | |
| if current_global_step >= total_steps: | |
| break | |
| input_ids = input_ids.to(device, non_blocking=True) | |
| loss_mask = loss_mask.to(device, non_blocking=True) | |
| step_start = time.perf_counter() | |
| with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): | |
| _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False) | |
| loss = loss / cfg["grad_accum"] | |
| scaler.scale(loss).backward() | |
| micro_step += 1 | |
| if micro_step % cfg["grad_accum"] != 0: | |
| continue | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(trainable_parameters, cfg["max_norm"]) | |
| lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"]) | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = lr_now | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad(set_to_none=True) | |
| if device.type == "cuda": | |
| torch.cuda.synchronize() | |
| dt = time.perf_counter() - step_start | |
| step += 1 | |
| if step % cfg["log_interval"] == 0 or step <= 3: | |
| tokens_step = effective_batch * cfg["seq_len"] | |
| tps = tokens_step / max(dt, 1e-6) | |
| vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0 | |
| eta_h = (total_steps - step) * dt / 3600 | |
| print( | |
| f" step {step:6d}/{total_steps} | epoch {epoch + 1}/{cfg['epochs']} | " | |
| f"loss {loss.item() * cfg['grad_accum']:.4f} | lr {lr_now:.2e} | " | |
| f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h" | |
| ) | |
| if step % cfg["save_interval"] == 0 or step == total_steps: | |
| step_dir = out_dir / f"step-{step:06d}" | |
| step_dir.mkdir(parents=True, exist_ok=True) | |
| adapter_state = get_lora_state_dict(model) | |
| torch.save(adapter_state, step_dir / "adapter_model.pt") | |
| torch.save( | |
| { | |
| "step": step, | |
| "adapter": adapter_state, | |
| "optimizer": optimizer.state_dict(), | |
| "epoch": epoch, | |
| "loss": loss.item() * cfg["grad_accum"], | |
| }, | |
| latest_path, | |
| ) | |
| print(f" Saved -> {step_dir}") | |
| if step % cfg["eval_interval"] == 0 or step == total_steps: | |
| if val_loader is not None: | |
| model.eval() | |
| val_loss_sum = 0.0 | |
| val_count = 0 | |
| with torch.no_grad(): | |
| for val_ids, val_mask in val_loader: | |
| val_ids = val_ids.to(device, non_blocking=True) | |
| val_mask = val_mask.to(device, non_blocking=True) | |
| with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): | |
| _, val_loss = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False) | |
| val_loss_sum += val_loss.item() | |
| val_count += 1 | |
| if val_count >= 50: | |
| break | |
| avg_val = val_loss_sum / max(val_count, 1) | |
| print(f" Val loss: {avg_val:.4f}") | |
| if avg_val < best_val_loss: | |
| best_val_loss = avg_val | |
| torch.save(get_lora_state_dict(model), out_dir / "best_adapter_model.pt") | |
| print(" New best! Saved best_adapter_model.pt") | |
| model.train() | |
| if cfg["eval_prompts"]: | |
| run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, step, out_dir) | |
| final_dir = out_dir / "final" | |
| final_dir.mkdir(parents=True, exist_ok=True) | |
| torch.save(get_lora_state_dict(model), final_dir / "adapter_model.pt") | |
| torch.save( | |
| { | |
| "step": step, | |
| "adapter": get_lora_state_dict(model), | |
| "lora_rank": cfg["lora_rank"], | |
| "lora_alpha": cfg["lora_alpha"], | |
| "lora_dropout": cfg["lora_dropout"], | |
| "target_modules": cfg["target_modules"], | |
| "base_checkpoint": str(ckpt_path), | |
| }, | |
| final_dir / "adapter_bundle.pt", | |
| ) | |
| total_h = (time.perf_counter() - run_t0) / 3600 | |
| print(SEP) | |
| print(f" LoRA SFT complete in {total_h:.2f}h -> {final_dir}") | |
| print(f" Best val loss: {best_val_loss:.4f}") | |
| print(SEP) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="LUNA 100M - LoRA SFT") | |
| parser.add_argument("--config", default="rag_mcp_lora_config.yaml") | |
| parser.add_argument("--pretrained_ckpt", default=None) | |
| parser.add_argument("--train_json", default=None) | |
| parser.add_argument("--val_json", default=None) | |
| parser.add_argument("--out_dir", default=None) | |
| parser.add_argument("--epochs", type=int, default=None) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| cfg = load_config(args.config) | |
| for key in ("pretrained_ckpt", "train_json", "val_json", "out_dir"): | |
| value = getattr(args, key) | |
| if value: | |
| cfg[key] = value | |
| if args.epochs is not None: | |
| cfg["epochs"] = args.epochs | |
| train(cfg) | |
| if __name__ == "__main__": | |
| main() |