Buckets:

cmpatino's picture
download
raw
36 kB
"""
train_gpt_simple.py
This file descends from the [NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt).
It was prepared as a simplified version of the speedrun for use in neural net optimization research.
"""
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import time
from pathlib import Path
import string
import random
import numpy as np
import torch
from torch import Tensor, nn
from torch.optim import AdamW
import torch.nn.functional as F
import torch.distributed as dist
########################################
# Dataloader #
########################################
def _load_data_shard(file: Path):
header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2]) # number of tokens (claimed)
with file.open("rb", buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
return tokens
def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len=1024):
world_size = dist.get_world_size()
rank = dist.get_rank()
files = sorted(Path.cwd().glob(filename_pattern))
assert batch_size % world_size == 0
local_batch_size = batch_size // world_size
file_iter = iter(files)
tokens, pos = _load_data_shard(next(file_iter)), 0
while True:
if pos + batch_size + 1 >= len(tokens):
tokens, pos = _load_data_shard(next(file_iter)), 0
buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True)
pos += batch_size
yield inputs.view(-1, seq_len), targets.view(-1, seq_len)
########################################
# Architecture #
########################################
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gains = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.rms_norm(x, (x.size(-1),), weight=self.gains.type_as(x))
class Linear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=True)
def forward(self, x):
return F.linear(x, self.weight.type_as(x), self.bias.type_as(x))
class Rotary(nn.Module):
def __init__(self, dim: int):
super().__init__()
# half-truncate RoPE (w/ base freq tuning)
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
self.register_buffer("angular_freq", torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]))
def forward(self, x_BTHD: Tensor):
pos = torch.arange(x_BTHD.size(1), dtype=torch.float32, device=x_BTHD.device)
theta = torch.outer(pos, self.angular_freq)[None, :, None, :]
cos, sin = theta.cos(), theta.sin()
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, head_dim=128):
super().__init__()
self.num_heads = dim // head_dim
self.head_dim = head_dim
hdim = self.num_heads * self.head_dim
self.q = Linear(dim, hdim)
self.k = Linear(dim, hdim)
self.v = Linear(dim, hdim)
self.proj = Linear(hdim, dim)
self.rotary = Rotary(head_dim)
def forward(self, x: Tensor):
B, T = x.size(0), x.size(1)
q = self.q(x).view(B, T, self.num_heads, self.head_dim)
k = self.k(x).view(B, T, self.num_heads, self.head_dim)
v = self.v(x).view(B, T, self.num_heads, self.head_dim)
q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))
q, k = self.rotary(q), self.rotary(k)
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2),
v.transpose(1, 2), scale=0.12, is_causal=True).transpose(1, 2)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim)
y = self.proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
self.fc = Linear(dim, hdim)
self.proj = Linear(hdim, dim)
def forward(self, x: Tensor):
x = self.fc(x)
x = x.relu().square()
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.attn = CausalSelfAttention(dim)
self.mlp = MLP(dim)
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
def forward(self, x: Tensor):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, model_dim: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim).bfloat16()
self.blocks = nn.ModuleList([Block(model_dim) for _ in range(num_layers)])
self.proj = Linear(model_dim, vocab_size)
self.norm1 = RMSNorm(model_dim)
self.norm2 = RMSNorm(model_dim)
def forward(self, inputs: Tensor, targets: Tensor):
x = self.norm1(self.embed(inputs))
for block in self.blocks:
x = block(x)
logits = self.proj(self.norm2(x)).float()
logits = 15 * logits * (logits.square() + 15**2).rsqrt()
return F.cross_entropy(logits.view(targets.numel(), -1), targets.view(-1), reduction="sum")
########################################
# Optimizer #
########################################
def zeropower_via_newtonschulz5(G: Tensor) -> Tensor:
assert G.ndim >= 2
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations, not optimizing for wallclock speed
a, b, c = 2, -1.5, 0.5
for _ in range(12):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def muon_update(grad, momentum, mu=0.95, nesterov=True):
momentum.lerp_(grad, 1 - mu)
update = grad.lerp_(momentum, mu) if nesterov else momentum
update = zeropower_via_newtonschulz5(update)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
return update
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, weight_decay=0, mu=0.95):
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
defaults = dict(lr=lr, weight_decay=weight_decay, mu=mu)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
world_size = dist.get_world_size()
rank = dist.get_rank()
for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (world_size - len(params) % world_size)
for base_i in range(0, len(params), world_size):
if base_i + rank < len(params):
p = params[base_i + rank]
state = self.state[p]
if len(state) == 0:
state["momentum"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum"], mu=group["mu"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank])
########################################
# PSGD Kron #
########################################
class ProbScheduler:
"""Scheduler for annealing preconditioner update probability.
Implements an exponential anneal with a flat start.
"""
def __init__(self, max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
self.max_prob = torch.tensor(max_prob, dtype=torch.float32)
self.min_prob = torch.tensor(min_prob, dtype=torch.float32)
self.decay = torch.tensor(decay, dtype=torch.float32)
self.flat_start = torch.tensor(flat_start, dtype=torch.float32)
self._compiled = False
try:
self._compiled_schedule = torch.compile(self._schedule_fn)
self._compiled = True
except Exception:
pass
def _schedule_fn(self, n):
prob = self.max_prob * torch.exp(-self.decay * (n - self.flat_start))
prob.clamp_(min=self.min_prob, max=self.max_prob)
return prob
def __call__(self, n):
if self._compiled:
return self._compiled_schedule(n)
else:
return self._schedule_fn(n)
def __reduce__(self):
return (self.__class__, (
self.max_prob.item(),
self.min_prob.item(),
self.decay.item(),
self.flat_start.item()
))
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
"""Anneal preconditioner update probability during beginning of training.
PSGD benefits from more preconditioner updates at the beginning of training,
but once the preconditioner is learned the update probability can drop low.
This schedule is an exponential anneal with a flat start. Default settings keep
update probability at 1.0 for 500 steps then exponentially anneal down to
`min_prob` by 4000 steps. Default settings work very well for most models and
training regimes.
"""
return ProbScheduler(max_prob, min_prob, decay, flat_start)
class PSGDKron(torch.optim.Optimizer):
"""PSGD Kron optimizer with layer-wise pipeline parallelism.
Args:
params: Parameters to optimize
lr: Learning rate
b1: Momentum
weight_decay: Weight decay
preconditioner_update_probability: Prob of updating preconditioner (default: anneals 1.0->0.03 by 4000 steps)
max_size_triangular: Max size for triangular preconditioner
min_ndim_triangular: Min dims needed for triangular preconditioners
memory_save_mode: Memory saving mode:
None: All triangular preconditioners
'smart_one_diag': Large outlier dims use diagonal
'one_diag': Largest dim uses diagonal
'all_diag': All diagonal preconditioners
precond_lr: Preconditioner learning rate (default: 0.1)
precond_init_scale: Initial preconditioner scale (default: 1.0)
merge_dims: Whether to combine dims to make grad tensor a matrix
partition_grads: Whether to partition gradients
block_size: Partition size for gradients
clip_update_rms: Clip update RMS at 1.1
dtype: Data type for params/grads
rank: Worker rank for pipeline
world_size: Total workers for pipeline
"""
def __init__(
self,
params,
lr=0.0003,
b1=0.9,
weight_decay=0.0,
preconditioner_update_probability=None,
max_size_triangular=8192,
min_ndim_triangular=2,
memory_save_mode=None,
precond_lr=0.1,
precond_init_scale=1.0,
merge_dims=True,
partition_grads=False,
block_size=1024,
clip_update_rms=True,
dtype=torch.float32,
rank=0,
world_size=1,
):
self.rank = rank
self.world_size = world_size
if preconditioner_update_probability is None:
preconditioner_update_probability = precond_update_prob_schedule()
params = [*params]
sizes = {p.numel() for p in params}
def create_update_buffer(size: int):
b = torch.empty(world_size, size, dtype=dtype, device="cuda")
return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)])
param_groups = [
{"params": [p for p in params if p.numel() == size], **create_update_buffer(size)}
for size in sizes
]
defaults = dict(
lr=lr,
b1=b1,
weight_decay=weight_decay,
preconditioner_update_probability=preconditioner_update_probability,
max_size_triangular=max_size_triangular,
min_ndim_triangular=min_ndim_triangular,
memory_save_mode=memory_save_mode,
precond_lr=precond_lr,
precond_init_scale=precond_init_scale,
merge_dims=merge_dims,
partition_grads=partition_grads,
block_size=block_size,
clip_update_rms=clip_update_rms,
dtype=dtype,
)
super().__init__(param_groups, defaults)
self._tiny = torch.tensor(torch.finfo(dtype).tiny, dtype=dtype, device="cuda")
self._prob_step = torch.tensor(0, dtype=torch.int32)
self._update_counter = torch.tensor(0, dtype=torch.int32)
self.rng = random.Random(42)
self.dtype = dtype
self.comm_stream = torch.cuda.Stream()
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
update_prob = self.defaults["preconditioner_update_probability"]
if callable(update_prob):
update_prob = update_prob(self._prob_step.to(dtype=torch.float32))
self._prob_step += 1
self._update_counter += 1
do_update = self._update_counter >= 1 / update_prob
if do_update:
self._update_counter = torch.tensor(0, dtype=torch.int32)
balance = do_update and self.rng.random() < 0.01
for group in self.param_groups:
prev_handle = None
prev_pworld = None
num_params = len(group["params"])
for base_i in range(0, num_params, self.world_size):
param_idx = base_i + self.rank
if param_idx < num_params:
p = group["params"][param_idx]
g = torch.zeros(group["update_buffer"].size(1), dtype=self.dtype, device="cuda")
if p.grad is not None:
grads = p.grad.to(self.dtype)
state = self.state[p]
# merge smaller dims
if grads.dim() > 2 and group.get('merge_dims', True):
if "merged_shape" not in state:
shape1 = [np.prod(grads.shape[:-1]), grads.shape[-1]]
shape2 = [grads.shape[0], np.prod(grads.shape[1:])]
shape = shape1 if np.diff(shape1) <= np.diff(shape2) else shape2
state["merged_shape"] = shape
grads = grads.view(*state["merged_shape"])
# partition grads
grads_list = (
state["partitioner"].partition(grads)
if group["partition_grads"]
and "partitioner" in state
else [grads]
)
if group["partition_grads"] and "partitioner" not in state:
state["partitioner"] = BlockPartitioner(
grads.shape,
group["block_size"],
_get_dim_diag(
group["memory_save_mode"],
grads.shape,
group["max_size_triangular"],
group["min_ndim_triangular"],
)
)
grads_list = state["partitioner"].partition(grads)
precond_grads = []
for i, grad in enumerate(grads_list):
if f"step_{i}" not in state:
state[f"step_{i}"] = 0
state[f"momentum_buffer_{i}"] = torch.zeros_like(grad, dtype=self.dtype)
state[f"Q_{i}"], state[f"exprs_{i}"] = _init_Q_exprs(
grad,
group["precond_init_scale"],
_get_dim_diag(
group["memory_save_mode"],
grad.shape,
group["max_size_triangular"],
group["min_ndim_triangular"],
),
self.dtype,
)
state[f"step_{i}"] += 1
debiased_momentum = _update_momentum(
state[f"momentum_buffer_{i}"],
grad,
torch.tensor(group["b1"], dtype=self.dtype, device="cuda"),
torch.tensor(state[f"step_{i}"], dtype=self.dtype, device="cuda"),
)
if grad.dim() > 1 and balance:
_balance_Q(state[f"Q_{i}"])
if do_update:
_update_precond(
state[f"Q_{i}"],
state[f"exprs_{i}"],
debiased_momentum,
torch.tensor(group["precond_lr"], dtype=self.dtype, device="cuda"),
self._tiny,
)
precond_grads.append(
_precond_grad(state[f"Q_{i}"], state[f"exprs_{i}"], debiased_momentum)
)
g = (
state["partitioner"].merge_partitions(precond_grads)
if group["partition_grads"]
else precond_grads[0]
)
if group["clip_update_rms"]:
_clip_update_rms(g)
g = g.flatten()
else:
g = torch.zeros(group["update_buffer"].size(1), dtype=self.dtype, device="cuda")
with torch.cuda.stream(self.comm_stream):
handle = dist.all_gather_into_tensor(group["update_buffer"], g, async_op=True)
if prev_handle is not None:
prev_handle.wait()
views = group["update_buffer_views"][: len(prev_pworld)]
updates = [v.view_as(pw) for v, pw in zip(views, prev_pworld)]
_update_params(
prev_pworld,
updates,
torch.tensor(group["weight_decay"], dtype=self.dtype, device="cuda"),
torch.tensor(group["lr"], dtype=self.dtype, device="cuda"),
)
prev_handle = handle
prev_pworld = group["params"][base_i : min(base_i + self.world_size, num_params)]
if prev_handle is not None:
prev_handle.wait()
views = group["update_buffer_views"][: len(prev_pworld)]
updates = [v.view_as(pw) for v, pw in zip(views, prev_pworld)]
_update_params(
prev_pworld,
updates,
torch.tensor(group["weight_decay"], dtype=self.dtype, device="cuda"),
torch.tensor(group["lr"], dtype=self.dtype, device="cuda"),
)
return loss
@torch.compile
def _update_momentum(momentum_buffer, grad, beta, step):
momentum_buffer.mul_(beta).add_(grad, alpha=1 - beta)
return momentum_buffer.div(1 - beta**step)
def _get_dim_diag(memory_save_mode, shape, max_size, min_ndim):
if memory_save_mode is None:
dim_diag = [False for _ in shape]
elif memory_save_mode == "smart_one_diag": # Thanks to @ClashLuke heavyball repo
rev_sorted_dims = np.argsort(shape)[::-1]
dim_diag = [False for _ in shape]
sorted_shape = sorted(shape)
if len(shape) > 1 and sorted_shape[-1] > sorted_shape[-2]:
dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "one_diag":
rev_sorted_dims = np.argsort(shape)[::-1]
dim_diag = [i == rev_sorted_dims[0] for i in range(len(shape))]
elif memory_save_mode == "all_diag":
dim_diag = [True for _ in shape]
else:
raise ValueError(
f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
"[None, 'smart_one_diag', 'one_diag', 'all_diag']"
)
if len(shape) < min_ndim:
return [True for _ in shape]
for i in range(len(shape)):
size = shape[i]
if size == 1 or size > max_size:
dim_diag[i] = True
return dim_diag
def _init_Q_exprs(t, scale, dim_diag, dtype):
"""Initialize preconditioner Q and reusable einsum expressions."""
letters = string.ascii_lowercase + string.ascii_uppercase
shape = t.shape
if len(shape) == 0: # scalar
Q = [scale * torch.ones_like(t, dtype=dtype)]
exprA = ",->"
exprGs = [",->"]
exprP = ",,->"
else:
if len(shape) > 13:
raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!")
scale = torch.tensor(scale ** (1 / len(shape)), dtype=dtype, device=t.device)
Q = []
piece1A, piece2A, piece3A = ([], "", "")
exprGs = []
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
for i, dim_d in enumerate(dim_diag):
if dim_d:
# use diagonal matrix as preconditioner for this dim
Q.append(scale * torch.ones(shape[i], dtype=dtype, device=t.device))
piece1A.append(letters[i])
piece2A = piece2A + letters[i]
piece3A = piece3A + letters[i]
piece1 = "".join(
[(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))]
)
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
exprGs.append(subscripts)
piece1P.append(letters[i + 13])
piece2P.append(letters[i + 13])
piece3P = piece3P + letters[i + 13]
piece4P = piece4P + letters[i + 13]
else:
# use triangular matrix as preconditioner for this dim
Q.append(scale * torch.eye(shape[i], dtype=dtype, device=t.device))
piece1A.append(letters[i] + letters[i + 13])
piece2A = piece2A + letters[i + 13]
piece3A = piece3A + letters[i]
piece1 = "".join(
[(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))]
)
piece2 = "".join(
[(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))]
)
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
exprGs.append(subscripts)
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
piece1P.append(a + b)
piece2P.append(a + c)
piece3P = piece3P + c
piece4P = piece4P + b
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
exprGs = tuple(exprGs)
return [Q, (exprA, exprGs, exprP)]
@torch.compile
def _balance_Q(Q_in):
norms = torch.stack([q.norm(float("inf")) for q in Q_in])
geometric_mean = norms.log().mean().exp()
norms = geometric_mean / norms
torch._foreach_mul_(Q_in, list(norms))
def _lb(A: Tensor, max_abs: Tensor):
"""Cheap lower bound for the spectral norm of A."""
A /= max_abs
a0 = torch.einsum("ij,ij->j", A, A)
i = torch.argmax(a0)
x = torch.index_select(A, 1, i).flatten().contiguous()
x = torch.einsum("i,ij->j", x, A)
x /= x.norm()
x = torch.einsum("j,kj->k", x, A)
x = x.norm()
x *= max_abs
return x
def _solve_triangular_right(X: Tensor, A: Tensor):
"""X @ inv(A)"""
orig_dtype = A.dtype
# roughly same complexity as a matmul
return (
torch.linalg.solve_triangular(
A.float(),
X.reshape(-1, X.size(-1)).float(),
upper=True,
left=False,
unitriangular=False,
)
.to(dtype=orig_dtype)
.reshape_as(X)
)
def _calc_A_and_conjB(exprA, G, Q):
"""Calculate A and conjB."""
order = G.dim()
V = torch.randn_like(G)
eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=G.dtype, device=G.device)
G += eps.sqrt() * G.abs().mean() * V
conjB = V.permute(*range(1, order), 0)
for i, q in enumerate(Q):
conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q)
if i < order - 1:
conjB = torch.transpose(conjB, i, order - 1)
A = torch.einsum(exprA, *Q, G)
return A, conjB
@torch.compile
def _update_precond(Q, exprs, G, step, tiny):
"""Update Kronecker product preconditioner Q with pair (V, G).
Thanks to @ClashLuke heavyball repo for many of the optimizations in this function.
"""
exprA, exprGs, _ = exprs
A, conjB = _calc_A_and_conjB(exprA, G, Q)
for q, exprG in zip(Q, exprGs):
term1 = torch.einsum(exprG, A, A)
term2 = torch.einsum(exprG, conjB, conjB)
term1, term2 = term1 - term2, term1 + term2
term1 *= step
norm = term2.norm(float("inf"))
if q.dim() < 2:
term1 *= q / norm.clamp_(min=tiny)
else:
torch.triu(term1, out=term1)
term1 /= torch.where(norm > 0, _lb(term2, norm), norm).clamp_(tiny)
term1 = torch.mm(term1, q)
q.sub_(term1)
@torch.compile
def _precond_grad(Q, exprs, G):
"""Precondition gradient G with preconditioner Q."""
return torch.einsum(exprs[-1], *Q, *Q, G)
@torch.compile
def _clip_update_rms(g):
g.mul_(
torch.minimum(
torch.tensor(1.0, dtype=g.dtype, device=g.device),
1.1 / g.square().mean().sqrt().add(1e-12),
)
)
@torch.compile
def _update_params(params_world, updates, weight_decay, lr):
if weight_decay > 0:
torch._foreach_add_(updates, params_world, alpha=weight_decay)
torch._foreach_add_(params_world, updates, alpha=-lr)
class BlockPartitioner:
"""Partitions a tensor into smaller tensors.
Modified from distributed_shampoo.
https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
"""
def __init__(self, param_shape, block_size, dim_diag):
assert len(dim_diag) == len(param_shape), "dim_diag must have same length as param_shape"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._shape = param_shape
self._split_indices = []
self._split_dims = []
for i, (d, is_diag) in enumerate(zip(param_shape, dim_diag)):
if 0 < block_size < d and not is_diag:
nsplit = (d - 1) // block_size
if nsplit > 0:
self._split_indices.append([(j + 1) * block_size for j in range(nsplit)])
self._split_dims.append(i)
self._total_blocks = (
np.prod([len(indices) + 1 for indices in self._split_indices])
if self._split_indices
else 1
)
def partition(self, tensor):
assert tensor.shape == self._shape
blocks = [tensor]
for dim, indices in zip(self._split_dims, self._split_indices):
new_blocks = []
for block in blocks:
split_blocks = torch.tensor_split(block, indices, dim=dim)
new_blocks.extend(split_blocks)
blocks = new_blocks
return blocks
def merge_partitions(self, partitions):
blocks = list(partitions)
for dim, indices in zip(reversed(self._split_dims), reversed(self._split_indices)):
n = len(indices) + 1
merged = []
for i in range(0, len(blocks), n):
merged.append(torch.cat(blocks[i : i + n], dim=dim))
blocks = merged
return blocks[0]
########################################
# Setup #
########################################
# torchrun sets these env variables
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
# this code can be run equivalently with 1, 2, 4, or 8 gpus.
assert 8 % dist.get_world_size() == 0
# logging setup
if dist.get_rank() == 0:
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{uuid.uuid4()}.txt"
print(logfile)
def print0(s, console=False, log=True):
if dist.get_rank() == 0:
if console:
print(s)
if log:
with open(logfile, "a") as f:
print(s, file=f)
# we begin by logging this file itself
print0(code)
print0("="*100)
print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}"
+ f" on {torch.cuda.get_device_name(device)} with world_size {dist.get_world_size()}")
print0("="*100)
val_tokens = 20 * 524288
batch_size = 8 * 64 * 1024
mbs = 64
val_inputs, val_targets = next(distributed_data_generator("data/fineweb10B/fineweb_val_*.bin", val_tokens))
model = GPT(vocab_size=50304, num_layers=12, model_dim=768).cuda()
model.compile(dynamic=False)
num_trials = int(sys.argv[-1]) if len(sys.argv) > 1 else 1
for _ in range(num_trials):
########################################
# Init & Optim Hyperparams #
########################################
# we want to minimize this while still reaching 3.28 val loss
train_steps = 5750
# initialize model parameters
for name, p in model.named_parameters():
if name.endswith("weight"):
if "proj" in name:
p.data.zero_()
elif "embed" in name:
p.data.normal_() # default torch init
else:
p.data.normal_(std=0.33**0.5 / p.size(-1)**0.5) # default torch init
elif name.endswith("bias"):
p.data.zero_()
elif name.endswith("gains"):
p.data.normal_(mean=1, std=0)
else:
raise Exception(f"Uninitialized parameter: {name}")
# create the optimizer(s)
optimizer1 = AdamW([dict(params=[model.embed.weight], lr=0.3),
dict(params=[model.proj.weight], lr=1/320),
dict(params=[p for p in model.parameters() if p.ndim < 2], lr=0.01)],
betas=(0.8, 0.95), eps=1e-10, weight_decay=0, fused=True)
optimizer2 = PSGDKron([p for p in model.blocks.parameters() if p.ndim >= 2],
lr=0.0005, weight_decay=0.625, b1=0.9,
rank=dist.get_rank(), world_size=dist.get_world_size(),
memory_save_mode="one_diag", precond_lr=0.1,
dtype=torch.float32)
optimizers = [optimizer1, optimizer2]
assert set(p for opt in optimizers for group in opt.param_groups
for p in group["params"]) == set(model.parameters())
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
group["warmup_steps"] = 250 if opt is optimizer2 else 0
# learning rate schedule: stable then decay
def set_hparams(step, cooldown_frac=0.7):
progress = step / train_steps
assert 0 <= progress < 1
if progress < 1 - cooldown_frac:
eta = 1.0
else:
eta = (1 - progress) / cooldown_frac
for opt in optimizers:
for group in opt.param_groups:
warmup_steps = group.get("warmup_steps", 0)
warmup = min(1.0, (step + 1) / warmup_steps) if warmup_steps > 0 else 1.0
group["lr"] = group["initial_lr"] * warmup * eta
########################################
# Training and Validation #
########################################
train_loader = distributed_data_generator("data/fineweb10B/fineweb_train_*.bin", batch_size)
for p in model.parameters():
dist.broadcast(p.detach(), 0)
# start the clock
training_time = 0
last_val_step = 0
dist.barrier()
t0 = time.perf_counter()
for step in range(train_steps + 1):
# --------------- VALIDATION SECTION -----------------
if step == train_steps or step % 125 == 0:
# stop the clock
dist.barrier()
time_since_last_val = time.perf_counter() - t0
step_avg = time_since_last_val / (step - last_val_step) if step > 0 else float("nan")
last_val_step = step
training_time += time_since_last_val
model.eval()
val_loss = 0
with torch.no_grad():
assert len(val_inputs) % mbs == 0
for i in range(len(val_inputs) // mbs):
val_loss += model(val_inputs[i*mbs:(i+1)*mbs], val_targets[i*mbs:(i+1)*mbs])
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
val_loss /= val_tokens
print0(f"step:{step}/{train_steps} val_loss:{val_loss:.5f} train_time:{training_time:.3f}s"
+ f" step_avg:{1000*step_avg:.2f}ms", console=True)
model.train()
# start the clock again
dist.barrier()
t0 = time.perf_counter()
if step == train_steps:
break
# --------------- TRAINING SECTION -----------------
inputs, targets = next(train_loader)
# accumulate across microbatches in case we are running with fewer than 8 gpus
assert len(inputs) % mbs == 0
for i in range(len(inputs) // mbs):
model(inputs[i*mbs:(i+1)*mbs], targets[i*mbs:(i+1)*mbs]).backward()
for name, p in model.named_parameters():
assert p.grad is not None, name
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# set optimization hyperparameters and take a step
set_hparams(step)
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
approx_training_time = training_time + (time.perf_counter() - t0)
print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time:.3f}s"
+ f" step_avg:{1000*approx_training_time/(step + 1):.2f}ms", console=True, log=False)
dist.destroy_process_group()

Xet Storage Details

Size:
36 kB
·
Xet hash:
ac179f3952714112fa5fc2b77a3b46b02982b70abf03adf78440578d96861140

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.