LUNA-Training / train.py
ASTERIZER's picture
Upload train.py with huggingface_hub
2be87ed verified
"""
LUNA 100M β€” Config-Driven Dynamic Training Script
==================================================
Reads train_config.yaml for all hyperparameters.
auto_config: true -> hardware probed; batch/lr/workers set automatically
auto_config: false -> every value in config used exactly as-is
Usage:
python train.py # uses train_config.yaml defaults
python train.py --config train_config.yaml # explicit config path
python train.py --data_path /mnt/data/litdata_final # override data path only
python train.py --max_tokens 10000000 # short smoke-test run
"""
import os
import gc
import sys
import math
import time
import json
import argparse
import yaml
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from pathlib import Path
# Reduce CUDA memory fragmentation
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# ─── Model ────────────────────────────────────────────────────────────────────
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_seq_len).float()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, seq_len):
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary(x, cos, sin):
c = cos.unsqueeze(0).unsqueeze(0)
s = sin.unsqueeze(0).unsqueeze(0)
return x * c + rotate_half(x) * s
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
super().__init__()
self.n_head = n_head
self.head_dim = n_embd // n_head
self.rot_dim = int(self.head_dim * rotary_pct)
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
self.rotary = RotaryEmbedding(self.rot_dim, block_size)
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
cos, sin = self.rotary(T)
q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
class MLP(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
self.gelu = nn.GELU()
self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
def forward(self, x):
return self.proj(self.gelu(self.fc(x)))
class Block(nn.Module):
def __init__(self, n_embd, n_head, block_size):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = MLP(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class LUNAModel(nn.Module):
def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.lm_head.weight = self.wte.weight # tie
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
m.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()
def forward(self, idx, targets=None, return_logits=True):
x = self.wte(idx)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
if not return_logits:
logits = None
return logits, loss
@property
def num_params(self):
return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
# ─── Dataset ──────────────────────────────────────────────────────────────────
class LitDataDataset(torch.utils.data.Dataset):
def __init__(self, data_path: str, block_size: int = 1024):
import struct, numpy as np
self.block_size = block_size
self.data_path = Path(data_path)
with open(self.data_path / "index.json") as f:
idx = json.load(f)
self.chunks_meta = idx["chunks"]
self._cum_blocks = []
total = 0
for c in self.chunks_meta:
n = c["dim"] // (block_size + 1)
total += n
self._cum_blocks.append(total)
self.total_blocks = total
self._chunk_cache = {}
def _load_chunk(self, chunk_idx: int):
if chunk_idx in self._chunk_cache:
return self._chunk_cache[chunk_idx]
import struct, numpy as np
meta = self.chunks_meta[chunk_idx]
with open(self.data_path / meta["filename"], "rb") as f:
raw = f.read()
num_items = struct.unpack_from("<I", raw, 0)[0]
header_bytes = (num_items + 2) * 4
tokens = torch.from_numpy(np.frombuffer(raw[header_bytes:], dtype=np.int32).copy())
if len(self._chunk_cache) >= 4:
del self._chunk_cache[next(iter(self._chunk_cache))]
self._chunk_cache[chunk_idx] = tokens
return tokens
def __len__(self):
return self.total_blocks
def __getitem__(self, idx):
chunk_idx = 0
for i, cum in enumerate(self._cum_blocks):
if idx < cum:
chunk_idx = i
break
prev = self._cum_blocks[chunk_idx - 1] if chunk_idx > 0 else 0
tokens = self._load_chunk(chunk_idx)
s = (idx - prev) * (self.block_size + 1)
e = s + self.block_size + 1
chunk = tokens[s:e]
if len(chunk) < self.block_size + 1:
pad = torch.zeros(self.block_size + 1, dtype=torch.int32)
pad[:len(chunk)] = chunk
chunk = pad
chunk = chunk.long()
return chunk[:self.block_size], chunk[1:self.block_size + 1]
# ─── Hardware Detection ────────────────────────────────────────────────────────
def probe_hardware():
info = {
"cpu_cores": os.cpu_count() or 4,
"ram_gb": psutil.virtual_memory().total / 1024**3,
}
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
info.update({
"device": "cuda",
"gpu_name": props.name,
"vram_gb": props.total_memory / 1024**3,
"sm_major": props.major,
})
if props.major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
info["precision"] = "bf16"
info["dtype"] = torch.bfloat16
else:
info["precision"] = "fp16"
info["dtype"] = torch.float16
else:
info.update({
"device": "cpu",
"gpu_name": "CPU",
"vram_gb": 0,
"sm_major": 0,
"precision": "fp32",
"dtype": torch.float32,
})
return info
def probe_max_batch(model, device, dtype, seq_len, vocab_size, max_search=4096, grad_accum_sim=4):
"""Binary search for max micro_batch. Simulates grad_accum forward+backward
passes to account for real training memory patterns. Safety: x0.70."""
tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
lo, hi, best = 1, max_search, 1
while lo <= hi:
mid = (lo + hi) // 2
try:
torch.cuda.empty_cache(); gc.collect()
tmp_opt.zero_grad(set_to_none=True)
# Simulate grad_accum micro-batches (real training pattern)
for _ in range(grad_accum_sim):
x = torch.randint(0, vocab_size, (mid, seq_len), device=device)
t = torch.randint(0, vocab_size, (mid, seq_len), device=device)
with autocast(device_type="cuda", dtype=dtype):
_, loss = model(x, t, return_logits=False)
loss = loss / grad_accum_sim
loss.backward()
del x, t, loss
tmp_opt.step()
tmp_opt.zero_grad(set_to_none=True)
best = mid; lo = mid + 1
torch.cuda.empty_cache()
except torch.cuda.OutOfMemoryError:
try: del x, t, loss
except: pass
torch.cuda.empty_cache()
tmp_opt.zero_grad(set_to_none=True)
hi = mid - 1
except RuntimeError as e:
if "out of memory" in str(e).lower():
try: del x, t, loss
except: pass
torch.cuda.empty_cache()
tmp_opt.zero_grad(set_to_none=True)
hi = mid - 1
else:
raise
del tmp_opt; torch.cuda.empty_cache(); gc.collect()
safe = max(1, int(best * 0.70))
print(f" Probe found max_batch={best}, using {safe} (70% safety, tested with {grad_accum_sim} accum steps)")
return safe
# ─── LR Schedule ──────────────────────────────────────────────────────────────
def cosine_lr(step, warmup, total, lr_max, lr_min):
if step < warmup:
return lr_max * (step + 1) / warmup
p = (step - warmup) / max(1, total - warmup)
return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min)
# ─── Config Loading ───────────────────────────────────────────────────────────
def load_config(config_path: str) -> dict:
"""Load YAML config and return flat namespace dict."""
with open(config_path, encoding="utf-8") as f:
raw = yaml.safe_load(f)
cfg = {
# top-level
"auto_config": raw.get("auto_config", True),
"data_path": raw.get("data_path", "Base/data/litdata_pretrain_final"),
"out_dir": raw.get("out_dir", "out/pretrain/luna-100m"),
"tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
# model
"vocab_size": raw["model"]["vocab_size"],
"seq_len": raw["model"]["seq_len"],
"n_layer": raw["model"]["n_layer"],
"n_embd": raw["model"]["n_embd"],
"n_head": raw["model"]["n_head"],
# train
"max_tokens": raw["train"]["max_tokens"],
"lr_warmup_steps":raw["train"]["lr_warmup_steps"],
"save_interval": raw["train"]["save_interval"],
"log_interval": raw["train"]["log_interval"],
"max_norm": raw["train"]["max_norm"],
# optimizer
"lr": raw["optimizer"]["lr"],
"min_lr": raw["optimizer"]["min_lr"],
"weight_decay": raw["optimizer"]["weight_decay"],
"betas": tuple(raw["optimizer"]["betas"]),
"eps": raw["optimizer"]["eps"],
# batch
"global_batch": raw["batch"]["global_batch"],
"micro_batch": raw["batch"]["micro_batch"],
"grad_accum": raw["batch"]["grad_accum"],
# dataloader
"num_workers": raw["dataloader"]["num_workers"],
"pin_memory": raw["dataloader"]["pin_memory"],
# hardware
"precision": raw["hardware"]["precision"],
"compile": raw["hardware"]["compile"],
}
return cfg
def apply_cli_overrides(cfg: dict, cli_args: argparse.Namespace) -> dict:
"""CLI args override config values (only if explicitly provided)."""
for key, val in vars(cli_args).items():
if key == "config":
continue
if val is not None: # argparse default=None means "not provided"
cfg[key] = val
return cfg
def resolve_auto(cfg: dict, hw: dict) -> dict:
"""
When auto_config=True: override batch, workers, lr-warmup, pin_memory,
precision from real hardware. Never touches model arch or max_tokens.
Returns updated cfg plus injected hw info.
"""
if not cfg["auto_config"]:
print(" [CONFIG] auto_config=false -- using manual values as-is")
cfg.update({"_hw": hw})
return cfg
print(" [CONFIG] auto_config=true -- tuning settings to this hardware")
# Precision
cfg["precision"] = hw["precision"]
cfg["_dtype"] = hw["dtype"]
# Workers
auto_workers = hw["cpu_cores"] // 2
# Cap by RAM: each worker caches up to 4 chunks Γ— ~67MB
max_by_ram = max(0, int(hw["ram_gb"] * 0.25 * 1024 / 268))
cfg["num_workers"] = min(auto_workers, max_by_ram, hw["cpu_cores"])
if cfg["num_workers"] == -1:
cfg["num_workers"] = 0
# Pin memory
cfg["pin_memory"] = hw["ram_gb"] > 16 and hw["device"] == "cuda"
# LR warmup: 5% of total steps (will be computed again in train())
cfg["_auto_warmup"] = True # flag: recompute once total_steps is known
# LR scaling: sqrt(global_batch / 120) relative to base lr
base_global = 120
cfg["lr"] = cfg["lr"] * math.sqrt(cfg["global_batch"] / base_global)
cfg["min_lr"] = cfg["min_lr"] * math.sqrt(cfg["global_batch"] / base_global)
cfg["_hw"] = hw
return cfg
# ─── Training ─────────────────────────────────────────────────────────────────
SEP = "=" * 72
def train(cfg: dict):
hw = cfg["_hw"]
device = torch.device(hw["device"])
# Clean GPU before anything β€” kill leftover allocations from prior runs
if device.type == "cuda":
torch.cuda.empty_cache()
gc.collect()
free_gb = (torch.cuda.get_device_properties(0).total_memory
- torch.cuda.memory_allocated()) / 1024**3
print(f" GPU free before model load: {free_gb:.1f} GB")
# Pick precision dtype
if cfg["auto_config"]:
dtype = hw.get("dtype", torch.float32)
else:
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16,
"fp32": torch.float32}.get(cfg["precision"], torch.float32)
print(SEP)
print(" LUNA 100M - Training")
print(SEP)
mode = "AUTO" if cfg["auto_config"] else "MANUAL"
print(f" Config mode : {mode}")
print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores")
print(f" Precision : {cfg['precision']} dtype={dtype}")
print(f" Workers : {cfg['num_workers']} pin_memory={cfg['pin_memory']}")
# ── Model ─────────────────────────────────────────────────────────────────
print(f"\n Building LUNA-100M...")
model = LUNAModel(
vocab_size=cfg["vocab_size"],
block_size=cfg["seq_len"],
n_layer=cfg["n_layer"],
n_embd=cfg["n_embd"],
n_head=cfg["n_head"],
).to(device)
compiled_model = False
# torch.compile disabled: causes CUDA graph / OOM issues with tied
# embeddings at this model size. Raw PyTorch + SDPA is already fast.
print(" torch.compile: disabled (not needed for 100M params)")
print(f" Parameters: {model.num_params:,} (unique)")
# ── Batch sizing ──────────────────────────────────────────────────────────
if cfg["auto_config"] and device.type == "cuda":
print(f"\n Probing max micro_batch_size (VRAM search)...")
# Probe using the actual model β€” no second copy wasting VRAM
max_mbs = probe_max_batch(
model, device, dtype, cfg["seq_len"], cfg["vocab_size"]
)
# Re-init model weights after probe (probe dirties optimizer state)
model.apply(model._init_weights)
torch.cuda.empty_cache(); gc.collect()
# grad_accum to hit global_batch
grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs))
effective_batch = max_mbs * grad_accum
print(f" AUTO -> micro_batch={max_mbs}, grad_accum={grad_accum}, "
f"effective_batch={effective_batch}")
else:
max_mbs = cfg["micro_batch"]
grad_accum = cfg["grad_accum"]
effective_batch = max_mbs * grad_accum
print(f"\n MANUAL -> micro_batch={max_mbs}, grad_accum={grad_accum}, "
f"effective_batch={effective_batch}")
tokens_per_step = effective_batch * cfg["seq_len"]
print(f" Tokens/step : {tokens_per_step:,}")
# ── Dataset ───────────────────────────────────────────────────────────────
print(f"\n Dataset: {cfg['data_path']}")
dataset = LitDataDataset(cfg["data_path"], block_size=cfg["seq_len"])
print(f" Blocks : {len(dataset):,} ({len(dataset) * cfg['seq_len']:,} tokens)")
loader = torch.utils.data.DataLoader(
dataset,
batch_size=max_mbs,
shuffle=True,
num_workers=cfg["num_workers"],
pin_memory=cfg["pin_memory"],
drop_last=True,
prefetch_factor=4 if cfg["num_workers"] > 0 else None,
persistent_workers=cfg["num_workers"] > 0,
)
# ── Optimiser ─────────────────────────────────────────────────────────────
fused_ok = device.type == "cuda" and hasattr(torch.optim, "AdamW")
try:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg["lr"], weight_decay=cfg["weight_decay"],
betas=cfg["betas"], eps=cfg["eps"],
fused=True,
)
except TypeError:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg["lr"], weight_decay=cfg["weight_decay"],
betas=cfg["betas"], eps=cfg["eps"],
)
use_scaler = dtype == torch.float16
scaler = GradScaler(enabled=use_scaler)
# ── Schedule ──────────────────────────────────────────────────────────────
total_steps = max(1, cfg["max_tokens"] // tokens_per_step)
if cfg["auto_config"] and cfg.get("_auto_warmup"):
warmup_steps = max(50, min(500, total_steps // 20))
else:
warmup_steps = min(cfg["lr_warmup_steps"], total_steps)
out_dir = Path(cfg["out_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
print(f"\n max_tokens : {cfg['max_tokens']:,}")
print(f" total_steps : {total_steps:,}")
print(f" warmup_steps : {warmup_steps}")
print(f" lr : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}")
print(f" save every : {cfg['save_interval']} steps")
print(f" out_dir : {out_dir}")
print(SEP)
# ── Resume ────────────────────────────────────────────────────────────────
start_step = 0
ckpt_path = out_dir / "latest.pt"
if ckpt_path.exists():
print(f"\n Resuming from {ckpt_path}...")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"]
print(f" Resumed at step {start_step}")
# ── Loop ──────────────────────────────────────────────────────────────────
model.train()
data_iter = iter(loader)
def get_batch():
nonlocal data_iter
try:
return next(data_iter)
except StopIteration:
data_iter = iter(loader)
return next(data_iter)
run_t0 = time.perf_counter()
tokens_seen = start_step * tokens_per_step
step = start_step
print(f"\n Starting training (step {start_step} -> {total_steps})...")
while step < total_steps:
t0 = time.perf_counter()
lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
for pg in optimizer.param_groups:
pg["lr"] = lr_now
optimizer.zero_grad(set_to_none=True)
total_loss = 0.0
for _ in range(grad_accum):
x, t = get_batch()
x = x.to(device, non_blocking=True)
t = t.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
_, loss = model(x, t, return_logits=False)
loss = loss / grad_accum
scaler.scale(loss).backward()
total_loss += loss.item()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"])
scaler.step(optimizer)
scaler.update()
if device.type == "cuda":
torch.cuda.synchronize()
dt = time.perf_counter() - t0
step += 1
tokens_seen += tokens_per_step
if step % cfg["log_interval"] == 0 or step <= 2:
tps = tokens_per_step / dt
steps_left = total_steps - step
eta_h = steps_left * dt / 3600
vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
print(f" step {step:6d}/{total_steps} | loss {total_loss:.4f} | "
f"lr {lr_now:.2e} | {tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h")
if step % cfg["save_interval"] == 0 or step == total_steps:
raw = model._orig_mod if hasattr(model, "_orig_mod") else model
step_dir = out_dir / f"step-{step:08d}"
step_dir.mkdir(parents=True, exist_ok=True)
torch.save(raw.state_dict(), step_dir / "lit_model.pth")
torch.save({"step": step, "model": raw.state_dict(),
"optimizer": optimizer.state_dict(),
"tokens_seen": tokens_seen},
out_dir / "latest.pt")
print(f" Saved -> {step_dir}")
# ── Final ─────────────────────────────────────────────────────────────────
final_dir = out_dir / "final"
final_dir.mkdir(parents=True, exist_ok=True)
raw = model._orig_mod if hasattr(model, "_orig_mod") else model
torch.save(raw.state_dict(), final_dir / "lit_model.pth")
import shutil
tok_src = Path(cfg["tokenizer_dir"])
if tok_src.exists():
shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True)
total_h = (time.perf_counter() - run_t0) / 3600
print(SEP)
print(f" Done! {total_h:.2f} h -> {final_dir}")
print(SEP)
# ─── Entry point ──────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(description="LUNA 100M Trainer")
p.add_argument("--config", type=str, default="train_config.yaml",
help="Path to train_config.yaml")
# CLI overrides (all optional - omit to use config value)
p.add_argument("--data_path", type=str, default=None)
p.add_argument("--out_dir", type=str, default=None)
p.add_argument("--max_tokens", type=int, default=None)
p.add_argument("--micro_batch", type=int, default=None)
p.add_argument("--global_batch",type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--num_workers", type=int, default=None)
p.add_argument("--save_interval",type=int, default=None)
p.add_argument("--log_interval",type=int, default=None)
p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"),
default=None, help="Override auto_config (true/false)")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
cfg = load_config(args.config)
cfg = apply_cli_overrides(cfg, args)
hw = probe_hardware()
cfg = resolve_auto(cfg, hw)
train(cfg)