Buckets:

cmpatino's picture
download
raw
16.3 kB
"""
train_gpt_simple.py
This file descends from the [NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt).
It was prepared as a simplified version of the speedrun for use in neural net optimization research.
"""
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import time
from pathlib import Path
import torch
from torch import Tensor, nn
from torch.optim import AdamW
import torch.nn.functional as F
import torch.distributed as dist
########################################
# Dataloader #
########################################
def _load_data_shard(file: Path):
header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2]) # number of tokens (claimed)
with file.open("rb", buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
return tokens
def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len=1024):
world_size = dist.get_world_size()
rank = dist.get_rank()
files = sorted(Path.cwd().glob(filename_pattern))
assert batch_size % world_size == 0
local_batch_size = batch_size // world_size
file_iter = iter(files)
tokens, pos = _load_data_shard(next(file_iter)), 0
while True:
if pos + batch_size + 1 >= len(tokens):
tokens, pos = _load_data_shard(next(file_iter)), 0
buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True)
pos += batch_size
yield inputs.view(-1, seq_len), targets.view(-1, seq_len)
########################################
# Architecture #
########################################
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gains = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.rms_norm(x, (x.size(-1),), weight=self.gains.type_as(x))
class Linear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=True)
def forward(self, x):
return F.linear(x, self.weight.type_as(x), self.bias.type_as(x))
class Rotary(nn.Module):
def __init__(self, dim: int):
super().__init__()
# half-truncate RoPE (w/ base freq tuning)
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
self.register_buffer("angular_freq", torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]))
def forward(self, x_BTHD: Tensor):
pos = torch.arange(x_BTHD.size(1), dtype=torch.float32, device=x_BTHD.device)
theta = torch.outer(pos, self.angular_freq)[None, :, None, :]
cos, sin = theta.cos(), theta.sin()
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, head_dim=128):
super().__init__()
self.num_heads = dim // head_dim
self.head_dim = head_dim
hdim = self.num_heads * self.head_dim
self.q = Linear(dim, hdim)
self.k = Linear(dim, hdim)
self.v = Linear(dim, hdim)
self.proj = Linear(hdim, dim)
self.rotary = Rotary(head_dim)
def forward(self, x: Tensor):
B, T = x.size(0), x.size(1)
q = self.q(x).view(B, T, self.num_heads, self.head_dim)
k = self.k(x).view(B, T, self.num_heads, self.head_dim)
v = self.v(x).view(B, T, self.num_heads, self.head_dim)
q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))
q, k = self.rotary(q), self.rotary(k)
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2),
v.transpose(1, 2), scale=0.12, is_causal=True).transpose(1, 2)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim)
y = self.proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
self.fc = Linear(dim, hdim)
self.proj = Linear(hdim, dim)
def forward(self, x: Tensor):
x = self.fc(x)
x = x.relu().square()
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.attn = CausalSelfAttention(dim)
self.mlp = MLP(dim)
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
def forward(self, x: Tensor):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, model_dim: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim).bfloat16()
self.blocks = nn.ModuleList([Block(model_dim) for _ in range(num_layers)])
self.proj = Linear(model_dim, vocab_size)
self.norm1 = RMSNorm(model_dim)
self.norm2 = RMSNorm(model_dim)
def forward(self, inputs: Tensor, targets: Tensor):
x = self.norm1(self.embed(inputs))
for block in self.blocks:
x = block(x)
logits = self.proj(self.norm2(x)).float()
logits = 15 * logits * (logits.square() + 15**2).rsqrt()
return F.cross_entropy(logits.view(targets.numel(), -1), targets.view(-1), reduction="sum")
########################################
# Optimizer #
########################################
def zeropower_via_newtonschulz5(G: Tensor) -> Tensor:
assert G.ndim >= 2
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations, not optimizing for wallclock speed
a, b, c = 2, -1.5, 0.5
for _ in range(12):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def muon_update(grad, momentum, mu=0.95, nesterov=True):
momentum.lerp_(grad, 1 - mu)
update = grad.lerp_(momentum, mu) if nesterov else momentum
update = zeropower_via_newtonschulz5(update)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
return update
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, weight_decay=0, mu=0.95):
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
defaults = dict(lr=lr, weight_decay=weight_decay, mu=mu)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
world_size = dist.get_world_size()
rank = dist.get_rank()
for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size)
for base_i in range(0, len(params), world_size):
if base_i + rank < len(params):
p = params[base_i + rank]
state = self.state[p]
if len(state) == 0:
state["momentum"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum"], mu=group["mu"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank])
class Lion(torch.optim.Optimizer):
def __init__(self, params, lr=2e-4, betas=(0.9, 0.99), weight_decay=0.0):
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
beta1, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
if group["weight_decay"]:
p.mul_(1 - group["lr"] * group["weight_decay"])
grad = p.grad
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
exp_avg = state["exp_avg"]
update = exp_avg.lerp(grad, 1 - beta1)
p.add_(update.sign(), alpha=-group["lr"])
exp_avg.lerp_(grad, 1 - beta2)
########################################
# Setup #
########################################
# torchrun sets these env variables
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
# this code can be run equivalently with 1, 2, 4, or 8 gpus.
assert 8 % dist.get_world_size() == 0
# logging setup
if dist.get_rank() == 0:
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{uuid.uuid4()}.txt"
print(logfile)
def print0(s, console=False, log=True):
if dist.get_rank() == 0:
if console:
print(s)
if log:
with open(logfile, "a") as f:
print(s, file=f)
# we begin by logging this file itself
print0(code)
print0("="*100)
print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}"
+ f" on {torch.cuda.get_device_name(device)} with world_size {dist.get_world_size()}")
print0("="*100)
val_tokens = 20 * 524288
batch_size = 8 * 64 * 1024
mbs = 64
val_inputs, val_targets = next(distributed_data_generator("data/fineweb10B/fineweb_val_*.bin", val_tokens))
model = GPT(vocab_size=50304, num_layers=12, model_dim=768).cuda()
model.compile(dynamic=False)
num_trials = int(sys.argv[-1]) if len(sys.argv) > 1 else 1
for _ in range(num_trials):
########################################
# Init & Optim Hyperparams #
########################################
# we want to minimize this while still reaching 3.28 val loss
train_steps = 5750
# initialize model parameters
for name, p in model.named_parameters():
if name.endswith("weight"):
if "proj" in name:
p.data.zero_()
elif "embed" in name:
p.data.normal_() # default torch init
else:
p.data.normal_(std=0.33**0.5 / p.size(-1)**0.5) # default torch init
elif name.endswith("bias"):
p.data.zero_()
elif name.endswith("gains"):
p.data.normal_(mean=1, std=0)
else:
raise Exception(f"Uninitialized parameter: {name}")
# create the optimizer(s)
optimizer1 = AdamW([dict(params=[model.embed.weight], lr=0.3),
dict(params=[model.proj.weight], lr=1/320),
dict(params=[p for p in model.parameters() if p.ndim < 2], lr=0.01)],
betas=(0.8, 0.95), eps=1e-10, weight_decay=0, fused=True)
optimizer2 = Lion([p for p in model.blocks.parameters() if p.ndim >= 2],
lr=0.0002, betas=(0.9, 0.99), weight_decay=0.1)
optimizers = [optimizer1, optimizer2]
assert set(p for opt in optimizers for group in opt.param_groups
for p in group["params"]) == set(model.parameters())
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
group["warmup_steps"] = 250 if opt is optimizer2 else 0
# learning rate schedule: stable then decay
def set_hparams(step, cooldown_frac=0.7):
progress = step / train_steps
assert 0 <= progress < 1
if progress < 1 - cooldown_frac:
eta = 1.0
else:
eta = (1 - progress) / cooldown_frac
for opt in optimizers:
for group in opt.param_groups:
warmup_steps = group.get("warmup_steps", 0)
warmup = min(1.0, (step + 1) / warmup_steps) if warmup_steps > 0 else 1.0
group["lr"] = group["initial_lr"] * warmup * eta
########################################
# Training and Validation #
########################################
train_loader = distributed_data_generator("data/fineweb10B/fineweb_train_*.bin", batch_size)
for p in model.parameters():
dist.broadcast(p.detach(), 0)
# start the clock
training_time = 0
last_val_step = 0
dist.barrier()
t0 = time.perf_counter()
for step in range(train_steps + 1):
# --------------- VALIDATION SECTION -----------------
if step == train_steps or step % 125 == 0:
# stop the clock
dist.barrier()
time_since_last_val = time.perf_counter() - t0
step_avg = time_since_last_val / (step - last_val_step) if step > 0 else float("nan")
last_val_step = step
training_time += time_since_last_val
model.eval()
val_loss = 0
with torch.no_grad():
assert len(val_inputs) % mbs == 0
for i in range(len(val_inputs) // mbs):
val_loss += model(val_inputs[i*mbs:(i+1)*mbs], val_targets[i*mbs:(i+1)*mbs])
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
val_loss /= val_tokens
print0(f"step:{step}/{train_steps} val_loss:{val_loss:.5f} train_time:{training_time:.3f}s"
+ f" step_avg:{1000*step_avg:.2f}ms", console=True)
model.train()
# start the clock again
dist.barrier()
t0 = time.perf_counter()
if step == train_steps:
break
# --------------- TRAINING SECTION -----------------
inputs, targets = next(train_loader)
# accumulate across microbatches in case we are running with fewer than 8 gpus
assert len(inputs) % mbs == 0
for i in range(len(inputs) // mbs):
model(inputs[i*mbs:(i+1)*mbs], targets[i*mbs:(i+1)*mbs]).backward()
for name, p in model.named_parameters():
assert p.grad is not None, name
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# set optimization hyperparameters and take a step
set_hparams(step)
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
approx_training_time = training_time + (time.perf_counter() - t0)
print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time:.3f}s"
+ f" step_avg:{1000*approx_training_time/(step + 1):.2f}ms", console=True, log=False)
dist.destroy_process_group()
====================================================================================================
Running PyTorch 2.11.0+cu128 compiled for CUDA 12.8 on NVIDIA H100 80GB HBM3 with world_size 2
====================================================================================================
step:0/5750 val_loss:10.82585 train_time:0.000s step_avg:nanms
step:125/5750 val_loss:5.36578 train_time:71.209s step_avg:569.67ms
step:250/5750 val_loss:4.82762 train_time:140.066s step_avg:550.86ms
step:375/5750 val_loss:4.46776 train_time:208.647s step_avg:548.64ms
step:500/5750 val_loss:4.20396 train_time:277.237s step_avg:548.72ms
step:625/5750 val_loss:4.04735 train_time:345.949s step_avg:549.70ms
step:750/5750 val_loss:3.94606 train_time:414.688s step_avg:549.91ms
step:875/5750 val_loss:3.87099 train_time:483.363s step_avg:549.40ms
step:1000/5750 val_loss:3.80722 train_time:552.032s step_avg:549.35ms

Xet Storage Details

Size:
16.3 kB
·
Xet hash:
bf5792b9a193b022a57d08c9e26d01b0f5417b6014b7de54cc4b521c2457e698

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.