LUNA-Training / sft_train.py
ASTERIZER's picture
Upload sft_train.py with huggingface_hub
0122e75 verified
"""
LUNA 100M β€” SFT Fine-Tuning Script
====================================
Fine-tunes the pretrained LUNA-100M on instruction-following (SFT) data.
Features:
- Loads pretrained checkpoint (latest.pt from pretraining)
- JSON-based SFT dataset (instruction/input/output format)
- Prompt masking: loss computed only on the output portion
- Checkpoint eval: runs identity + knowledge prompts after each save
- Cosine LR with warmup
- Auto hardware detection (same as train.py)
Usage:
python sft_train.py # uses sft_config.yaml
python sft_train.py --config sft_config.yaml # explicit config
python sft_train.py --train_json /data/train.json # override data path
"""
import os
import gc
import sys
import math
import time
import json
import argparse
import yaml
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from pathlib import Path
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# ─── Model (identical to train.py) ───────────────────────────────────────────
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_seq_len).float()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, seq_len):
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary(x, cos, sin):
c = cos.unsqueeze(0).unsqueeze(0)
s = sin.unsqueeze(0).unsqueeze(0)
return x * c + rotate_half(x) * s
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
super().__init__()
self.n_head = n_head
self.head_dim = n_embd // n_head
self.rot_dim = int(self.head_dim * rotary_pct)
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
self.rotary = RotaryEmbedding(self.rot_dim, block_size)
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
cos, sin = self.rotary(T)
q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
class MLP(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
self.gelu = nn.GELU()
self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
def forward(self, x):
return self.proj(self.gelu(self.fc(x)))
class Block(nn.Module):
def __init__(self, n_embd, n_head, block_size):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = MLP(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class LUNAModel(nn.Module):
def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
super().__init__()
self.block_size = block_size
self.wte = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.lm_head.weight = self.wte.weight # tied
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
m.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()
def forward(self, idx, targets=None, loss_mask=None, return_logits=True):
x = self.wte(idx)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_targets = targets[:, 1:].contiguous()
if loss_mask is not None:
shift_mask = loss_mask[:, 1:].contiguous()
# Only compute loss on output tokens
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_targets = shift_targets.view(-1)
flat_mask = shift_mask.view(-1).float()
per_token_loss = F.cross_entropy(flat_logits, flat_targets, reduction='none')
loss = (per_token_loss * flat_mask).sum() / flat_mask.sum().clamp(min=1)
else:
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1)
)
if not return_logits:
logits = None
return logits, loss
@property
def num_params(self):
return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
# ─── SFT Dataset ──────────────────────────────────────────────────────────────
class SFTDataset(torch.utils.data.Dataset):
"""
Loads JSON SFT data (instruction/input/output) and tokenizes with prompt masking.
Format per entry: {"instruction": "...", "input": "...", "output": "..."}
Prompt template (Alpaca-style):
### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}<|endoftext|>
Loss mask: 0 for prompt tokens, 1 for response tokens (including EOS).
"""
def __init__(self, json_path, tokenizer, max_len=1024):
with open(json_path, "r", encoding="utf-8") as f:
self.data = json.load(f)
self.tokenizer = tokenizer
self.max_len = max_len
self.eos_id = tokenizer.eos_token_id or 0
def __len__(self):
return len(self.data)
def _format_prompt(self, entry):
inst = entry.get("instruction", "").strip()
inp = entry.get("input", "").strip()
out = entry.get("output", "").strip()
if inst and inp:
prompt = f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
elif inst:
prompt = f"### Instruction:\n{inst}\n\n### Response:\n"
else:
# input-only format
prompt = f"### Input:\n{inp}\n\n### Response:\n"
return prompt, out
def __getitem__(self, idx):
entry = self.data[idx]
prompt, response = self._format_prompt(entry)
prompt_ids = self.tokenizer.encode(prompt)
response_ids = self.tokenizer.encode(response) + [self.eos_id]
total_ids = prompt_ids + response_ids
# Truncate to max_len
if len(total_ids) > self.max_len:
total_ids = total_ids[:self.max_len]
# Ensure EOS at end
total_ids[-1] = self.eos_id
# Recalculate prompt boundary
prompt_len = min(len(prompt_ids), self.max_len)
else:
prompt_len = len(prompt_ids)
# Build loss mask: 0 for prompt, 1 for response
loss_mask = [0] * prompt_len + [1] * (len(total_ids) - prompt_len)
# Pad to max_len
pad_len = self.max_len - len(total_ids)
total_ids = total_ids + [self.eos_id] * pad_len
loss_mask = loss_mask + [0] * pad_len # don't compute loss on padding
input_ids = torch.tensor(total_ids, dtype=torch.long)
loss_mask = torch.tensor(loss_mask, dtype=torch.long)
return input_ids, loss_mask
# ─── Generation (for eval) ───────────────────────────────────────────────────
@torch.no_grad()
def generate(model, input_ids, max_new=150, temperature=0.7,
top_p=0.9, top_k=40, device="cpu"):
model.eval()
ids = input_ids.clone().to(device)
for _ in range(max_new):
ctx = ids[:, -model.block_size:]
logits, _ = model(ctx)
logits = logits[:, -1, :] / max(temperature, 1e-8)
if top_k > 0:
vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < vals[:, -1:]] = -float("inf")
probs = torch.softmax(logits, dim=-1)
if top_p < 1.0:
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cum = torch.cumsum(sorted_probs, dim=-1)
mask = cum - sorted_probs > top_p
sorted_probs[mask] = 0.0
sorted_probs /= sorted_probs.sum()
next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
else:
next_token = torch.multinomial(probs[0], 1)
ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
if next_token.item() == 0: # EOS
break
model.train()
return ids[0, input_ids.size(1):]
def run_eval_prompts(model, tokenizer, prompts, device, step, out_dir):
"""Run eval prompts and print + log results."""
model.eval()
results = []
sep = "─" * 60
print(f"\n{sep}")
print(f" EVAL @ step {step}")
print(sep)
for prompt_text in prompts:
# Format as instruction
formatted = f"### Instruction:\n{prompt_text}\n\n### Response:\n"
ids = tokenizer.encode(formatted, return_tensors="pt").to(device)
out_ids = generate(model, ids, max_new=150, temperature=0.7, device=device)
response = tokenizer.decode(out_ids.tolist(), skip_special_tokens=True).strip()
print(f" Q: {prompt_text}")
print(f" A: {response[:200]}")
print()
results.append({"prompt": prompt_text, "response": response[:500]})
print(sep)
# Save eval log
eval_dir = Path(out_dir) / "evals"
eval_dir.mkdir(parents=True, exist_ok=True)
with open(eval_dir / f"eval_step_{step:06d}.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
model.train()
return results
# ─── Hardware Detection (same as train.py) ────────────────────────────────────
def probe_hardware():
info = {
"cpu_cores": os.cpu_count() or 4,
"ram_gb": psutil.virtual_memory().total / 1024**3,
}
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
info.update({
"device": "cuda",
"gpu_name": props.name,
"vram_gb": props.total_memory / 1024**3,
"sm_major": props.major,
})
if props.major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
info["precision"] = "bf16"
info["dtype"] = torch.bfloat16
else:
info["precision"] = "fp16"
info["dtype"] = torch.float16
else:
info.update({
"device": "cpu", "gpu_name": "CPU", "vram_gb": 0,
"sm_major": 0, "precision": "fp32", "dtype": torch.float32,
})
return info
def probe_max_batch(model, device, dtype, seq_len, vocab_size, grad_accum_sim=4):
"""Binary search for max micro_batch. Safety: x0.70."""
tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
lo, hi, best = 1, 512, 1
while lo <= hi:
mid = (lo + hi) // 2
try:
torch.cuda.empty_cache(); gc.collect()
tmp_opt.zero_grad(set_to_none=True)
for _ in range(grad_accum_sim):
x = torch.randint(0, vocab_size, (mid, seq_len), device=device)
mask = torch.ones_like(x)
with autocast(device_type="cuda", dtype=dtype):
_, loss = model(x, x, loss_mask=mask, return_logits=False)
loss = loss / grad_accum_sim
loss.backward()
del x, mask, loss
tmp_opt.step()
tmp_opt.zero_grad(set_to_none=True)
best = mid; lo = mid + 1
torch.cuda.empty_cache()
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if "out of memory" in str(e).lower() or isinstance(e, torch.cuda.OutOfMemoryError):
try: del x, mask, loss
except: pass
torch.cuda.empty_cache()
tmp_opt.zero_grad(set_to_none=True)
hi = mid - 1
else:
raise
del tmp_opt; torch.cuda.empty_cache(); gc.collect()
safe = max(1, int(best * 0.70))
print(f" Probe: max_batch={best}, using {safe} (70% safety)")
return safe
# ─── LR Schedule ──────────────────────────────────────────────────────────────
def cosine_lr(step, warmup, total, lr_max, lr_min):
if step < warmup:
return lr_max * (step + 1) / warmup
p = (step - warmup) / max(1, total - warmup)
return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min)
# ─── Config ───────────────────────────────────────────────────────────────────
def load_sft_config(config_path):
with open(config_path, encoding="utf-8") as f:
raw = yaml.safe_load(f)
cfg = {
"auto_config": raw.get("auto_config", True),
"hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
"hf_model_file": raw.get("hf_model_file", "latest.pt"),
"hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
"pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
"train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
"val_json": raw.get("val_json", "Base/Datasets/sft_clean/val.json"),
"out_dir": raw.get("out_dir", "Base/out/sft/luna_100m_sft"),
"tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
# model
"vocab_size": raw["model"]["vocab_size"],
"seq_len": raw["model"]["seq_len"],
"n_layer": raw["model"]["n_layer"],
"n_embd": raw["model"]["n_embd"],
"n_head": raw["model"]["n_head"],
# train
"epochs": raw["train"]["epochs"],
"max_tokens": raw["train"].get("max_tokens", 0),
"lr_warmup_steps": raw["train"]["lr_warmup_steps"],
"save_interval": raw["train"]["save_interval"],
"log_interval": raw["train"]["log_interval"],
"eval_interval": raw["train"]["eval_interval"],
"max_norm": raw["train"]["max_norm"],
# optimizer
"lr": raw["optimizer"]["lr"],
"min_lr": raw["optimizer"]["min_lr"],
"weight_decay": raw["optimizer"]["weight_decay"],
"betas": tuple(raw["optimizer"]["betas"]),
"eps": raw["optimizer"]["eps"],
# batch
"global_batch": raw["batch"]["global_batch"],
"micro_batch": raw["batch"]["micro_batch"],
"grad_accum": raw["batch"]["grad_accum"],
# dataloader
"num_workers": raw["dataloader"]["num_workers"],
"pin_memory": raw["dataloader"]["pin_memory"],
# hardware
"precision": raw["hardware"]["precision"],
# eval prompts
"eval_prompts": raw.get("eval_prompts", []),
}
return cfg
# ─── Training ─────────────────────────────────────────────────────────────────
SEP = "=" * 72
def sft_train(cfg):
hw = probe_hardware()
device = torch.device(hw["device"])
if device.type == "cuda":
torch.cuda.empty_cache(); gc.collect()
# Precision
if cfg["auto_config"]:
dtype = hw.get("dtype", torch.float32)
cfg["precision"] = hw["precision"]
else:
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16,
"fp32": torch.float32}.get(cfg["precision"], torch.float32)
print(SEP)
print(" LUNA 100M - SFT Fine-Tuning")
print(SEP)
print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores")
print(f" Precision : {cfg['precision']} dtype={dtype}")
print(f" Pretrained : {cfg['pretrained_ckpt']}")
# ── Tokenizer ─────────────────────────────────────────────────────────────
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"])
print(f" Tokenizer : {cfg['tokenizer_dir']} (vocab={tokenizer.vocab_size})")
# ── Model ─────────────────────────────────────────────────────────────────
print(f"\n Building LUNA-100M...")
model = LUNAModel(
vocab_size=cfg["vocab_size"],
block_size=cfg["seq_len"],
n_layer=cfg["n_layer"],
n_embd=cfg["n_embd"],
n_head=cfg["n_head"],
).to(device)
print(f" Parameters: {model.num_params:,} (unique)")
# ── Load pretrained weights ───────────────────────────────────────────────
ckpt_path = Path(cfg["pretrained_ckpt"])
if not ckpt_path.exists() and cfg.get("hf_model_repo"):
# Auto-download from HuggingFace model repo
print(f"\n Pretrained checkpoint not found locally.")
print(f" Downloading from HuggingFace: {cfg['hf_model_repo']} ({cfg['hf_model_file']})")
from huggingface_hub import hf_hub_download
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
downloaded = hf_hub_download(
repo_id=cfg["hf_model_repo"],
filename=cfg["hf_model_file"],
local_dir=str(ckpt_path.parent),
token=os.environ.get("HF_TOKEN"),
)
downloaded_path = Path(downloaded)
if not ckpt_path.exists() and downloaded_path.exists():
ckpt_path = downloaded_path
print(f" Downloaded to: {ckpt_path}")
if ckpt_path.exists():
print(f"\n Loading pretrained checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
state = ckpt["model"] if "model" in ckpt else ckpt
model.load_state_dict(state, strict=True)
pretrain_step = ckpt.get("step", "?")
pretrain_tokens = ckpt.get("tokens_seen", 0)
print(f" Pretrained @ step {pretrain_step}, tokens seen: {pretrain_tokens:,}")
# Do NOT load optimizer state β€” we start fresh for SFT
else:
print(f"\n WARNING: No pretrained checkpoint at {ckpt_path}")
print(f" Training from scratch (not recommended for SFT)!")
# ── Dataset (auto-download from HF if missing) ─────────────────────────────
train_path = Path(cfg["train_json"])
val_path = Path(cfg["val_json"]) if cfg["val_json"] else None
if not train_path.exists() and cfg.get("hf_dataset_repo"):
print(f"\n SFT dataset not found locally.")
print(f" Downloading from HuggingFace: {cfg['hf_dataset_repo']}")
from huggingface_hub import hf_hub_download
train_path.parent.mkdir(parents=True, exist_ok=True)
hf_hub_download(
repo_id=cfg["hf_dataset_repo"],
repo_type="dataset",
filename="train.json",
local_dir=str(train_path.parent),
token=os.environ.get("HF_TOKEN"),
)
print(f" Downloaded train.json")
if val_path:
hf_hub_download(
repo_id=cfg["hf_dataset_repo"],
repo_type="dataset",
filename="val.json",
local_dir=str(val_path.parent),
token=os.environ.get("HF_TOKEN"),
)
print(f" Downloaded val.json")
print(f"\n Train data: {cfg['train_json']}")
train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"])
print(f" Train entries: {len(train_dataset):,}")
val_dataset = None
if cfg["val_json"] and Path(cfg["val_json"]).exists():
val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"])
print(f" Val entries: {len(val_dataset):,}")
# ── Batch sizing ──────────────────────────────────────────────────────────
if cfg["auto_config"] and device.type == "cuda":
print(f"\n Probing max micro_batch_size...")
max_mbs = probe_max_batch(model, device, dtype, cfg["seq_len"], cfg["vocab_size"])
model.load_state_dict(state, strict=True) # reinit after probe
torch.cuda.empty_cache(); gc.collect()
grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs))
effective_batch = max_mbs * grad_accum
else:
max_mbs = cfg["micro_batch"]
grad_accum = cfg["grad_accum"]
effective_batch = max_mbs * grad_accum
print(f" micro_batch={max_mbs}, grad_accum={grad_accum}, effective={effective_batch}")
# ── DataLoader ────────────────────────────────────────────────────────────
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=max_mbs,
shuffle=True,
num_workers=cfg["num_workers"],
pin_memory=cfg["pin_memory"],
drop_last=True,
prefetch_factor=4 if cfg["num_workers"] > 0 else None,
persistent_workers=cfg["num_workers"] > 0,
)
val_loader = None
if val_dataset:
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=max_mbs, shuffle=False,
num_workers=min(2, cfg["num_workers"]),
pin_memory=cfg["pin_memory"], drop_last=False,
)
# ── Optimizer ─────────────────────────────────────────────────────────────
try:
optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg["lr"],
weight_decay=cfg["weight_decay"],
betas=cfg["betas"], eps=cfg["eps"], fused=True,
)
except TypeError:
optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg["lr"],
weight_decay=cfg["weight_decay"],
betas=cfg["betas"], eps=cfg["eps"],
)
use_scaler = dtype == torch.float16
scaler = GradScaler(enabled=use_scaler)
# ── Schedule ──────────────────────────────────────────────────────────────
steps_per_epoch = len(train_loader) // grad_accum
total_steps = steps_per_epoch * cfg["epochs"]
warmup_steps = min(cfg["lr_warmup_steps"], total_steps // 5)
out_dir = Path(cfg["out_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
print(f"\n Epochs : {cfg['epochs']}")
print(f" Steps/epoch : {steps_per_epoch:,}")
print(f" Total steps : {total_steps:,}")
print(f" Warmup steps : {warmup_steps}")
print(f" LR : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}")
print(f" Save every : {cfg['save_interval']} steps")
print(f" Eval every : {cfg['eval_interval']} steps")
print(f" Eval prompts : {len(cfg['eval_prompts'])}")
print(f" Out dir : {out_dir}")
print(SEP)
# ── Resume SFT ────────────────────────────────────────────────────────────
start_step = 0
sft_ckpt_path = out_dir / "latest.pt"
if sft_ckpt_path.exists():
print(f"\n Resuming SFT from {sft_ckpt_path}...")
sft_ckpt = torch.load(sft_ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(sft_ckpt["model"])
optimizer.load_state_dict(sft_ckpt["optimizer"])
start_step = sft_ckpt["step"]
print(f" Resumed at SFT step {start_step}")
# ── Eval at start ─────────────────────────────────────────────────────────
if cfg["eval_prompts"] and start_step == 0:
print("\n Running initial eval (before SFT)...")
run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir)
# ── Training loop ─────────────────────────────────────────────────────────
model.train()
run_t0 = time.perf_counter()
step = start_step
best_val_loss = float("inf")
print(f"\n Starting SFT training (step {start_step} -> {total_steps})...")
for epoch in range(cfg["epochs"]):
data_iter = iter(train_loader)
micro_step = 0
for batch_idx, (input_ids, loss_mask) in enumerate(data_iter):
# Skip already-done steps on resume
current_global_step = epoch * steps_per_epoch + (micro_step // grad_accum)
if current_global_step < start_step and (micro_step % grad_accum == grad_accum - 1):
micro_step += 1
continue
if current_global_step >= total_steps:
break
input_ids = input_ids.to(device, non_blocking=True)
loss_mask = loss_mask.to(device, non_blocking=True)
t0 = time.perf_counter()
# Accumulation step
with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
_, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
loss = loss / grad_accum
scaler.scale(loss).backward()
micro_step += 1
# Optimizer step after grad_accum micro-batches
if micro_step % grad_accum == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"])
# LR schedule
lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
for pg in optimizer.param_groups:
pg["lr"] = lr_now
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if device.type == "cuda":
torch.cuda.synchronize()
dt = time.perf_counter() - t0
step += 1
# ── Log ───────────────────────────────────────────────────────
if step % cfg["log_interval"] == 0 or step <= 3:
tokens_step = effective_batch * cfg["seq_len"]
tps = tokens_step / dt if dt > 0 else 0
vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
eta_h = (total_steps - step) * dt / 3600
print(f" step {step:6d}/{total_steps} | epoch {epoch+1}/{cfg['epochs']} | "
f"loss {loss.item()*grad_accum:.4f} | lr {lr_now:.2e} | "
f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h")
# ── Save checkpoint ───────────────────────────────────────────
if step % cfg["save_interval"] == 0 or step == total_steps:
raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
step_dir = out_dir / f"step-{step:06d}"
step_dir.mkdir(parents=True, exist_ok=True)
torch.save(raw_model.state_dict(), step_dir / "model.pth")
torch.save({
"step": step,
"model": raw_model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"sft_loss": loss.item() * grad_accum,
}, out_dir / "latest.pt")
print(f" Saved -> {step_dir}")
# ── Eval ──────────────────────────────────────────────────────
if step % cfg["eval_interval"] == 0 or step == total_steps:
# Validation loss
if val_loader:
model.eval()
val_loss_sum = 0.0
val_count = 0
with torch.no_grad():
for val_ids, val_mask in val_loader:
val_ids = val_ids.to(device, non_blocking=True)
val_mask = val_mask.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
_, vl = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False)
val_loss_sum += vl.item()
val_count += 1
if val_count >= 50: # cap eval to 50 batches
break
avg_val = val_loss_sum / max(val_count, 1)
print(f" Val loss: {avg_val:.4f}")
if avg_val < best_val_loss:
best_val_loss = avg_val
raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
torch.save(raw_model.state_dict(), out_dir / "best_model.pth")
print(f" New best! Saved best_model.pth")
model.train()
# Run eval prompts
if cfg["eval_prompts"]:
run_eval_prompts(model, tokenizer, cfg["eval_prompts"],
device, step, out_dir)
# ── Final ─────────────────────────────────────────────────────────────────
final_dir = out_dir / "final"
final_dir.mkdir(parents=True, exist_ok=True)
raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
torch.save(raw_model.state_dict(), final_dir / "model.pth")
torch.save({
"step": step,
"model": raw_model.state_dict(),
"sft_complete": True,
}, out_dir / "latest.pt")
# Copy tokenizer
import shutil
tok_src = Path(cfg["tokenizer_dir"])
if tok_src.exists():
shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True)
total_h = (time.perf_counter() - run_t0) / 3600
print(SEP)
print(f" SFT Complete! {total_h:.2f}h -> {final_dir}")
print(f" Best val loss: {best_val_loss:.4f}")
print(SEP)
# ─── Entry ────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(description="LUNA 100M β€” SFT Fine-Tuning")
p.add_argument("--config", default="sft_config.yaml")
p.add_argument("--pretrained_ckpt", default=None)
p.add_argument("--train_json", default=None)
p.add_argument("--val_json", default=None)
p.add_argument("--out_dir", default=None)
p.add_argument("--epochs", type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--micro_batch",type=int, default=None)
p.add_argument("--global_batch",type=int, default=None)
p.add_argument("--save_interval", type=int, default=None)
p.add_argument("--eval_interval", type=int, default=None)
p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"),
default=None)
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
cfg = load_sft_config(args.config)
# CLI overrides
for key, val in vars(args).items():
if key != "config" and val is not None:
cfg[key] = val
sft_train(cfg)