| """
|
| LUNA 100M β SFT Fine-Tuning Script
|
| ====================================
|
| Fine-tunes the pretrained LUNA-100M on instruction-following (SFT) data.
|
|
|
| Features:
|
| - Loads pretrained checkpoint (latest.pt from pretraining)
|
| - JSON-based SFT dataset (instruction/input/output format)
|
| - Prompt masking: loss computed only on the output portion
|
| - Checkpoint eval: runs identity + knowledge prompts after each save
|
| - Cosine LR with warmup
|
| - Auto hardware detection (same as train.py)
|
|
|
| Usage:
|
| python sft_train.py # uses sft_config.yaml
|
| python sft_train.py --config sft_config.yaml # explicit config
|
| python sft_train.py --train_json /data/train.json # override data path
|
| """
|
|
|
| import os
|
| import gc
|
| import sys
|
| import math
|
| import time
|
| import json
|
| import argparse
|
| import yaml
|
| import psutil
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.amp import autocast, GradScaler
|
| from pathlib import Path
|
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
|
|
|
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| def __init__(self, dim, max_seq_len=1024):
|
| super().__init__()
|
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| self.register_buffer("inv_freq", inv_freq)
|
| t = torch.arange(max_seq_len).float()
|
| freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| emb = torch.cat([freqs, freqs], dim=-1)
|
| self.register_buffer("cos_cached", emb.cos())
|
| self.register_buffer("sin_cached", emb.sin())
|
|
|
| def forward(self, seq_len):
|
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
|
|
|
|
| def rotate_half(x):
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat([-x2, x1], dim=-1)
|
|
|
|
|
| def apply_rotary(x, cos, sin):
|
| c = cos.unsqueeze(0).unsqueeze(0)
|
| s = sin.unsqueeze(0).unsqueeze(0)
|
| return x * c + rotate_half(x) * s
|
|
|
|
|
| class CausalSelfAttention(nn.Module):
|
| def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
|
| super().__init__()
|
| self.n_head = n_head
|
| self.head_dim = n_embd // n_head
|
| self.rot_dim = int(self.head_dim * rotary_pct)
|
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
|
| self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
|
| self.rotary = RotaryEmbedding(self.rot_dim, block_size)
|
|
|
| def forward(self, x):
|
| B, T, C = x.size()
|
| qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
|
| q, k, v = qkv.unbind(0)
|
| cos, sin = self.rotary(T)
|
| q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
|
| k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
|
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
|
|
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, n_embd):
|
| super().__init__()
|
| self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
|
| self.gelu = nn.GELU()
|
| self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
|
|
|
| def forward(self, x):
|
| return self.proj(self.gelu(self.fc(x)))
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(self, n_embd, n_head, block_size):
|
| super().__init__()
|
| self.ln1 = nn.LayerNorm(n_embd)
|
| self.attn = CausalSelfAttention(n_embd, n_head, block_size)
|
| self.ln2 = nn.LayerNorm(n_embd)
|
| self.mlp = MLP(n_embd)
|
|
|
| def forward(self, x):
|
| x = x + self.attn(self.ln1(x))
|
| x = x + self.mlp(self.ln2(x))
|
| return x
|
|
|
|
|
| class LUNAModel(nn.Module):
|
| def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
|
| super().__init__()
|
| self.block_size = block_size
|
| self.wte = nn.Embedding(vocab_size, n_embd)
|
| self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
|
| self.ln_f = nn.LayerNorm(n_embd)
|
| self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| self.lm_head.weight = self.wte.weight
|
|
|
| def _init_weights(self, m):
|
| if isinstance(m, (nn.Linear, nn.Embedding)):
|
| m.weight.data.normal_(mean=0.0, std=0.02)
|
| if isinstance(m, nn.Linear) and m.bias is not None:
|
| m.bias.data.zero_()
|
|
|
| def forward(self, idx, targets=None, loss_mask=None, return_logits=True):
|
| x = self.wte(idx)
|
| for block in self.blocks:
|
| x = block(x)
|
| x = self.ln_f(x)
|
| logits = self.lm_head(x)
|
| loss = None
|
| if targets is not None:
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_targets = targets[:, 1:].contiguous()
|
| if loss_mask is not None:
|
| shift_mask = loss_mask[:, 1:].contiguous()
|
|
|
| flat_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| flat_targets = shift_targets.view(-1)
|
| flat_mask = shift_mask.view(-1).float()
|
| per_token_loss = F.cross_entropy(flat_logits, flat_targets, reduction='none')
|
| loss = (per_token_loss * flat_mask).sum() / flat_mask.sum().clamp(min=1)
|
| else:
|
| loss = F.cross_entropy(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_targets.view(-1)
|
| )
|
| if not return_logits:
|
| logits = None
|
| return logits, loss
|
|
|
| @property
|
| def num_params(self):
|
| return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
|
|
|
|
|
|
|
|
|
| class SFTDataset(torch.utils.data.Dataset):
|
| """
|
| Loads JSON SFT data (instruction/input/output) and tokenizes with prompt masking.
|
| Format per entry: {"instruction": "...", "input": "...", "output": "..."}
|
|
|
| Prompt template (Alpaca-style):
|
| ### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}<|endoftext|>
|
|
|
| Loss mask: 0 for prompt tokens, 1 for response tokens (including EOS).
|
| """
|
|
|
| def __init__(self, json_path, tokenizer, max_len=1024):
|
| with open(json_path, "r", encoding="utf-8") as f:
|
| self.data = json.load(f)
|
| self.tokenizer = tokenizer
|
| self.max_len = max_len
|
| self.eos_id = tokenizer.eos_token_id or 0
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| def _format_prompt(self, entry):
|
| inst = entry.get("instruction", "").strip()
|
| inp = entry.get("input", "").strip()
|
| out = entry.get("output", "").strip()
|
|
|
| if inst and inp:
|
| prompt = f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
|
| elif inst:
|
| prompt = f"### Instruction:\n{inst}\n\n### Response:\n"
|
| else:
|
|
|
| prompt = f"### Input:\n{inp}\n\n### Response:\n"
|
|
|
| return prompt, out
|
|
|
| def __getitem__(self, idx):
|
| entry = self.data[idx]
|
| prompt, response = self._format_prompt(entry)
|
|
|
| prompt_ids = self.tokenizer.encode(prompt)
|
| response_ids = self.tokenizer.encode(response) + [self.eos_id]
|
|
|
| total_ids = prompt_ids + response_ids
|
|
|
|
|
| if len(total_ids) > self.max_len:
|
| total_ids = total_ids[:self.max_len]
|
|
|
| total_ids[-1] = self.eos_id
|
|
|
| prompt_len = min(len(prompt_ids), self.max_len)
|
| else:
|
| prompt_len = len(prompt_ids)
|
|
|
|
|
| loss_mask = [0] * prompt_len + [1] * (len(total_ids) - prompt_len)
|
|
|
|
|
| pad_len = self.max_len - len(total_ids)
|
| total_ids = total_ids + [self.eos_id] * pad_len
|
| loss_mask = loss_mask + [0] * pad_len
|
|
|
| input_ids = torch.tensor(total_ids, dtype=torch.long)
|
| loss_mask = torch.tensor(loss_mask, dtype=torch.long)
|
|
|
| return input_ids, loss_mask
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def generate(model, input_ids, max_new=150, temperature=0.7,
|
| top_p=0.9, top_k=40, device="cpu"):
|
| model.eval()
|
| ids = input_ids.clone().to(device)
|
| for _ in range(max_new):
|
| ctx = ids[:, -model.block_size:]
|
| logits, _ = model(ctx)
|
| logits = logits[:, -1, :] / max(temperature, 1e-8)
|
| if top_k > 0:
|
| vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| logits[logits < vals[:, -1:]] = -float("inf")
|
| probs = torch.softmax(logits, dim=-1)
|
| if top_p < 1.0:
|
| sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| cum = torch.cumsum(sorted_probs, dim=-1)
|
| mask = cum - sorted_probs > top_p
|
| sorted_probs[mask] = 0.0
|
| sorted_probs /= sorted_probs.sum()
|
| next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
|
| else:
|
| next_token = torch.multinomial(probs[0], 1)
|
| ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
|
| if next_token.item() == 0:
|
| break
|
| model.train()
|
| return ids[0, input_ids.size(1):]
|
|
|
|
|
| def run_eval_prompts(model, tokenizer, prompts, device, step, out_dir):
|
| """Run eval prompts and print + log results."""
|
| model.eval()
|
| results = []
|
| sep = "β" * 60
|
|
|
| print(f"\n{sep}")
|
| print(f" EVAL @ step {step}")
|
| print(sep)
|
|
|
| for prompt_text in prompts:
|
|
|
| formatted = f"### Instruction:\n{prompt_text}\n\n### Response:\n"
|
| ids = tokenizer.encode(formatted, return_tensors="pt").to(device)
|
| out_ids = generate(model, ids, max_new=150, temperature=0.7, device=device)
|
| response = tokenizer.decode(out_ids.tolist(), skip_special_tokens=True).strip()
|
|
|
| print(f" Q: {prompt_text}")
|
| print(f" A: {response[:200]}")
|
| print()
|
| results.append({"prompt": prompt_text, "response": response[:500]})
|
|
|
| print(sep)
|
|
|
|
|
| eval_dir = Path(out_dir) / "evals"
|
| eval_dir.mkdir(parents=True, exist_ok=True)
|
| with open(eval_dir / f"eval_step_{step:06d}.json", "w", encoding="utf-8") as f:
|
| json.dump(results, f, indent=2, ensure_ascii=False)
|
|
|
| model.train()
|
| return results
|
|
|
|
|
|
|
|
|
| def probe_hardware():
|
| info = {
|
| "cpu_cores": os.cpu_count() or 4,
|
| "ram_gb": psutil.virtual_memory().total / 1024**3,
|
| }
|
| if torch.cuda.is_available():
|
| props = torch.cuda.get_device_properties(0)
|
| info.update({
|
| "device": "cuda",
|
| "gpu_name": props.name,
|
| "vram_gb": props.total_memory / 1024**3,
|
| "sm_major": props.major,
|
| })
|
| if props.major >= 8:
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
| info["precision"] = "bf16"
|
| info["dtype"] = torch.bfloat16
|
| else:
|
| info["precision"] = "fp16"
|
| info["dtype"] = torch.float16
|
| else:
|
| info.update({
|
| "device": "cpu", "gpu_name": "CPU", "vram_gb": 0,
|
| "sm_major": 0, "precision": "fp32", "dtype": torch.float32,
|
| })
|
| return info
|
|
|
|
|
| def probe_max_batch(model, device, dtype, seq_len, vocab_size, grad_accum_sim=4):
|
| """Binary search for max micro_batch. Safety: x0.70."""
|
| tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| lo, hi, best = 1, 512, 1
|
| while lo <= hi:
|
| mid = (lo + hi) // 2
|
| try:
|
| torch.cuda.empty_cache(); gc.collect()
|
| tmp_opt.zero_grad(set_to_none=True)
|
| for _ in range(grad_accum_sim):
|
| x = torch.randint(0, vocab_size, (mid, seq_len), device=device)
|
| mask = torch.ones_like(x)
|
| with autocast(device_type="cuda", dtype=dtype):
|
| _, loss = model(x, x, loss_mask=mask, return_logits=False)
|
| loss = loss / grad_accum_sim
|
| loss.backward()
|
| del x, mask, loss
|
| tmp_opt.step()
|
| tmp_opt.zero_grad(set_to_none=True)
|
| best = mid; lo = mid + 1
|
| torch.cuda.empty_cache()
|
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
| if "out of memory" in str(e).lower() or isinstance(e, torch.cuda.OutOfMemoryError):
|
| try: del x, mask, loss
|
| except: pass
|
| torch.cuda.empty_cache()
|
| tmp_opt.zero_grad(set_to_none=True)
|
| hi = mid - 1
|
| else:
|
| raise
|
| del tmp_opt; torch.cuda.empty_cache(); gc.collect()
|
| safe = max(1, int(best * 0.70))
|
| print(f" Probe: max_batch={best}, using {safe} (70% safety)")
|
| return safe
|
|
|
|
|
|
|
|
|
| def cosine_lr(step, warmup, total, lr_max, lr_min):
|
| if step < warmup:
|
| return lr_max * (step + 1) / warmup
|
| p = (step - warmup) / max(1, total - warmup)
|
| return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min)
|
|
|
|
|
|
|
|
|
| def load_sft_config(config_path):
|
| with open(config_path, encoding="utf-8") as f:
|
| raw = yaml.safe_load(f)
|
|
|
| cfg = {
|
| "auto_config": raw.get("auto_config", True),
|
| "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
|
| "hf_model_file": raw.get("hf_model_file", "latest.pt"),
|
| "hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
|
| "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
|
| "train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
|
| "val_json": raw.get("val_json", "Base/Datasets/sft_clean/val.json"),
|
| "out_dir": raw.get("out_dir", "Base/out/sft/luna_100m_sft"),
|
| "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
|
|
|
| "vocab_size": raw["model"]["vocab_size"],
|
| "seq_len": raw["model"]["seq_len"],
|
| "n_layer": raw["model"]["n_layer"],
|
| "n_embd": raw["model"]["n_embd"],
|
| "n_head": raw["model"]["n_head"],
|
|
|
| "epochs": raw["train"]["epochs"],
|
| "max_tokens": raw["train"].get("max_tokens", 0),
|
| "lr_warmup_steps": raw["train"]["lr_warmup_steps"],
|
| "save_interval": raw["train"]["save_interval"],
|
| "log_interval": raw["train"]["log_interval"],
|
| "eval_interval": raw["train"]["eval_interval"],
|
| "max_norm": raw["train"]["max_norm"],
|
|
|
| "lr": raw["optimizer"]["lr"],
|
| "min_lr": raw["optimizer"]["min_lr"],
|
| "weight_decay": raw["optimizer"]["weight_decay"],
|
| "betas": tuple(raw["optimizer"]["betas"]),
|
| "eps": raw["optimizer"]["eps"],
|
|
|
| "global_batch": raw["batch"]["global_batch"],
|
| "micro_batch": raw["batch"]["micro_batch"],
|
| "grad_accum": raw["batch"]["grad_accum"],
|
|
|
| "num_workers": raw["dataloader"]["num_workers"],
|
| "pin_memory": raw["dataloader"]["pin_memory"],
|
|
|
| "precision": raw["hardware"]["precision"],
|
|
|
| "eval_prompts": raw.get("eval_prompts", []),
|
| }
|
| return cfg
|
|
|
|
|
|
|
|
|
| SEP = "=" * 72
|
|
|
| def sft_train(cfg):
|
| hw = probe_hardware()
|
| device = torch.device(hw["device"])
|
|
|
| if device.type == "cuda":
|
| torch.cuda.empty_cache(); gc.collect()
|
|
|
|
|
| if cfg["auto_config"]:
|
| dtype = hw.get("dtype", torch.float32)
|
| cfg["precision"] = hw["precision"]
|
| else:
|
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16,
|
| "fp32": torch.float32}.get(cfg["precision"], torch.float32)
|
|
|
| print(SEP)
|
| print(" LUNA 100M - SFT Fine-Tuning")
|
| print(SEP)
|
| print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
|
| print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores")
|
| print(f" Precision : {cfg['precision']} dtype={dtype}")
|
| print(f" Pretrained : {cfg['pretrained_ckpt']}")
|
|
|
|
|
| from transformers import AutoTokenizer
|
| tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"])
|
| print(f" Tokenizer : {cfg['tokenizer_dir']} (vocab={tokenizer.vocab_size})")
|
|
|
|
|
| print(f"\n Building LUNA-100M...")
|
| model = LUNAModel(
|
| vocab_size=cfg["vocab_size"],
|
| block_size=cfg["seq_len"],
|
| n_layer=cfg["n_layer"],
|
| n_embd=cfg["n_embd"],
|
| n_head=cfg["n_head"],
|
| ).to(device)
|
| print(f" Parameters: {model.num_params:,} (unique)")
|
|
|
|
|
| ckpt_path = Path(cfg["pretrained_ckpt"])
|
| if not ckpt_path.exists() and cfg.get("hf_model_repo"):
|
|
|
| print(f"\n Pretrained checkpoint not found locally.")
|
| print(f" Downloading from HuggingFace: {cfg['hf_model_repo']} ({cfg['hf_model_file']})")
|
| from huggingface_hub import hf_hub_download
|
| ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| downloaded = hf_hub_download(
|
| repo_id=cfg["hf_model_repo"],
|
| filename=cfg["hf_model_file"],
|
| local_dir=str(ckpt_path.parent),
|
| token=os.environ.get("HF_TOKEN"),
|
| )
|
| downloaded_path = Path(downloaded)
|
| if not ckpt_path.exists() and downloaded_path.exists():
|
| ckpt_path = downloaded_path
|
| print(f" Downloaded to: {ckpt_path}")
|
|
|
| if ckpt_path.exists():
|
| print(f"\n Loading pretrained checkpoint: {ckpt_path}")
|
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
| state = ckpt["model"] if "model" in ckpt else ckpt
|
| model.load_state_dict(state, strict=True)
|
| pretrain_step = ckpt.get("step", "?")
|
| pretrain_tokens = ckpt.get("tokens_seen", 0)
|
| print(f" Pretrained @ step {pretrain_step}, tokens seen: {pretrain_tokens:,}")
|
|
|
| else:
|
| print(f"\n WARNING: No pretrained checkpoint at {ckpt_path}")
|
| print(f" Training from scratch (not recommended for SFT)!")
|
|
|
|
|
| train_path = Path(cfg["train_json"])
|
| val_path = Path(cfg["val_json"]) if cfg["val_json"] else None
|
|
|
| if not train_path.exists() and cfg.get("hf_dataset_repo"):
|
| print(f"\n SFT dataset not found locally.")
|
| print(f" Downloading from HuggingFace: {cfg['hf_dataset_repo']}")
|
| from huggingface_hub import hf_hub_download
|
| train_path.parent.mkdir(parents=True, exist_ok=True)
|
| hf_hub_download(
|
| repo_id=cfg["hf_dataset_repo"],
|
| repo_type="dataset",
|
| filename="train.json",
|
| local_dir=str(train_path.parent),
|
| token=os.environ.get("HF_TOKEN"),
|
| )
|
| print(f" Downloaded train.json")
|
| if val_path:
|
| hf_hub_download(
|
| repo_id=cfg["hf_dataset_repo"],
|
| repo_type="dataset",
|
| filename="val.json",
|
| local_dir=str(val_path.parent),
|
| token=os.environ.get("HF_TOKEN"),
|
| )
|
| print(f" Downloaded val.json")
|
|
|
| print(f"\n Train data: {cfg['train_json']}")
|
| train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"])
|
| print(f" Train entries: {len(train_dataset):,}")
|
|
|
| val_dataset = None
|
| if cfg["val_json"] and Path(cfg["val_json"]).exists():
|
| val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"])
|
| print(f" Val entries: {len(val_dataset):,}")
|
|
|
|
|
| if cfg["auto_config"] and device.type == "cuda":
|
| print(f"\n Probing max micro_batch_size...")
|
| max_mbs = probe_max_batch(model, device, dtype, cfg["seq_len"], cfg["vocab_size"])
|
| model.load_state_dict(state, strict=True)
|
| torch.cuda.empty_cache(); gc.collect()
|
| grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs))
|
| effective_batch = max_mbs * grad_accum
|
| else:
|
| max_mbs = cfg["micro_batch"]
|
| grad_accum = cfg["grad_accum"]
|
| effective_batch = max_mbs * grad_accum
|
|
|
| print(f" micro_batch={max_mbs}, grad_accum={grad_accum}, effective={effective_batch}")
|
|
|
|
|
| train_loader = torch.utils.data.DataLoader(
|
| train_dataset,
|
| batch_size=max_mbs,
|
| shuffle=True,
|
| num_workers=cfg["num_workers"],
|
| pin_memory=cfg["pin_memory"],
|
| drop_last=True,
|
| prefetch_factor=4 if cfg["num_workers"] > 0 else None,
|
| persistent_workers=cfg["num_workers"] > 0,
|
| )
|
|
|
| val_loader = None
|
| if val_dataset:
|
| val_loader = torch.utils.data.DataLoader(
|
| val_dataset, batch_size=max_mbs, shuffle=False,
|
| num_workers=min(2, cfg["num_workers"]),
|
| pin_memory=cfg["pin_memory"], drop_last=False,
|
| )
|
|
|
|
|
| try:
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(), lr=cfg["lr"],
|
| weight_decay=cfg["weight_decay"],
|
| betas=cfg["betas"], eps=cfg["eps"], fused=True,
|
| )
|
| except TypeError:
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(), lr=cfg["lr"],
|
| weight_decay=cfg["weight_decay"],
|
| betas=cfg["betas"], eps=cfg["eps"],
|
| )
|
|
|
| use_scaler = dtype == torch.float16
|
| scaler = GradScaler(enabled=use_scaler)
|
|
|
|
|
| steps_per_epoch = len(train_loader) // grad_accum
|
| total_steps = steps_per_epoch * cfg["epochs"]
|
| warmup_steps = min(cfg["lr_warmup_steps"], total_steps // 5)
|
|
|
| out_dir = Path(cfg["out_dir"])
|
| out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| print(f"\n Epochs : {cfg['epochs']}")
|
| print(f" Steps/epoch : {steps_per_epoch:,}")
|
| print(f" Total steps : {total_steps:,}")
|
| print(f" Warmup steps : {warmup_steps}")
|
| print(f" LR : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}")
|
| print(f" Save every : {cfg['save_interval']} steps")
|
| print(f" Eval every : {cfg['eval_interval']} steps")
|
| print(f" Eval prompts : {len(cfg['eval_prompts'])}")
|
| print(f" Out dir : {out_dir}")
|
| print(SEP)
|
|
|
|
|
| start_step = 0
|
| sft_ckpt_path = out_dir / "latest.pt"
|
| if sft_ckpt_path.exists():
|
| print(f"\n Resuming SFT from {sft_ckpt_path}...")
|
| sft_ckpt = torch.load(sft_ckpt_path, map_location=device, weights_only=True)
|
| model.load_state_dict(sft_ckpt["model"])
|
| optimizer.load_state_dict(sft_ckpt["optimizer"])
|
| start_step = sft_ckpt["step"]
|
| print(f" Resumed at SFT step {start_step}")
|
|
|
|
|
| if cfg["eval_prompts"] and start_step == 0:
|
| print("\n Running initial eval (before SFT)...")
|
| run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir)
|
|
|
|
|
| model.train()
|
| run_t0 = time.perf_counter()
|
| step = start_step
|
| best_val_loss = float("inf")
|
|
|
| print(f"\n Starting SFT training (step {start_step} -> {total_steps})...")
|
|
|
| for epoch in range(cfg["epochs"]):
|
| data_iter = iter(train_loader)
|
| micro_step = 0
|
|
|
| for batch_idx, (input_ids, loss_mask) in enumerate(data_iter):
|
|
|
| current_global_step = epoch * steps_per_epoch + (micro_step // grad_accum)
|
| if current_global_step < start_step and (micro_step % grad_accum == grad_accum - 1):
|
| micro_step += 1
|
| continue
|
| if current_global_step >= total_steps:
|
| break
|
|
|
| input_ids = input_ids.to(device, non_blocking=True)
|
| loss_mask = loss_mask.to(device, non_blocking=True)
|
|
|
| t0 = time.perf_counter()
|
|
|
|
|
| with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
|
| _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
|
| loss = loss / grad_accum
|
|
|
| scaler.scale(loss).backward()
|
| micro_step += 1
|
|
|
|
|
| if micro_step % grad_accum == 0:
|
| scaler.unscale_(optimizer)
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"])
|
|
|
|
|
| lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
|
| for pg in optimizer.param_groups:
|
| pg["lr"] = lr_now
|
|
|
| scaler.step(optimizer)
|
| scaler.update()
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| if device.type == "cuda":
|
| torch.cuda.synchronize()
|
|
|
| dt = time.perf_counter() - t0
|
| step += 1
|
|
|
|
|
| if step % cfg["log_interval"] == 0 or step <= 3:
|
| tokens_step = effective_batch * cfg["seq_len"]
|
| tps = tokens_step / dt if dt > 0 else 0
|
| vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
|
| eta_h = (total_steps - step) * dt / 3600
|
| print(f" step {step:6d}/{total_steps} | epoch {epoch+1}/{cfg['epochs']} | "
|
| f"loss {loss.item()*grad_accum:.4f} | lr {lr_now:.2e} | "
|
| f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h")
|
|
|
|
|
| if step % cfg["save_interval"] == 0 or step == total_steps:
|
| raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
|
| step_dir = out_dir / f"step-{step:06d}"
|
| step_dir.mkdir(parents=True, exist_ok=True)
|
| torch.save(raw_model.state_dict(), step_dir / "model.pth")
|
| torch.save({
|
| "step": step,
|
| "model": raw_model.state_dict(),
|
| "optimizer": optimizer.state_dict(),
|
| "epoch": epoch,
|
| "sft_loss": loss.item() * grad_accum,
|
| }, out_dir / "latest.pt")
|
| print(f" Saved -> {step_dir}")
|
|
|
|
|
| if step % cfg["eval_interval"] == 0 or step == total_steps:
|
|
|
| if val_loader:
|
| model.eval()
|
| val_loss_sum = 0.0
|
| val_count = 0
|
| with torch.no_grad():
|
| for val_ids, val_mask in val_loader:
|
| val_ids = val_ids.to(device, non_blocking=True)
|
| val_mask = val_mask.to(device, non_blocking=True)
|
| with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
|
| _, vl = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False)
|
| val_loss_sum += vl.item()
|
| val_count += 1
|
| if val_count >= 50:
|
| break
|
| avg_val = val_loss_sum / max(val_count, 1)
|
| print(f" Val loss: {avg_val:.4f}")
|
| if avg_val < best_val_loss:
|
| best_val_loss = avg_val
|
| raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
|
| torch.save(raw_model.state_dict(), out_dir / "best_model.pth")
|
| print(f" New best! Saved best_model.pth")
|
| model.train()
|
|
|
|
|
| if cfg["eval_prompts"]:
|
| run_eval_prompts(model, tokenizer, cfg["eval_prompts"],
|
| device, step, out_dir)
|
|
|
|
|
| final_dir = out_dir / "final"
|
| final_dir.mkdir(parents=True, exist_ok=True)
|
| raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
|
| torch.save(raw_model.state_dict(), final_dir / "model.pth")
|
| torch.save({
|
| "step": step,
|
| "model": raw_model.state_dict(),
|
| "sft_complete": True,
|
| }, out_dir / "latest.pt")
|
|
|
|
|
| import shutil
|
| tok_src = Path(cfg["tokenizer_dir"])
|
| if tok_src.exists():
|
| shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True)
|
|
|
| total_h = (time.perf_counter() - run_t0) / 3600
|
| print(SEP)
|
| print(f" SFT Complete! {total_h:.2f}h -> {final_dir}")
|
| print(f" Best val loss: {best_val_loss:.4f}")
|
| print(SEP)
|
|
|
|
|
|
|
|
|
| def parse_args():
|
| p = argparse.ArgumentParser(description="LUNA 100M β SFT Fine-Tuning")
|
| p.add_argument("--config", default="sft_config.yaml")
|
| p.add_argument("--pretrained_ckpt", default=None)
|
| p.add_argument("--train_json", default=None)
|
| p.add_argument("--val_json", default=None)
|
| p.add_argument("--out_dir", default=None)
|
| p.add_argument("--epochs", type=int, default=None)
|
| p.add_argument("--lr", type=float, default=None)
|
| p.add_argument("--micro_batch",type=int, default=None)
|
| p.add_argument("--global_batch",type=int, default=None)
|
| p.add_argument("--save_interval", type=int, default=None)
|
| p.add_argument("--eval_interval", type=int, default=None)
|
| p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"),
|
| default=None)
|
| return p.parse_args()
|
|
|
|
|
| if __name__ == "__main__":
|
| args = parse_args()
|
| cfg = load_sft_config(args.config)
|
|
|
| for key, val in vars(args).items():
|
| if key != "config" and val is not None:
|
| cfg[key] = val
|
| sft_train(cfg)
|
|
|