ml-intern-explorers/efficient-optimizer-collab / artifacts /psgd_kron_baseline_cmpatino-1 /train_gpt_simple.py
| """ | |
| train_gpt_simple.py | |
| This file descends from the [NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt). | |
| It was prepared as a simplified version of the speedrun for use in neural net optimization research. | |
| """ | |
| import os | |
| import sys | |
| with open(sys.argv[0]) as f: | |
| code = f.read() # read the code of this file ASAP, for logging | |
| import uuid | |
| import time | |
| from pathlib import Path | |
| import string | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.optim import AdamW | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| ######################################## | |
| # Dataloader # | |
| ######################################## | |
| def _load_data_shard(file: Path): | |
| header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 | |
| assert header[0] == 20240520, "magic number mismatch in the data .bin file" | |
| assert header[1] == 1, "unsupported version" | |
| num_tokens = int(header[2]) # number of tokens (claimed) | |
| with file.open("rb", buffering=0) as f: | |
| tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) | |
| f.seek(256 * 4) | |
| nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy | |
| assert nbytes == 2 * num_tokens, "number of tokens read does not match header" | |
| return tokens | |
| def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len=1024): | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| files = sorted(Path.cwd().glob(filename_pattern)) | |
| assert batch_size % world_size == 0 | |
| local_batch_size = batch_size // world_size | |
| file_iter = iter(files) | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| while True: | |
| if pos + batch_size + 1 >= len(tokens): | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1] | |
| inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) | |
| targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) | |
| pos += batch_size | |
| yield inputs.view(-1, seq_len), targets.view(-1, seq_len) | |
| ######################################## | |
| # Architecture # | |
| ######################################## | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gains = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| return F.rms_norm(x, (x.size(-1),), weight=self.gains.type_as(x)) | |
| class Linear(nn.Linear): | |
| def __init__(self, in_features, out_features): | |
| super().__init__(in_features, out_features, bias=True) | |
| def forward(self, x): | |
| return F.linear(x, self.weight.type_as(x), self.bias.type_as(x)) | |
| class Rotary(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| # half-truncate RoPE (w/ base freq tuning) | |
| angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) | |
| self.register_buffer("angular_freq", torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])) | |
| def forward(self, x_BTHD: Tensor): | |
| pos = torch.arange(x_BTHD.size(1), dtype=torch.float32, device=x_BTHD.device) | |
| theta = torch.outer(pos, self.angular_freq)[None, :, None, :] | |
| cos, sin = theta.cos(), theta.sin() | |
| x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat((y1, y2), 3).type_as(x_BTHD) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, dim: int, head_dim=128): | |
| super().__init__() | |
| self.num_heads = dim // head_dim | |
| self.head_dim = head_dim | |
| hdim = self.num_heads * self.head_dim | |
| self.q = Linear(dim, hdim) | |
| self.k = Linear(dim, hdim) | |
| self.v = Linear(dim, hdim) | |
| self.proj = Linear(hdim, dim) | |
| self.rotary = Rotary(head_dim) | |
| def forward(self, x: Tensor): | |
| B, T = x.size(0), x.size(1) | |
| q = self.q(x).view(B, T, self.num_heads, self.head_dim) | |
| k = self.k(x).view(B, T, self.num_heads, self.head_dim) | |
| v = self.v(x).view(B, T, self.num_heads, self.head_dim) | |
| q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) | |
| q, k = self.rotary(q), self.rotary(k) | |
| y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), | |
| v.transpose(1, 2), scale=0.12, is_causal=True).transpose(1, 2) | |
| y = y.contiguous().view(B, T, self.num_heads * self.head_dim) | |
| y = self.proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| hdim = 4 * dim | |
| self.fc = Linear(dim, hdim) | |
| self.proj = Linear(hdim, dim) | |
| def forward(self, x: Tensor): | |
| x = self.fc(x) | |
| x = x.relu().square() | |
| x = self.proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.attn = CausalSelfAttention(dim) | |
| self.mlp = MLP(dim) | |
| self.norm1 = RMSNorm(dim) | |
| self.norm2 = RMSNorm(dim) | |
| def forward(self, x: Tensor): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__(self, vocab_size: int, num_layers: int, model_dim: int): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, model_dim).bfloat16() | |
| self.blocks = nn.ModuleList([Block(model_dim) for _ in range(num_layers)]) | |
| self.proj = Linear(model_dim, vocab_size) | |
| self.norm1 = RMSNorm(model_dim) | |
| self.norm2 = RMSNorm(model_dim) | |
| def forward(self, inputs: Tensor, targets: Tensor): | |
| x = self.norm1(self.embed(inputs)) | |
| for block in self.blocks: | |
| x = block(x) | |
| logits = self.proj(self.norm2(x)).float() | |
| logits = 15 * logits * (logits.square() + 15**2).rsqrt() | |
| return F.cross_entropy(logits.view(targets.numel(), -1), targets.view(-1), reduction="sum") | |
| ######################################## | |
| # Optimizer # | |
| ######################################## | |
| def zeropower_via_newtonschulz5(G: Tensor) -> Tensor: | |
| assert G.ndim >= 2 | |
| X = G.bfloat16() | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| # Ensure spectral norm is at most 1 | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
| # Perform the NS iterations, not optimizing for wallclock speed | |
| a, b, c = 2, -1.5, 0.5 | |
| for _ in range(12): | |
| A = X @ X.mT | |
| B = b * A + c * A @ A | |
| X = a * X + B @ X | |
| if G.size(-2) > G.size(-1): | |
| X = X.mT | |
| return X | |
| def muon_update(grad, momentum, mu=0.95, nesterov=True): | |
| momentum.lerp_(grad, 1 - mu) | |
| update = grad.lerp_(momentum, mu) if nesterov else momentum | |
| update = zeropower_via_newtonschulz5(update) | |
| update *= max(1, grad.size(-2) / grad.size(-1))**0.5 | |
| return update | |
| class Muon(torch.optim.Optimizer): | |
| def __init__(self, params, lr=0.02, weight_decay=0, mu=0.95): | |
| assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) | |
| params = sorted(params, key=lambda x: x.size(), reverse=True) | |
| defaults = dict(lr=lr, weight_decay=weight_decay, mu=mu) | |
| super().__init__(params, defaults) | |
| def step(self): | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| for group in self.param_groups: | |
| params = group["params"] | |
| params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size) | |
| for base_i in range(0, len(params), world_size): | |
| if base_i + rank < len(params): | |
| p = params[base_i + rank] | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["momentum"] = torch.zeros_like(p) | |
| update = muon_update(p.grad, state["momentum"], mu=group["mu"]) | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| p.add_(update, alpha=-group["lr"]) | |
| dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank]) | |
| ######################################## | |
| # PSGD Kron # | |
| ######################################## | |
| class ProbScheduler: | |
| """Scheduler for annealing preconditioner update probability. | |
| Implements an exponential anneal with a flat start. | |
| """ | |
| def __init__(self, max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500): | |
| self.max_prob = torch.tensor(max_prob, dtype=torch.float32) | |
| self.min_prob = torch.tensor(min_prob, dtype=torch.float32) | |
| self.decay = torch.tensor(decay, dtype=torch.float32) | |
| self.flat_start = torch.tensor(flat_start, dtype=torch.float32) | |
| self._compiled = False | |
| try: | |
| self._compiled_schedule = torch.compile(self._schedule_fn) | |
| self._compiled = True | |
| except Exception: | |
| pass | |
| def _schedule_fn(self, n): | |
| prob = self.max_prob * torch.exp(-self.decay * (n - self.flat_start)) | |
| prob.clamp_(min=self.min_prob, max=self.max_prob) | |
| return prob | |
| def __call__(self, n): | |
| if self._compiled: | |
| return self._compiled_schedule(n) | |
| else: | |
| return self._schedule_fn(n) | |
| def __reduce__(self): | |
| return (self.__class__, ( | |
| self.max_prob.item(), | |
| self.min_prob.item(), | |
| self.decay.item(), | |
| self.flat_start.item() | |
| )) | |
| def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500): | |
| """Anneal preconditioner update probability during beginning of training. | |
| PSGD benefits from more preconditioner updates at the beginning of training, | |
| but once the preconditioner is learned the update probability can drop low. | |
| This schedule is an exponential anneal with a flat start. Default settings keep | |
| update probability at 1.0 for 500 steps then exponentially anneal down to | |
| `min_prob` by 4000 steps. Default settings work very well for most models and | |
| training regimes. | |
| """ | |
| return ProbScheduler(max_prob, min_prob, decay, flat_start) | |
| class PSGDKron(torch.optim.Optimizer): | |
| """PSGD Kron optimizer with layer-wise pipeline parallelism. | |
| Args: | |
| params: Parameters to optimize | |
| lr: Learning rate | |
| b1: Momentum | |
| weight_decay: Weight decay | |
| preconditioner_update_probability: Prob of updating preconditioner (default: anneals 1.0->0.03 by 4000 steps) | |
| max_size_triangular: Max size for triangular preconditioner | |
| min_ndim_triangular: Min dims needed for triangular preconditioners | |
| memory_save_mode: Memory saving mode: | |
| None: All triangular preconditioners | |
| 'smart_one_diag': Large outlier dims use diagonal | |
| 'one_diag': Largest dim uses diagonal | |
| 'all_diag': All diagonal preconditioners | |
| precond_lr: Preconditioner learning rate (default: 0.1) | |
| precond_init_scale: Initial preconditioner scale (default: 1.0) | |
| merge_dims: Whether to combine dims to make grad tensor a matrix | |
| partition_grads: Whether to partition gradients | |
| block_size: Partition size for gradients | |
| clip_update_rms: Clip update RMS at 1.1 | |
| dtype: Data type for params/grads | |
| rank: Worker rank for pipeline | |
| world_size: Total workers for pipeline | |
| """ | |
| def __init__( | |
| self, | |
| params, | |
| lr=0.0003, | |
| b1=0.9, | |
| weight_decay=0.0, | |
| preconditioner_update_probability=None, | |
| max_size_triangular=8192, | |
| min_ndim_triangular=2, | |
| memory_save_mode=None, | |
| precond_lr=0.1, | |
| precond_init_scale=1.0, | |
| merge_dims=True, | |
| partition_grads=False, | |
| block_size=1024, | |
| clip_update_rms=True, | |
| dtype=torch.float32, | |
| rank=0, | |
| world_size=1, | |
| ): | |
| self.rank = rank | |
| self.world_size = world_size | |
| if preconditioner_update_probability is None: | |
| preconditioner_update_probability = precond_update_prob_schedule() | |
| params = [*params] | |
| sizes = {p.numel() for p in params} | |
| def create_update_buffer(size: int): | |
| b = torch.empty(world_size, size, dtype=dtype, device="cuda") | |
| return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)]) | |
| param_groups = [ | |
| {"params": [p for p in params if p.numel() == size], **create_update_buffer(size)} | |
| for size in sizes | |
| ] | |
| defaults = dict( | |
| lr=lr, | |
| b1=b1, | |
| weight_decay=weight_decay, | |
| preconditioner_update_probability=preconditioner_update_probability, | |
| max_size_triangular=max_size_triangular, | |
| min_ndim_triangular=min_ndim_triangular, | |
| memory_save_mode=memory_save_mode, | |
| precond_lr=precond_lr, | |
| precond_init_scale=precond_init_scale, | |
| merge_dims=merge_dims, | |
| partition_grads=partition_grads, | |
| block_size=block_size, | |
| clip_update_rms=clip_update_rms, | |
| dtype=dtype, | |
| ) | |
| super().__init__(param_groups, defaults) | |
| self._tiny = torch.tensor(torch.finfo(dtype).tiny, dtype=dtype, device="cuda") | |
| self._prob_step = torch.tensor(0, dtype=torch.int32) | |
| self._update_counter = torch.tensor(0, dtype=torch.int32) | |
| self.rng = random.Random(42) | |
| self.dtype = dtype | |
| self.comm_stream = torch.cuda.Stream() | |
| def step(self, closure=None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| update_prob = self.defaults["preconditioner_update_probability"] | |
| if callable(update_prob): | |
| update_prob = update_prob(self._prob_step.to(dtype=torch.float32)) | |
| self._prob_step += 1 | |
| self._update_counter += 1 | |
| do_update = self._update_counter >= 1 / update_prob | |
| if do_update: | |
| self._update_counter = torch.tensor(0, dtype=torch.int32) | |
| balance = do_update and self.rng.random() < 0.01 | |
| for group in self.param_groups: | |
| prev_handle = None | |
| prev_pworld = None | |
| num_params = len(group["params"]) | |
| for base_i in range(0, num_params, self.world_size): | |
| param_idx = base_i + self.rank | |
| if param_idx < num_params: | |
| p = group["params"][param_idx] | |
| g = torch.zeros(group["update_buffer"].size(1), dtype=self.dtype, device="cuda") | |
| if p.grad is not None: | |
| grads = p.grad.to(self.dtype) | |
| state = self.state[p] | |
| # merge smaller dims | |
| if grads.dim() > 2 and group.get('merge_dims', True): | |
| if "merged_shape" not in state: | |
| shape1 = [np.prod(grads.shape[:-1]), grads.shape[-1]] | |
| shape2 = [grads.shape[0], np.prod(grads.shape[1:])] | |
| shape = shape1 if np.diff(shape1) <= np.diff(shape2) else shape2 | |
| state["merged_shape"] = shape | |
| grads = grads.view(*state["merged_shape"]) | |
| # partition grads | |
| grads_list = ( | |
| state["partitioner"].partition(grads) | |
| if group["partition_grads"] | |
| and "partitioner" in state | |
| else [grads] | |
| ) | |
| if group["partition_grads"] and "partitioner" not in state: | |
| state["partitioner"] = BlockPartitioner( | |
| grads.shape, | |
| group["block_size"], | |
| _get_dim_diag( | |
| group["memory_save_mode"], | |
| grads.shape, | |
| group["max_size_triangular"], | |
| group["min_ndim_triangular"], | |
| ) | |
| ) | |
| grads_list = state["partitioner"].partition(grads) | |
| precond_grads = [] | |
| for i, grad in enumerate(grads_list): | |
| if f"step_{i}" not in state: | |
| state[f"step_{i}"] = 0 | |
| state[f"momentum_buffer_{i}"] = torch.zeros_like(grad, dtype=self.dtype) | |
| state[f"Q_{i}"], state[f"exprs_{i}"] = _init_Q_exprs( | |
| grad, | |
| group["precond_init_scale"], | |
| _get_dim_diag( | |
| group["memory_save_mode"], | |
| grad.shape, | |
| group["max_size_triangular"], | |
| group["min_ndim_triangular"], | |
| ), | |
| self.dtype, | |
| ) | |
| state[f"step_{i}"] += 1 | |
| debiased_momentum = _update_momentum( | |
| state[f"momentum_buffer_{i}"], | |
| grad, | |
| torch.tensor(group["b1"], dtype=self.dtype, device="cuda"), | |
| torch.tensor(state[f"step_{i}"], dtype=self.dtype, device="cuda"), | |
| ) | |
| if grad.dim() > 1 and balance: | |
| _balance_Q(state[f"Q_{i}"]) | |
| if do_update: | |
| _update_precond( | |
| state[f"Q_{i}"], | |
| state[f"exprs_{i}"], | |
| debiased_momentum, | |
| torch.tensor(group["precond_lr"], dtype=self.dtype, device="cuda"), | |
| self._tiny, | |
| ) | |
| precond_grads.append( | |
| _precond_grad(state[f"Q_{i}"], state[f"exprs_{i}"], debiased_momentum) | |
| ) | |
| g = ( | |
| state["partitioner"].merge_partitions(precond_grads) | |
| if group["partition_grads"] | |
| else precond_grads[0] | |
| ) | |
| if group["clip_update_rms"]: | |
| _clip_update_rms(g) | |
| g = g.flatten() | |
| else: | |
| g = torch.zeros(group["update_buffer"].size(1), dtype=self.dtype, device="cuda") | |
| with torch.cuda.stream(self.comm_stream): | |
| handle = dist.all_gather_into_tensor(group["update_buffer"], g, async_op=True) | |
| if prev_handle is not None: | |
| prev_handle.wait() | |
| views = group["update_buffer_views"][: len(prev_pworld)] | |
| updates = [v.view_as(pw) for v, pw in zip(views, prev_pworld)] | |
| _update_params( | |
| prev_pworld, | |
| updates, | |
| torch.tensor(group["weight_decay"], dtype=self.dtype, device="cuda"), | |
| torch.tensor(group["lr"], dtype=self.dtype, device="cuda"), | |
| ) | |
| prev_handle = handle | |
| prev_pworld = group["params"][base_i : min(base_i + self.world_size, num_params)] | |
| if prev_handle is not None: | |
| prev_handle.wait() | |
| views = group["update_buffer_views"][: len(prev_pworld)] | |
| updates = [v.view_as(pw) for v, pw in zip(views, prev_pworld)] | |
| _update_params( | |
| prev_pworld, | |
| updates, | |
| torch.tensor(group["weight_decay"], dtype=self.dtype, device="cuda"), | |
| torch.tensor(group["lr"], dtype=self.dtype, device="cuda"), | |
| ) | |
| return loss | |
| def _update_momentum(momentum_buffer, grad, beta, step): | |
| momentum_buffer.mul_(beta).add_(grad, alpha=1 - beta) | |
| return momentum_buffer.div(1 - beta**step) | |
| def _get_dim_diag(memory_save_mode, shape, max_size, min_ndim): | |
| if memory_save_mode is None: | |
| dim_diag = [False for _ in shape] | |
| elif memory_save_mode == "smart_one_diag": # Thanks to @ClashLuke heavyball repo | |
| rev_sorted_dims = np.argsort(shape)[::-1] | |
| dim_diag = [False for _ in shape] | |
| sorted_shape = sorted(shape) | |
| if len(shape) > 1 and sorted_shape[-1] > sorted_shape[-2]: | |
| dim_diag[rev_sorted_dims[0]] = True | |
| elif memory_save_mode == "one_diag": | |
| rev_sorted_dims = np.argsort(shape)[::-1] | |
| dim_diag = [i == rev_sorted_dims[0] for i in range(len(shape))] | |
| elif memory_save_mode == "all_diag": | |
| dim_diag = [True for _ in shape] | |
| else: | |
| raise ValueError( | |
| f"Invalid memory_save_mode: {memory_save_mode}, must be one of " | |
| "[None, 'smart_one_diag', 'one_diag', 'all_diag']" | |
| ) | |
| if len(shape) < min_ndim: | |
| return [True for _ in shape] | |
| for i in range(len(shape)): | |
| size = shape[i] | |
| if size == 1 or size > max_size: | |
| dim_diag[i] = True | |
| return dim_diag | |
| def _init_Q_exprs(t, scale, dim_diag, dtype): | |
| """Initialize preconditioner Q and reusable einsum expressions.""" | |
| letters = string.ascii_lowercase + string.ascii_uppercase | |
| shape = t.shape | |
| if len(shape) == 0: # scalar | |
| Q = [scale * torch.ones_like(t, dtype=dtype)] | |
| exprA = ",->" | |
| exprGs = [",->"] | |
| exprP = ",,->" | |
| else: | |
| if len(shape) > 13: | |
| raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!") | |
| scale = torch.tensor(scale ** (1 / len(shape)), dtype=dtype, device=t.device) | |
| Q = [] | |
| piece1A, piece2A, piece3A = ([], "", "") | |
| exprGs = [] | |
| piece1P, piece2P, piece3P, piece4P = ([], [], "", "") | |
| for i, dim_d in enumerate(dim_diag): | |
| if dim_d: | |
| # use diagonal matrix as preconditioner for this dim | |
| Q.append(scale * torch.ones(shape[i], dtype=dtype, device=t.device)) | |
| piece1A.append(letters[i]) | |
| piece2A = piece2A + letters[i] | |
| piece3A = piece3A + letters[i] | |
| piece1 = "".join( | |
| [(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))] | |
| ) | |
| subscripts = piece1 + "," + piece1 + "->" + letters[i + 13] | |
| exprGs.append(subscripts) | |
| piece1P.append(letters[i + 13]) | |
| piece2P.append(letters[i + 13]) | |
| piece3P = piece3P + letters[i + 13] | |
| piece4P = piece4P + letters[i + 13] | |
| else: | |
| # use triangular matrix as preconditioner for this dim | |
| Q.append(scale * torch.eye(shape[i], dtype=dtype, device=t.device)) | |
| piece1A.append(letters[i] + letters[i + 13]) | |
| piece2A = piece2A + letters[i + 13] | |
| piece3A = piece3A + letters[i] | |
| piece1 = "".join( | |
| [(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))] | |
| ) | |
| piece2 = "".join( | |
| [(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))] | |
| ) | |
| subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26] | |
| exprGs.append(subscripts) | |
| a, b, c = (letters[i], letters[i + 13], letters[i + 26]) | |
| piece1P.append(a + b) | |
| piece2P.append(a + c) | |
| piece3P = piece3P + c | |
| piece4P = piece4P + b | |
| exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A | |
| exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P | |
| exprGs = tuple(exprGs) | |
| return [Q, (exprA, exprGs, exprP)] | |
| def _balance_Q(Q_in): | |
| norms = torch.stack([q.norm(float("inf")) for q in Q_in]) | |
| geometric_mean = norms.log().mean().exp() | |
| norms = geometric_mean / norms | |
| torch._foreach_mul_(Q_in, list(norms)) | |
| def _lb(A: Tensor, max_abs: Tensor): | |
| """Cheap lower bound for the spectral norm of A.""" | |
| A /= max_abs | |
| a0 = torch.einsum("ij,ij->j", A, A) | |
| i = torch.argmax(a0) | |
| x = torch.index_select(A, 1, i).flatten().contiguous() | |
| x = torch.einsum("i,ij->j", x, A) | |
| x /= x.norm() | |
| x = torch.einsum("j,kj->k", x, A) | |
| x = x.norm() | |
| x *= max_abs | |
| return x | |
| def _solve_triangular_right(X: Tensor, A: Tensor): | |
| """X @ inv(A)""" | |
| orig_dtype = A.dtype | |
| # roughly same complexity as a matmul | |
| return ( | |
| torch.linalg.solve_triangular( | |
| A.float(), | |
| X.reshape(-1, X.size(-1)).float(), | |
| upper=True, | |
| left=False, | |
| unitriangular=False, | |
| ) | |
| .to(dtype=orig_dtype) | |
| .reshape_as(X) | |
| ) | |
| def _calc_A_and_conjB(exprA, G, Q): | |
| """Calculate A and conjB.""" | |
| order = G.dim() | |
| V = torch.randn_like(G) | |
| eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=G.dtype, device=G.device) | |
| G += eps.sqrt() * G.abs().mean() * V | |
| conjB = V.permute(*range(1, order), 0) | |
| for i, q in enumerate(Q): | |
| conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q) | |
| if i < order - 1: | |
| conjB = torch.transpose(conjB, i, order - 1) | |
| A = torch.einsum(exprA, *Q, G) | |
| return A, conjB | |
| def _update_precond(Q, exprs, G, step, tiny): | |
| """Update Kronecker product preconditioner Q with pair (V, G). | |
| Thanks to @ClashLuke heavyball repo for many of the optimizations in this function. | |
| """ | |
| exprA, exprGs, _ = exprs | |
| A, conjB = _calc_A_and_conjB(exprA, G, Q) | |
| for q, exprG in zip(Q, exprGs): | |
| term1 = torch.einsum(exprG, A, A) | |
| term2 = torch.einsum(exprG, conjB, conjB) | |
| term1, term2 = term1 - term2, term1 + term2 | |
| term1 *= step | |
| norm = term2.norm(float("inf")) | |
| if q.dim() < 2: | |
| term1 *= q / norm.clamp_(min=tiny) | |
| else: | |
| torch.triu(term1, out=term1) | |
| term1 /= torch.where(norm > 0, _lb(term2, norm), norm).clamp_(tiny) | |
| term1 = torch.mm(term1, q) | |
| q.sub_(term1) | |
| def _precond_grad(Q, exprs, G): | |
| """Precondition gradient G with preconditioner Q.""" | |
| return torch.einsum(exprs[-1], *Q, *Q, G) | |
| def _clip_update_rms(g): | |
| g.mul_( | |
| torch.minimum( | |
| torch.tensor(1.0, dtype=g.dtype, device=g.device), | |
| 1.1 / g.square().mean().sqrt().add(1e-12), | |
| ) | |
| ) | |
| def _update_params(params_world, updates, weight_decay, lr): | |
| if weight_decay > 0: | |
| torch._foreach_add_(updates, params_world, alpha=weight_decay) | |
| torch._foreach_add_(params_world, updates, alpha=-lr) | |
| class BlockPartitioner: | |
| """Partitions a tensor into smaller tensors. | |
| Modified from distributed_shampoo. | |
| https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py | |
| """ | |
| def __init__(self, param_shape, block_size, dim_diag): | |
| assert len(dim_diag) == len(param_shape), "dim_diag must have same length as param_shape" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._shape = param_shape | |
| self._split_indices = [] | |
| self._split_dims = [] | |
| for i, (d, is_diag) in enumerate(zip(param_shape, dim_diag)): | |
| if 0 < block_size < d and not is_diag: | |
| nsplit = (d - 1) // block_size | |
| if nsplit > 0: | |
| self._split_indices.append([(j + 1) * block_size for j in range(nsplit)]) | |
| self._split_dims.append(i) | |
| self._total_blocks = ( | |
| np.prod([len(indices) + 1 for indices in self._split_indices]) | |
| if self._split_indices | |
| else 1 | |
| ) | |
| def partition(self, tensor): | |
| assert tensor.shape == self._shape | |
| blocks = [tensor] | |
| for dim, indices in zip(self._split_dims, self._split_indices): | |
| new_blocks = [] | |
| for block in blocks: | |
| split_blocks = torch.tensor_split(block, indices, dim=dim) | |
| new_blocks.extend(split_blocks) | |
| blocks = new_blocks | |
| return blocks | |
| def merge_partitions(self, partitions): | |
| blocks = list(partitions) | |
| for dim, indices in zip(reversed(self._split_dims), reversed(self._split_indices)): | |
| n = len(indices) + 1 | |
| merged = [] | |
| for i in range(0, len(blocks), n): | |
| merged.append(torch.cat(blocks[i : i + n], dim=dim)) | |
| blocks = merged | |
| return blocks[0] | |
| ######################################## | |
| # Setup # | |
| ######################################## | |
| # torchrun sets these env variables | |
| device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) | |
| torch.cuda.set_device(device) | |
| dist.init_process_group(backend="nccl", device_id=device) | |
| dist.barrier() | |
| # this code can be run equivalently with 1, 2, 4, or 8 gpus. | |
| assert 8 % dist.get_world_size() == 0 | |
| # logging setup | |
| if dist.get_rank() == 0: | |
| os.makedirs("logs", exist_ok=True) | |
| logfile = f"logs/{uuid.uuid4()}.txt" | |
| print(logfile) | |
| def print0(s, console=False, log=True): | |
| if dist.get_rank() == 0: | |
| if console: | |
| print(s) | |
| if log: | |
| with open(logfile, "a") as f: | |
| print(s, file=f) | |
| # we begin by logging this file itself | |
| print0(code) | |
| print0("="*100) | |
| print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" | |
| + f" on {torch.cuda.get_device_name(device)} with world_size {dist.get_world_size()}") | |
| print0("="*100) | |
| val_tokens = 20 * 524288 | |
| batch_size = 8 * 64 * 1024 | |
| mbs = 64 | |
| val_inputs, val_targets = next(distributed_data_generator("data/fineweb10B/fineweb_val_*.bin", val_tokens)) | |
| model = GPT(vocab_size=50304, num_layers=12, model_dim=768).cuda() | |
| model.compile(dynamic=False) | |
| num_trials = int(sys.argv[-1]) if len(sys.argv) > 1 else 1 | |
| for _ in range(num_trials): | |
| ######################################## | |
| # Init & Optim Hyperparams # | |
| ######################################## | |
| # we want to minimize this while still reaching 3.28 val loss | |
| train_steps = 5750 | |
| # initialize model parameters | |
| for name, p in model.named_parameters(): | |
| if name.endswith("weight"): | |
| if "proj" in name: | |
| p.data.zero_() | |
| elif "embed" in name: | |
| p.data.normal_() # default torch init | |
| else: | |
| p.data.normal_(std=0.33**0.5 / p.size(-1)**0.5) # default torch init | |
| elif name.endswith("bias"): | |
| p.data.zero_() | |
| elif name.endswith("gains"): | |
| p.data.normal_(mean=1, std=0) | |
| else: | |
| raise Exception(f"Uninitialized parameter: {name}") | |
| # create the optimizer(s) | |
| optimizer1 = AdamW([dict(params=[model.embed.weight], lr=0.3), | |
| dict(params=[model.proj.weight], lr=1/320), | |
| dict(params=[p for p in model.parameters() if p.ndim < 2], lr=0.01)], | |
| betas=(0.8, 0.95), eps=1e-10, weight_decay=0, fused=True) | |
| optimizer2 = PSGDKron([p for p in model.blocks.parameters() if p.ndim >= 2], | |
| lr=0.0005, weight_decay=0.625, b1=0.9, | |
| rank=dist.get_rank(), world_size=dist.get_world_size(), | |
| memory_save_mode="one_diag", precond_lr=0.1, | |
| dtype=torch.float32) | |
| optimizers = [optimizer1, optimizer2] | |
| assert set(p for opt in optimizers for group in opt.param_groups | |
| for p in group["params"]) == set(model.parameters()) | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| group["initial_lr"] = group["lr"] | |
| group["warmup_steps"] = 250 if opt is optimizer2 else 0 | |
| # learning rate schedule: stable then decay | |
| def set_hparams(step, cooldown_frac=0.7): | |
| progress = step / train_steps | |
| assert 0 <= progress < 1 | |
| if progress < 1 - cooldown_frac: | |
| eta = 1.0 | |
| else: | |
| eta = (1 - progress) / cooldown_frac | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| warmup_steps = group.get("warmup_steps", 0) | |
| warmup = min(1.0, (step + 1) / warmup_steps) if warmup_steps > 0 else 1.0 | |
| group["lr"] = group["initial_lr"] * warmup * eta | |
| ######################################## | |
| # Training and Validation # | |
| ######################################## | |
| train_loader = distributed_data_generator("data/fineweb10B/fineweb_train_*.bin", batch_size) | |
| for p in model.parameters(): | |
| dist.broadcast(p.detach(), 0) | |
| # start the clock | |
| training_time = 0 | |
| last_val_step = 0 | |
| dist.barrier() | |
| t0 = time.perf_counter() | |
| for step in range(train_steps + 1): | |
| # --------------- VALIDATION SECTION ----------------- | |
| if step == train_steps or step % 125 == 0: | |
| # stop the clock | |
| dist.barrier() | |
| time_since_last_val = time.perf_counter() - t0 | |
| step_avg = time_since_last_val / (step - last_val_step) if step > 0 else float("nan") | |
| last_val_step = step | |
| training_time += time_since_last_val | |
| model.eval() | |
| val_loss = 0 | |
| with torch.no_grad(): | |
| assert len(val_inputs) % mbs == 0 | |
| for i in range(len(val_inputs) // mbs): | |
| val_loss += model(val_inputs[i*mbs:(i+1)*mbs], val_targets[i*mbs:(i+1)*mbs]) | |
| dist.all_reduce(val_loss, op=dist.ReduceOp.SUM) | |
| val_loss /= val_tokens | |
| print0(f"step:{step}/{train_steps} val_loss:{val_loss:.5f} train_time:{training_time:.3f}s" | |
| + f" step_avg:{1000*step_avg:.2f}ms", console=True) | |
| model.train() | |
| # start the clock again | |
| dist.barrier() | |
| t0 = time.perf_counter() | |
| if step == train_steps: | |
| break | |
| # --------------- TRAINING SECTION ----------------- | |
| inputs, targets = next(train_loader) | |
| # accumulate across microbatches in case we are running with fewer than 8 gpus | |
| assert len(inputs) % mbs == 0 | |
| for i in range(len(inputs) // mbs): | |
| model(inputs[i*mbs:(i+1)*mbs], targets[i*mbs:(i+1)*mbs]).backward() | |
| for name, p in model.named_parameters(): | |
| assert p.grad is not None, name | |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) | |
| # set optimization hyperparameters and take a step | |
| set_hparams(step) | |
| for opt in optimizers: | |
| opt.step() | |
| model.zero_grad(set_to_none=True) | |
| approx_training_time = training_time + (time.perf_counter() - t0) | |
| print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time:.3f}s" | |
| + f" step_avg:{1000*approx_training_time/(step + 1):.2f}ms", console=True, log=False) | |
| dist.destroy_process_group() | |
Xet Storage Details
- Size:
- 36 kB
- Xet hash:
- ac179f3952714112fa5fc2b77a3b46b02982b70abf03adf78440578d96861140
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.