""" LUNA 100M — SFT Fine-Tuning Script ==================================== Fine-tunes the pretrained LUNA-100M on instruction-following (SFT) data. Features: - Loads pretrained checkpoint (latest.pt from pretraining) - JSON-based SFT dataset (instruction/input/output format) - Prompt masking: loss computed only on the output portion - Checkpoint eval: runs identity + knowledge prompts after each save - Cosine LR with warmup - Auto hardware detection (same as train.py) Usage: python sft_train.py # uses sft_config.yaml python sft_train.py --config sft_config.yaml # explicit config python sft_train.py --train_json /data/train.json # override data path """ import os import gc import sys import math import time import json import argparse import yaml import psutil import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from pathlib import Path os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") # ─── Model (identical to train.py) ─────────────────────────────────────────── class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=1024): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len).float() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, seq_len): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary(x, cos, sin): c = cos.unsqueeze(0).unsqueeze(0) s = sin.unsqueeze(0).unsqueeze(0) return x * c + rotate_half(x) * s class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.rot_dim = int(self.head_dim * rotary_pct) self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) self.c_proj = nn.Linear(n_embd, n_embd, bias=True) self.rotary = RotaryEmbedding(self.rot_dim, block_size) def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) cos, sin = self.rotary(T) q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1) k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C)) class MLP(nn.Module): def __init__(self, n_embd): super().__init__() self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True) self.gelu = nn.GELU() self.proj = nn.Linear(4 * n_embd, n_embd, bias=True) def forward(self, x): return self.proj(self.gelu(self.fc(x))) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class LUNAModel(nn.Module): def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head): super().__init__() self.block_size = block_size self.wte = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.lm_head.weight = self.wte.weight # tied def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): m.weight.data.normal_(mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: m.bias.data.zero_() def forward(self, idx, targets=None, loss_mask=None, return_logits=True): x = self.wte(idx) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: shift_logits = logits[:, :-1, :].contiguous() shift_targets = targets[:, 1:].contiguous() if loss_mask is not None: shift_mask = loss_mask[:, 1:].contiguous() # Only compute loss on output tokens flat_logits = shift_logits.view(-1, shift_logits.size(-1)) flat_targets = shift_targets.view(-1) flat_mask = shift_mask.view(-1).float() per_token_loss = F.cross_entropy(flat_logits, flat_targets, reduction='none') loss = (per_token_loss * flat_mask).sum() / flat_mask.sum().clamp(min=1) else: loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_targets.view(-1) ) if not return_logits: logits = None return logits, loss @property def num_params(self): return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel() # ─── SFT Dataset ────────────────────────────────────────────────────────────── class SFTDataset(torch.utils.data.Dataset): """ Loads JSON SFT data (instruction/input/output) and tokenizes with prompt masking. Format per entry: {"instruction": "...", "input": "...", "output": "..."} Prompt template (Alpaca-style): ### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}<|endoftext|> Loss mask: 0 for prompt tokens, 1 for response tokens (including EOS). """ def __init__(self, json_path, tokenizer, max_len=1024): with open(json_path, "r", encoding="utf-8") as f: self.data = json.load(f) self.tokenizer = tokenizer self.max_len = max_len self.eos_id = tokenizer.eos_token_id or 0 def __len__(self): return len(self.data) def _format_prompt(self, entry): inst = entry.get("instruction", "").strip() inp = entry.get("input", "").strip() out = entry.get("output", "").strip() if inst and inp: prompt = f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n" elif inst: prompt = f"### Instruction:\n{inst}\n\n### Response:\n" else: # input-only format prompt = f"### Input:\n{inp}\n\n### Response:\n" return prompt, out def __getitem__(self, idx): entry = self.data[idx] prompt, response = self._format_prompt(entry) prompt_ids = self.tokenizer.encode(prompt) response_ids = self.tokenizer.encode(response) + [self.eos_id] total_ids = prompt_ids + response_ids # Truncate to max_len if len(total_ids) > self.max_len: total_ids = total_ids[:self.max_len] # Ensure EOS at end total_ids[-1] = self.eos_id # Recalculate prompt boundary prompt_len = min(len(prompt_ids), self.max_len) else: prompt_len = len(prompt_ids) # Build loss mask: 0 for prompt, 1 for response loss_mask = [0] * prompt_len + [1] * (len(total_ids) - prompt_len) # Pad to max_len pad_len = self.max_len - len(total_ids) total_ids = total_ids + [self.eos_id] * pad_len loss_mask = loss_mask + [0] * pad_len # don't compute loss on padding input_ids = torch.tensor(total_ids, dtype=torch.long) loss_mask = torch.tensor(loss_mask, dtype=torch.long) return input_ids, loss_mask # ─── Generation (for eval) ─────────────────────────────────────────────────── @torch.no_grad() def generate(model, input_ids, max_new=150, temperature=0.7, top_p=0.9, top_k=40, device="cpu"): model.eval() ids = input_ids.clone().to(device) for _ in range(max_new): ctx = ids[:, -model.block_size:] logits, _ = model(ctx) logits = logits[:, -1, :] / max(temperature, 1e-8) if top_k > 0: vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < vals[:, -1:]] = -float("inf") probs = torch.softmax(logits, dim=-1) if top_p < 1.0: sorted_probs, sorted_idx = torch.sort(probs, descending=True) cum = torch.cumsum(sorted_probs, dim=-1) mask = cum - sorted_probs > top_p sorted_probs[mask] = 0.0 sorted_probs /= sorted_probs.sum() next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] else: next_token = torch.multinomial(probs[0], 1) ids = torch.cat([ids, next_token.view(1, 1)], dim=1) if next_token.item() == 0: # EOS break model.train() return ids[0, input_ids.size(1):] def run_eval_prompts(model, tokenizer, prompts, device, step, out_dir): """Run eval prompts and print + log results.""" model.eval() results = [] sep = "─" * 60 print(f"\n{sep}") print(f" EVAL @ step {step}") print(sep) for prompt_text in prompts: # Format as instruction formatted = f"### Instruction:\n{prompt_text}\n\n### Response:\n" ids = tokenizer.encode(formatted, return_tensors="pt").to(device) out_ids = generate(model, ids, max_new=150, temperature=0.7, device=device) response = tokenizer.decode(out_ids.tolist(), skip_special_tokens=True).strip() print(f" Q: {prompt_text}") print(f" A: {response[:200]}") print() results.append({"prompt": prompt_text, "response": response[:500]}) print(sep) # Save eval log eval_dir = Path(out_dir) / "evals" eval_dir.mkdir(parents=True, exist_ok=True) with open(eval_dir / f"eval_step_{step:06d}.json", "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) model.train() return results # ─── Hardware Detection (same as train.py) ──────────────────────────────────── def probe_hardware(): info = { "cpu_cores": os.cpu_count() or 4, "ram_gb": psutil.virtual_memory().total / 1024**3, } if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) info.update({ "device": "cuda", "gpu_name": props.name, "vram_gb": props.total_memory / 1024**3, "sm_major": props.major, }) if props.major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True info["precision"] = "bf16" info["dtype"] = torch.bfloat16 else: info["precision"] = "fp16" info["dtype"] = torch.float16 else: info.update({ "device": "cpu", "gpu_name": "CPU", "vram_gb": 0, "sm_major": 0, "precision": "fp32", "dtype": torch.float32, }) return info def probe_max_batch(model, device, dtype, seq_len, vocab_size, grad_accum_sim=4): """Binary search for max micro_batch. Safety: x0.70.""" tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4) lo, hi, best = 1, 512, 1 while lo <= hi: mid = (lo + hi) // 2 try: torch.cuda.empty_cache(); gc.collect() tmp_opt.zero_grad(set_to_none=True) for _ in range(grad_accum_sim): x = torch.randint(0, vocab_size, (mid, seq_len), device=device) mask = torch.ones_like(x) with autocast(device_type="cuda", dtype=dtype): _, loss = model(x, x, loss_mask=mask, return_logits=False) loss = loss / grad_accum_sim loss.backward() del x, mask, loss tmp_opt.step() tmp_opt.zero_grad(set_to_none=True) best = mid; lo = mid + 1 torch.cuda.empty_cache() except (torch.cuda.OutOfMemoryError, RuntimeError) as e: if "out of memory" in str(e).lower() or isinstance(e, torch.cuda.OutOfMemoryError): try: del x, mask, loss except: pass torch.cuda.empty_cache() tmp_opt.zero_grad(set_to_none=True) hi = mid - 1 else: raise del tmp_opt; torch.cuda.empty_cache(); gc.collect() safe = max(1, int(best * 0.70)) print(f" Probe: max_batch={best}, using {safe} (70% safety)") return safe # ─── LR Schedule ────────────────────────────────────────────────────────────── def cosine_lr(step, warmup, total, lr_max, lr_min): if step < warmup: return lr_max * (step + 1) / warmup p = (step - warmup) / max(1, total - warmup) return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min) # ─── Config ─────────────────────────────────────────────────────────────────── def load_sft_config(config_path): with open(config_path, encoding="utf-8") as f: raw = yaml.safe_load(f) cfg = { "auto_config": raw.get("auto_config", True), "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"), "hf_model_file": raw.get("hf_model_file", "latest.pt"), "hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"), "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"), "train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"), "val_json": raw.get("val_json", "Base/Datasets/sft_clean/val.json"), "out_dir": raw.get("out_dir", "Base/out/sft/luna_100m_sft"), "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"), # model "vocab_size": raw["model"]["vocab_size"], "seq_len": raw["model"]["seq_len"], "n_layer": raw["model"]["n_layer"], "n_embd": raw["model"]["n_embd"], "n_head": raw["model"]["n_head"], # train "epochs": raw["train"]["epochs"], "max_tokens": raw["train"].get("max_tokens", 0), "lr_warmup_steps": raw["train"]["lr_warmup_steps"], "save_interval": raw["train"]["save_interval"], "log_interval": raw["train"]["log_interval"], "eval_interval": raw["train"]["eval_interval"], "max_norm": raw["train"]["max_norm"], # optimizer "lr": raw["optimizer"]["lr"], "min_lr": raw["optimizer"]["min_lr"], "weight_decay": raw["optimizer"]["weight_decay"], "betas": tuple(raw["optimizer"]["betas"]), "eps": raw["optimizer"]["eps"], # batch "global_batch": raw["batch"]["global_batch"], "micro_batch": raw["batch"]["micro_batch"], "grad_accum": raw["batch"]["grad_accum"], # dataloader "num_workers": raw["dataloader"]["num_workers"], "pin_memory": raw["dataloader"]["pin_memory"], # hardware "precision": raw["hardware"]["precision"], # eval prompts "eval_prompts": raw.get("eval_prompts", []), } return cfg # ─── Training ───────────────────────────────────────────────────────────────── SEP = "=" * 72 def sft_train(cfg): hw = probe_hardware() device = torch.device(hw["device"]) if device.type == "cuda": torch.cuda.empty_cache(); gc.collect() # Precision if cfg["auto_config"]: dtype = hw.get("dtype", torch.float32) cfg["precision"] = hw["precision"] else: dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}.get(cfg["precision"], torch.float32) print(SEP) print(" LUNA 100M - SFT Fine-Tuning") print(SEP) print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)") print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores") print(f" Precision : {cfg['precision']} dtype={dtype}") print(f" Pretrained : {cfg['pretrained_ckpt']}") # ── Tokenizer ───────────────────────────────────────────────────────────── from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"]) print(f" Tokenizer : {cfg['tokenizer_dir']} (vocab={tokenizer.vocab_size})") # ── Model ───────────────────────────────────────────────────────────────── print(f"\n Building LUNA-100M...") model = LUNAModel( vocab_size=cfg["vocab_size"], block_size=cfg["seq_len"], n_layer=cfg["n_layer"], n_embd=cfg["n_embd"], n_head=cfg["n_head"], ).to(device) print(f" Parameters: {model.num_params:,} (unique)") # ── Load pretrained weights ─────────────────────────────────────────────── ckpt_path = Path(cfg["pretrained_ckpt"]) if not ckpt_path.exists() and cfg.get("hf_model_repo"): # Auto-download from HuggingFace model repo print(f"\n Pretrained checkpoint not found locally.") print(f" Downloading from HuggingFace: {cfg['hf_model_repo']} ({cfg['hf_model_file']})") from huggingface_hub import hf_hub_download ckpt_path.parent.mkdir(parents=True, exist_ok=True) downloaded = hf_hub_download( repo_id=cfg["hf_model_repo"], filename=cfg["hf_model_file"], local_dir=str(ckpt_path.parent), token=os.environ.get("HF_TOKEN"), ) downloaded_path = Path(downloaded) if not ckpt_path.exists() and downloaded_path.exists(): ckpt_path = downloaded_path print(f" Downloaded to: {ckpt_path}") if ckpt_path.exists(): print(f"\n Loading pretrained checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) state = ckpt["model"] if "model" in ckpt else ckpt model.load_state_dict(state, strict=True) pretrain_step = ckpt.get("step", "?") pretrain_tokens = ckpt.get("tokens_seen", 0) print(f" Pretrained @ step {pretrain_step}, tokens seen: {pretrain_tokens:,}") # Do NOT load optimizer state — we start fresh for SFT else: print(f"\n WARNING: No pretrained checkpoint at {ckpt_path}") print(f" Training from scratch (not recommended for SFT)!") # ── Dataset (auto-download from HF if missing) ───────────────────────────── train_path = Path(cfg["train_json"]) val_path = Path(cfg["val_json"]) if cfg["val_json"] else None if not train_path.exists() and cfg.get("hf_dataset_repo"): print(f"\n SFT dataset not found locally.") print(f" Downloading from HuggingFace: {cfg['hf_dataset_repo']}") from huggingface_hub import hf_hub_download train_path.parent.mkdir(parents=True, exist_ok=True) hf_hub_download( repo_id=cfg["hf_dataset_repo"], repo_type="dataset", filename="train.json", local_dir=str(train_path.parent), token=os.environ.get("HF_TOKEN"), ) print(f" Downloaded train.json") if val_path: hf_hub_download( repo_id=cfg["hf_dataset_repo"], repo_type="dataset", filename="val.json", local_dir=str(val_path.parent), token=os.environ.get("HF_TOKEN"), ) print(f" Downloaded val.json") print(f"\n Train data: {cfg['train_json']}") train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"]) print(f" Train entries: {len(train_dataset):,}") val_dataset = None if cfg["val_json"] and Path(cfg["val_json"]).exists(): val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"]) print(f" Val entries: {len(val_dataset):,}") # ── Batch sizing ────────────────────────────────────────────────────────── if cfg["auto_config"] and device.type == "cuda": print(f"\n Probing max micro_batch_size...") max_mbs = probe_max_batch(model, device, dtype, cfg["seq_len"], cfg["vocab_size"]) model.load_state_dict(state, strict=True) # reinit after probe torch.cuda.empty_cache(); gc.collect() grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs)) effective_batch = max_mbs * grad_accum else: max_mbs = cfg["micro_batch"] grad_accum = cfg["grad_accum"] effective_batch = max_mbs * grad_accum print(f" micro_batch={max_mbs}, grad_accum={grad_accum}, effective={effective_batch}") # ── DataLoader ──────────────────────────────────────────────────────────── train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=max_mbs, shuffle=True, num_workers=cfg["num_workers"], pin_memory=cfg["pin_memory"], drop_last=True, prefetch_factor=4 if cfg["num_workers"] > 0 else None, persistent_workers=cfg["num_workers"] > 0, ) val_loader = None if val_dataset: val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=max_mbs, shuffle=False, num_workers=min(2, cfg["num_workers"]), pin_memory=cfg["pin_memory"], drop_last=False, ) # ── Optimizer ───────────────────────────────────────────────────────────── try: optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"], betas=cfg["betas"], eps=cfg["eps"], fused=True, ) except TypeError: optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"], betas=cfg["betas"], eps=cfg["eps"], ) use_scaler = dtype == torch.float16 scaler = GradScaler(enabled=use_scaler) # ── Schedule ────────────────────────────────────────────────────────────── steps_per_epoch = len(train_loader) // grad_accum total_steps = steps_per_epoch * cfg["epochs"] warmup_steps = min(cfg["lr_warmup_steps"], total_steps // 5) out_dir = Path(cfg["out_dir"]) out_dir.mkdir(parents=True, exist_ok=True) print(f"\n Epochs : {cfg['epochs']}") print(f" Steps/epoch : {steps_per_epoch:,}") print(f" Total steps : {total_steps:,}") print(f" Warmup steps : {warmup_steps}") print(f" LR : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}") print(f" Save every : {cfg['save_interval']} steps") print(f" Eval every : {cfg['eval_interval']} steps") print(f" Eval prompts : {len(cfg['eval_prompts'])}") print(f" Out dir : {out_dir}") print(SEP) # ── Resume SFT ──────────────────────────────────────────────────────────── start_step = 0 sft_ckpt_path = out_dir / "latest.pt" if sft_ckpt_path.exists(): print(f"\n Resuming SFT from {sft_ckpt_path}...") sft_ckpt = torch.load(sft_ckpt_path, map_location=device, weights_only=True) model.load_state_dict(sft_ckpt["model"]) optimizer.load_state_dict(sft_ckpt["optimizer"]) start_step = sft_ckpt["step"] print(f" Resumed at SFT step {start_step}") # ── Eval at start ───────────────────────────────────────────────────────── if cfg["eval_prompts"] and start_step == 0: print("\n Running initial eval (before SFT)...") run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir) # ── Training loop ───────────────────────────────────────────────────────── model.train() run_t0 = time.perf_counter() step = start_step best_val_loss = float("inf") print(f"\n Starting SFT training (step {start_step} -> {total_steps})...") for epoch in range(cfg["epochs"]): data_iter = iter(train_loader) micro_step = 0 for batch_idx, (input_ids, loss_mask) in enumerate(data_iter): # Skip already-done steps on resume current_global_step = epoch * steps_per_epoch + (micro_step // grad_accum) if current_global_step < start_step and (micro_step % grad_accum == grad_accum - 1): micro_step += 1 continue if current_global_step >= total_steps: break input_ids = input_ids.to(device, non_blocking=True) loss_mask = loss_mask.to(device, non_blocking=True) t0 = time.perf_counter() # Accumulation step with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False) loss = loss / grad_accum scaler.scale(loss).backward() micro_step += 1 # Optimizer step after grad_accum micro-batches if micro_step % grad_accum == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"]) # LR schedule lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"]) for pg in optimizer.param_groups: pg["lr"] = lr_now scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if device.type == "cuda": torch.cuda.synchronize() dt = time.perf_counter() - t0 step += 1 # ── Log ─────────────────────────────────────────────────────── if step % cfg["log_interval"] == 0 or step <= 3: tokens_step = effective_batch * cfg["seq_len"] tps = tokens_step / dt if dt > 0 else 0 vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0 eta_h = (total_steps - step) * dt / 3600 print(f" step {step:6d}/{total_steps} | epoch {epoch+1}/{cfg['epochs']} | " f"loss {loss.item()*grad_accum:.4f} | lr {lr_now:.2e} | " f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h") # ── Save checkpoint ─────────────────────────────────────────── if step % cfg["save_interval"] == 0 or step == total_steps: raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model step_dir = out_dir / f"step-{step:06d}" step_dir.mkdir(parents=True, exist_ok=True) torch.save(raw_model.state_dict(), step_dir / "model.pth") torch.save({ "step": step, "model": raw_model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "sft_loss": loss.item() * grad_accum, }, out_dir / "latest.pt") print(f" Saved -> {step_dir}") # ── Eval ────────────────────────────────────────────────────── if step % cfg["eval_interval"] == 0 or step == total_steps: # Validation loss if val_loader: model.eval() val_loss_sum = 0.0 val_count = 0 with torch.no_grad(): for val_ids, val_mask in val_loader: val_ids = val_ids.to(device, non_blocking=True) val_mask = val_mask.to(device, non_blocking=True) with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): _, vl = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False) val_loss_sum += vl.item() val_count += 1 if val_count >= 50: # cap eval to 50 batches break avg_val = val_loss_sum / max(val_count, 1) print(f" Val loss: {avg_val:.4f}") if avg_val < best_val_loss: best_val_loss = avg_val raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model torch.save(raw_model.state_dict(), out_dir / "best_model.pth") print(f" New best! Saved best_model.pth") model.train() # Run eval prompts if cfg["eval_prompts"]: run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, step, out_dir) # ── Final ───────────────────────────────────────────────────────────────── final_dir = out_dir / "final" final_dir.mkdir(parents=True, exist_ok=True) raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model torch.save(raw_model.state_dict(), final_dir / "model.pth") torch.save({ "step": step, "model": raw_model.state_dict(), "sft_complete": True, }, out_dir / "latest.pt") # Copy tokenizer import shutil tok_src = Path(cfg["tokenizer_dir"]) if tok_src.exists(): shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True) total_h = (time.perf_counter() - run_t0) / 3600 print(SEP) print(f" SFT Complete! {total_h:.2f}h -> {final_dir}") print(f" Best val loss: {best_val_loss:.4f}") print(SEP) # ─── Entry ──────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description="LUNA 100M — SFT Fine-Tuning") p.add_argument("--config", default="sft_config.yaml") p.add_argument("--pretrained_ckpt", default=None) p.add_argument("--train_json", default=None) p.add_argument("--val_json", default=None) p.add_argument("--out_dir", default=None) p.add_argument("--epochs", type=int, default=None) p.add_argument("--lr", type=float, default=None) p.add_argument("--micro_batch",type=int, default=None) p.add_argument("--global_batch",type=int, default=None) p.add_argument("--save_interval", type=int, default=None) p.add_argument("--eval_interval", type=int, default=None) p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"), default=None) return p.parse_args() if __name__ == "__main__": args = parse_args() cfg = load_sft_config(args.config) # CLI overrides for key, val in vars(args).items(): if key != "config" and val is not None: cfg[key] = val sft_train(cfg)