| """ | |
| train_gpt_simple.py | |
| This file descends from the [NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt). | |
| It was prepared as a simplified version of the speedrun for use in neural net optimization research. | |
| """ | |
| import os | |
| import sys | |
| with open(sys.argv[0]) as f: | |
| code = f.read() # read the code of this file ASAP, for logging | |
| import uuid | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.optim import AdamW | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| ######################################## | |
| # Dataloader # | |
| ######################################## | |
| def _load_data_shard(file: Path): | |
| header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 | |
| assert header[0] == 20240520, "magic number mismatch in the data .bin file" | |
| assert header[1] == 1, "unsupported version" | |
| num_tokens = int(header[2]) # number of tokens (claimed) | |
| with file.open("rb", buffering=0) as f: | |
| tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) | |
| f.seek(256 * 4) | |
| nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy | |
| assert nbytes == 2 * num_tokens, "number of tokens read does not match header" | |
| return tokens | |
| def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len=1024): | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| files = sorted(Path.cwd().glob(filename_pattern)) | |
| assert batch_size % world_size == 0 | |
| local_batch_size = batch_size // world_size | |
| file_iter = iter(files) | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| while True: | |
| if pos + batch_size + 1 >= len(tokens): | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1] | |
| inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) | |
| targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) | |
| pos += batch_size | |
| yield inputs.view(-1, seq_len), targets.view(-1, seq_len) | |
| ######################################## | |
| # Architecture # | |
| ######################################## | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gains = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| return F.rms_norm(x, (x.size(-1),), weight=self.gains.type_as(x)) | |
| class Linear(nn.Linear): | |
| def __init__(self, in_features, out_features): | |
| super().__init__(in_features, out_features, bias=True) | |
| def forward(self, x): | |
| return F.linear(x, self.weight.type_as(x), self.bias.type_as(x)) | |
| class Rotary(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| # half-truncate RoPE (w/ base freq tuning) | |
| angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) | |
| self.register_buffer("angular_freq", torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])) | |
| def forward(self, x_BTHD: Tensor): | |
| pos = torch.arange(x_BTHD.size(1), dtype=torch.float32, device=x_BTHD.device) | |
| theta = torch.outer(pos, self.angular_freq)[None, :, None, :] | |
| cos, sin = theta.cos(), theta.sin() | |
| x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat((y1, y2), 3).type_as(x_BTHD) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, dim: int, head_dim=128): | |
| super().__init__() | |
| self.num_heads = dim // head_dim | |
| self.head_dim = head_dim | |
| hdim = self.num_heads * self.head_dim | |
| self.q = Linear(dim, hdim) | |
| self.k = Linear(dim, hdim) | |
| self.v = Linear(dim, hdim) | |
| self.proj = Linear(hdim, dim) | |
| self.rotary = Rotary(head_dim) | |
| def forward(self, x: Tensor): | |
| B, T = x.size(0), x.size(1) | |
| q = self.q(x).view(B, T, self.num_heads, self.head_dim) | |
| k = self.k(x).view(B, T, self.num_heads, self.head_dim) | |
| v = self.v(x).view(B, T, self.num_heads, self.head_dim) | |
| q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) | |
| q, k = self.rotary(q), self.rotary(k) | |
| y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), | |
| v.transpose(1, 2), scale=0.12, is_causal=True).transpose(1, 2) | |
| y = y.contiguous().view(B, T, self.num_heads * self.head_dim) | |
| y = self.proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| hdim = 4 * dim | |
| self.fc = Linear(dim, hdim) | |
| self.proj = Linear(hdim, dim) | |
| def forward(self, x: Tensor): | |
| x = self.fc(x) | |
| x = x.relu().square() | |
| x = self.proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.attn = CausalSelfAttention(dim) | |
| self.mlp = MLP(dim) | |
| self.norm1 = RMSNorm(dim) | |
| self.norm2 = RMSNorm(dim) | |
| def forward(self, x: Tensor): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__(self, vocab_size: int, num_layers: int, model_dim: int): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, model_dim).bfloat16() | |
| self.blocks = nn.ModuleList([Block(model_dim) for _ in range(num_layers)]) | |
| self.proj = Linear(model_dim, vocab_size) | |
| self.norm1 = RMSNorm(model_dim) | |
| self.norm2 = RMSNorm(model_dim) | |
| def forward(self, inputs: Tensor, targets: Tensor): | |
| x = self.norm1(self.embed(inputs)) | |
| for block in self.blocks: | |
| x = block(x) | |
| logits = self.proj(self.norm2(x)).float() | |
| logits = 15 * logits * (logits.square() + 15**2).rsqrt() | |
| return F.cross_entropy(logits.view(targets.numel(), -1), targets.view(-1), reduction="sum") | |
| ######################################## | |
| # Optimizer # | |
| ######################################## | |
| def zeropower_via_newtonschulz5(G: Tensor) -> Tensor: | |
| assert G.ndim >= 2 | |
| X = G.bfloat16() | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| # Ensure spectral norm is at most 1 | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
| # Perform the NS iterations, not optimizing for wallclock speed | |
| a, b, c = 2, -1.5, 0.5 | |
| for _ in range(12): | |
| A = X @ X.mT | |
| B = b * A + c * A @ A | |
| X = a * X + B @ X | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| return X | |
| @torch.compile | |
| def muon_update(grad, momentum, mu=0.95, nesterov=True): | |
| momentum.lerp_(grad, 1 - mu) | |
| update = grad.lerp_(momentum, mu) if nesterov else momentum | |
| update = zeropower_via_newtonschulz5(update) | |
| update *= max(1, grad.size(-2) / grad.size(-1))**0.5 | |
| return update | |
| class Muon(torch.optim.Optimizer): | |
| def __init__(self, params, lr=0.02, weight_decay=0, mu=0.95): | |
| assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) | |
| params = sorted(params, key=lambda x: x.size(), reverse=True) | |
| defaults = dict(lr=lr, weight_decay=weight_decay, mu=mu) | |
| super().__init__(params, defaults) | |
| @torch.no_grad() | |
| def step(self): | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| for group in self.param_groups: | |
| params = group["params"] | |
| params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size) | |
| for base_i in range(0, len(params), world_size): | |
| if base_i + rank < len(params): | |
| p = params[base_i + rank] | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["momentum"] = torch.zeros_like(p) | |
| update = muon_update(p.grad, state["momentum"], mu=group["mu"]) | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| p.add_(update, alpha=-group["lr"]) | |
| dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank]) | |
| class Lion(torch.optim.Optimizer): | |
| def __init__(self, params, lr=2e-4, betas=(0.9, 0.99), weight_decay=0.0): | |
| defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
| super().__init__(params, defaults) | |
| @torch.no_grad() | |
| def step(self): | |
| for group in self.param_groups: | |
| beta1, beta2 = group["betas"] | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| if group["weight_decay"]: | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| grad = p.grad | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["exp_avg"] = torch.zeros_like(p) | |
| exp_avg = state["exp_avg"] | |
| update = exp_avg.lerp(grad, 1 - beta1) | |
| p.add_(update.sign(), alpha=-group["lr"]) | |
| exp_avg.lerp_(grad, 1 - beta2) | |
| ######################################## | |
| # Setup # | |
| ######################################## | |
| # torchrun sets these env variables | |
| device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) | |
| torch.cuda.set_device(device) | |
| dist.init_process_group(backend="nccl", device_id=device) | |
| dist.barrier() | |
| # this code can be run equivalently with 1, 2, 4, or 8 gpus. | |
| assert 8 % dist.get_world_size() == 0 | |
| # logging setup | |
| if dist.get_rank() == 0: | |
| os.makedirs("logs", exist_ok=True) | |
| logfile = f"logs/{uuid.uuid4()}.txt" | |
| print(logfile) | |
| def print0(s, console=False, log=True): | |
| if dist.get_rank() == 0: | |
| if console: | |
| print(s) | |
| if log: | |
| with open(logfile, "a") as f: | |
| print(s, file=f) | |
| # we begin by logging this file itself | |
| print0(code) | |
| print0("="*100) | |
| print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" | |
| + f" on {torch.cuda.get_device_name(device)} with world_size {dist.get_world_size()}") | |
| print0("="*100) | |
| val_tokens = 20 * 524288 | |
| batch_size = 8 * 64 * 1024 | |
| mbs = 64 | |
| val_inputs, val_targets = next(distributed_data_generator("data/fineweb10B/fineweb_val_*.bin", val_tokens)) | |
| model = GPT(vocab_size=50304, num_layers=12, model_dim=768).cuda() | |
| model.compile(dynamic=False) | |
| num_trials = int(sys.argv[-1]) if len(sys.argv) > 1 else 1 | |
| for _ in range(num_trials): | |
| ######################################## | |
| # Init & Optim Hyperparams # | |
| ######################################## | |
| # we want to minimize this while still reaching 3.28 val loss | |
| train_steps = 5750 | |
| # initialize model parameters | |
| for name, p in model.named_parameters(): | |
| if name.endswith("weight"): | |
| if "proj" in name: | |
| p.data.zero_() | |
| elif "embed" in name: | |
| p.data.normal_() # default torch init | |
| else: | |
| p.data.normal_(std=0.33**0.5 / p.size(-1)**0.5) # default torch init | |
| elif name.endswith("bias"): | |
| p.data.zero_() | |
| elif name.endswith("gains"): | |
| p.data.normal_(mean=1, std=0) | |
| else: | |
| raise Exception(f"Uninitialized parameter: {name}") | |
| # create the optimizer(s) | |
| optimizer1 = AdamW([dict(params=[model.embed.weight], lr=0.3), | |
| dict(params=[model.proj.weight], lr=1/320), | |
| dict(params=[p for p in model.parameters() if p.ndim < 2], lr=0.01)], | |
| betas=(0.8, 0.95), eps=1e-10, weight_decay=0, fused=True) | |
| optimizer2 = Lion([p for p in model.blocks.parameters() if p.ndim >= 2], | |
| lr=0.0002, betas=(0.9, 0.99), weight_decay=0.1) | |
| optimizers = [optimizer1, optimizer2] | |
| assert set(p for opt in optimizers for group in opt.param_groups | |
| for p in group["params"]) == set(model.parameters()) | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| group["initial_lr"] = group["lr"] | |
| group["warmup_steps"] = 250 if opt is optimizer2 else 0 | |
| # learning rate schedule: stable then decay | |
| def set_hparams(step, cooldown_frac=0.7): | |
| progress = step / train_steps | |
| assert 0 <= progress < 1 | |
| if progress < 1 - cooldown_frac: | |
| eta = 1.0 | |
| else: | |
| eta = (1 - progress) / cooldown_frac | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| warmup_steps = group.get("warmup_steps", 0) | |
| warmup = min(1.0, (step + 1) / warmup_steps) if warmup_steps > 0 else 1.0 | |
| group["lr"] = group["initial_lr"] * warmup * eta | |
| ######################################## | |
| # Training and Validation # | |
| ######################################## | |
| train_loader = distributed_data_generator("data/fineweb10B/fineweb_train_*.bin", batch_size) | |
| for p in model.parameters(): | |
| dist.broadcast(p.detach(), 0) | |
| # start the clock | |
| training_time = 0 | |
| last_val_step = 0 | |
| dist.barrier() | |
| t0 = time.perf_counter() | |
| for step in range(train_steps + 1): | |
| # --------------- VALIDATION SECTION ----------------- | |
| if step == train_steps or step % 125 == 0: | |
| # stop the clock | |
| dist.barrier() | |
| time_since_last_val = time.perf_counter() - t0 | |
| step_avg = time_since_last_val / (step - last_val_step) if step > 0 else float("nan") | |
| last_val_step = step | |
| training_time += time_since_last_val | |
| model.eval() | |
| val_loss = 0 | |
| with torch.no_grad(): | |
| assert len(val_inputs) % mbs == 0 | |
| for i in range(len(val_inputs) // mbs): | |
| val_loss += model(val_inputs[i*mbs:(i+1)*mbs], val_targets[i*mbs:(i+1)*mbs]) | |
| dist.all_reduce(val_loss, op=dist.ReduceOp.SUM) | |
| val_loss /= val_tokens | |
| print0(f"step:{step}/{train_steps} val_loss:{val_loss:.5f} train_time:{training_time:.3f}s" | |
| + f" step_avg:{1000*step_avg:.2f}ms", console=True) | |
| model.train() | |
| # start the clock again | |
| dist.barrier() | |
| t0 = time.perf_counter() | |
| if step == train_steps: | |
| break | |
| # --------------- TRAINING SECTION ----------------- | |
| inputs, targets = next(train_loader) | |
| # accumulate across microbatches in case we are running with fewer than 8 gpus | |
| assert len(inputs) % mbs == 0 | |
| for i in range(len(inputs) // mbs): | |
| model(inputs[i*mbs:(i+1)*mbs], targets[i*mbs:(i+1)*mbs]).backward() | |
| for name, p in model.named_parameters(): | |
| assert p.grad is not None, name | |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) | |
| # set optimization hyperparameters and take a step | |
| set_hparams(step) | |
| for opt in optimizers: | |
| opt.step() | |
| model.zero_grad(set_to_none=True) | |
| approx_training_time = training_time + (time.perf_counter() - t0) | |
| print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time:.3f}s" | |
| + f" step_avg:{1000*approx_training_time/(step + 1):.2f}ms", console=True, log=False) | |
| dist.destroy_process_group() | |
| ==================================================================================================== | |
| Running PyTorch 2.11.0+cu128 compiled for CUDA 12.8 on NVIDIA H100 80GB HBM3 with world_size 2 | |
| ==================================================================================================== | |
| step:0/5750 val_loss:10.82585 train_time:0.000s step_avg:nanms | |
| step:125/5750 val_loss:5.36578 train_time:71.209s step_avg:569.67ms | |
| step:250/5750 val_loss:4.82762 train_time:140.066s step_avg:550.86ms | |
| step:375/5750 val_loss:4.46776 train_time:208.647s step_avg:548.64ms | |
| step:500/5750 val_loss:4.20396 train_time:277.237s step_avg:548.72ms | |
| step:625/5750 val_loss:4.04735 train_time:345.949s step_avg:549.70ms | |
| step:750/5750 val_loss:3.94606 train_time:414.688s step_avg:549.91ms | |
| step:875/5750 val_loss:3.87099 train_time:483.363s step_avg:549.40ms | |
| step:1000/5750 val_loss:3.80722 train_time:552.032s step_avg:549.35ms | |
Xet Storage Details
- Size:
- 16.3 kB
- Xet hash:
- bf5792b9a193b022a57d08c9e26d01b0f5417b6014b7de54cc4b521c2457e698
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.