|
|
|
|
|
|
| import os
|
| import sys
|
| with open(sys.argv[0]) as f:
|
| code = f.read()
|
| import uuid
|
| import time
|
| import copy
|
| from dataclasses import dataclass
|
| from functools import lru_cache
|
| from pathlib import Path
|
| import numpy as np
|
| import wandb
|
|
|
| os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| import torch
|
| torch.empty(1, device="cuda", requires_grad=True).backward()
|
| from torch import Tensor, nn
|
| import torch.nn.functional as F
|
| import torch.distributed as dist
|
| from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
| torch._inductor.config.coordinate_descent_tuning = True
|
|
|
|
|
|
|
|
|
| def zeropower_via_newtonschulz5(G: Tensor) -> Tensor:
|
| """
|
| Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
|
| """
|
| assert G.ndim >= 2
|
| X = G.bfloat16()
|
| if G.size(-2) > G.size(-1):
|
| X = X.mT
|
|
|
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| for a, b, c in [
|
| (4.0848, -6.8946, 2.9270),
|
| (3.9505, -6.3029, 2.6377),
|
| (3.7418, -5.5913, 2.3037),
|
| (2.8769, -3.1427, 1.2046),
|
| (2.8366, -3.0525, 1.2012),
|
| ]:
|
| A = X @ X.mT
|
| B = b * A + c * A @ A
|
| X = a * X + B @ X
|
|
|
| if G.size(-2) > G.size(-1):
|
| X = X.mT
|
| return X
|
|
|
| @torch.compile
|
| def update(acc_bf16_view_u16: Tensor, mantissa: Tensor, momentum_buffer: Tensor, grad: Tensor, momentum: Tensor, eff_lr: Tensor, eff_weight_decay: Tensor):
|
| assert acc_bf16_view_u16.dtype == mantissa.dtype == torch.uint16
|
| grad = grad.float()
|
| momentum_buffer.copy_(momentum * momentum_buffer + (1 - momentum) * grad)
|
| v = zeropower_via_newtonschulz5(momentum * momentum_buffer + (1 - momentum) * grad)
|
|
|
| acc_m_u32 = (acc_bf16_view_u16.to(torch.uint32) << 16) | mantissa.to(torch.uint32)
|
| acc_m_u32.view(torch.float32).mul_(1 - eff_weight_decay)
|
| acc_m_u32.view(torch.float32).add_(other=v, alpha=-eff_lr)
|
| acc_bf16_view_u16.copy_((acc_m_u32 >> 16).to(torch.uint16))
|
| mantissa.copy_(acc_m_u32.to(torch.uint16))
|
|
|
| class Muon(torch.optim.Optimizer):
|
| """Muon - MomentUm Orthogonalized by Newton-schulz"""
|
| def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, rank=0, world_size=1):
|
| self.rank = rank
|
| self.world_size = world_size
|
| defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
| super().__init__(params, defaults)
|
| assert all(p.dtype == torch.bfloat16 for group in self.param_groups for p in group["params"])
|
|
|
| @torch.no_grad()
|
| def step(self):
|
| futures: list[torch.Future] = []
|
| for group in self.param_groups:
|
| params: list[Tensor] = group["params"]
|
| params_pad = params + [torch.empty_like(params[-1])] * self.world_size
|
| momentum = torch._as_tensor_fullprec(group["momentum"])
|
| for base_i in range(len(params))[::self.world_size]:
|
| if base_i + self.rank < len(params):
|
| p = params[base_i + self.rank]
|
| state = self.state[p]
|
| if len(state) == 0:
|
| state["mantissa"] = torch.zeros_like(p, dtype=torch.uint16)
|
| state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32)
|
| update(
|
| p.view(torch.uint16), state["mantissa"], state["momentum_buffer"],
|
| p.grad, momentum,
|
| eff_lr=torch._as_tensor_fullprec(group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5),
|
| eff_weight_decay=torch._as_tensor_fullprec(group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0)),
|
| )
|
| futures.append(dist.all_gather(params_pad[base_i:base_i + self.world_size], params_pad[base_i + self.rank], async_op=True).get_future())
|
| torch.futures.collect_all(futures).wait()
|
|
|
|
|
|
|
|
|
| def norm(x: Tensor):
|
| return F.rms_norm(x, (x.size(-1),))
|
|
|
| @torch.no_grad()
|
| def init_linear(w: Tensor):
|
| std = 0.5 * (w.size(-1) ** -0.5)
|
| bound = (3 ** 0.5) * std
|
| return w.uniform_(-bound, bound)
|
|
|
| class Rotary(nn.Module):
|
| def __init__(self, dim: int, max_seq_len: int):
|
| super().__init__()
|
| angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
|
| angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
|
| t = torch.arange(max_seq_len, dtype=torch.float32)
|
| theta = torch.einsum("i,j -> ij", t, angular_freq)
|
| self.cos = nn.Buffer(theta.cos(), persistent=False)
|
| self.sin = nn.Buffer(theta.sin(), persistent=False)
|
|
|
| def forward(self, x_BTHD: Tensor):
|
| assert self.cos.size(0) >= x_BTHD.size(-3)
|
| cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
|
| x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
|
| y1 = x1 * cos + x2 * sin
|
| y2 = x1 * (-sin) + x2 * cos
|
| return torch.cat((y1, y2), 3).type_as(x_BTHD)
|
|
|
| class CausalSelfAttention(nn.Module):
|
| def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
|
| super().__init__()
|
| self.num_heads = num_heads
|
| self.head_dim = head_dim
|
| hdim = num_heads * head_dim
|
| self.qkvo_w = nn.Parameter(init_linear(torch.empty(4, hdim, dim)).bfloat16())
|
| self.qkvo_w.detach()[3].zero_()
|
| self.rotary = Rotary(head_dim, max_seq_len)
|
| self.attn_scale = 0.12
|
|
|
| def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask, lambdas: Tensor):
|
| B, T = x.size(0), x.size(1)
|
| assert B == 1, "Must use batch size = 1 for FlexAttention"
|
| q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
|
| q, k = norm(q), norm(k)
|
| q, k = self.rotary(q), self.rotary(k)
|
| v = norm(v)
|
| if ve is not None:
|
| v = lambdas[0] * v + lambdas[1] * ve.view_as(v)
|
| else:
|
| v = lambdas[0] * v
|
| y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2)
|
| y = y.contiguous().view(B, T, self.num_heads * self.head_dim)
|
| y = F.linear(y, self.qkvo_w[3])
|
| return y
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, dim: int):
|
| super().__init__()
|
| hdim = 4 * dim
|
| self.fc_w = nn.Parameter(init_linear(torch.empty(hdim, dim)).bfloat16())
|
| self.proj_w = nn.Parameter(torch.zeros(dim, hdim).bfloat16())
|
| self.fc_w.wd_mul = 2.0
|
| self.proj_w.wd_mul = 2.0
|
|
|
| def forward(self, x: Tensor):
|
| x = F.linear(x, self.fc_w)
|
| x = F.relu(x).square()
|
| x = F.linear(x, self.proj_w)
|
| return x
|
|
|
| class Block(nn.Module):
|
| def __init__(self, dim: int, num_heads: int, max_seq_len: int):
|
| super().__init__()
|
| self.attn = CausalSelfAttention(dim, num_heads, max_seq_len)
|
| self.mlp = MLP(dim)
|
|
|
| def forward(self, x: Tensor, ve: Tensor | None, x00: Tensor, x01: Tensor, block_mask: BlockMask, lambdas: Tensor, sa_lambdas: Tensor):
|
| x = lambdas[0] * x + lambdas[1] * x00 + lambdas[2] * x01
|
| x = x + self.attn(x, ve, block_mask, sa_lambdas)
|
| x = x + self.mlp(norm(x))
|
| return x
|
|
|
|
|
|
|
|
|
| def next_multiple_of_n(v: float | int, *, n: int):
|
| return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
|
|
|
| class GPT(nn.Module):
|
| def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, eos_token_id: int = 3):
|
| super().__init__()
|
| self.eos_token_id = eos_token_id
|
| self.embed1 = nn.Embedding(vocab_size, model_dim)
|
| self.embed2 = nn.Embedding(vocab_size, model_dim)
|
|
|
| self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(5)])
|
| self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len) for _ in range(num_layers)])
|
| self.lm_head_w = nn.Parameter(torch.zeros(next_multiple_of_n(vocab_size, n=128), model_dim))
|
| assert num_layers % 2 == 0
|
| self.scalars = nn.Parameter(torch.cat([
|
| torch.ones(num_layers),
|
| *[torch.tensor([1.0, 0.0, 0.0]) for _ in range(num_layers)],
|
| *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)],
|
| ]))
|
|
|
| def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
|
| BLOCK_SIZE = 128
|
| docs = (input_seq == self.eos_token_id).cumsum(0)
|
|
|
| def document_causal(b, h, q_idx, kv_idx):
|
| causal_mask = q_idx >= kv_idx
|
| document_mask = docs[q_idx] == docs[kv_idx]
|
| return causal_mask & document_mask
|
|
|
| def dense_to_ordered(dense_blockmask: Tensor):
|
| num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
|
| indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
|
| return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
|
|
|
| assert len(input_seq) % BLOCK_SIZE == 0
|
| NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
|
| block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
|
| causal_blockmask_any = block_idx[:, None] >= block_idx
|
| causal_blockmask_all = block_idx[:, None] > block_idx
|
| docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
|
| docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
|
| document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
|
| document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
|
| blockmask_any = causal_blockmask_any & document_blockmask_any
|
| blockmask_all = causal_blockmask_all & document_blockmask_all
|
| partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
|
| full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
|
| def build_bm(window_size_blocks: Tensor) -> BlockMask:
|
| return BlockMask.from_kv_blocks(
|
| torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
|
| partial_kv_indices,
|
| torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
|
| full_kv_indices,
|
| BLOCK_SIZE=BLOCK_SIZE,
|
| mask_mod=document_causal,
|
| )
|
| return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
|
|
|
| def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
|
| assert input_seq.ndim == 1
|
| L = len(self.blocks)
|
|
|
| ve = [value_embed(input_seq) for value_embed in self.value_embeds]
|
|
|
| ve_layers = [ve[0], ve[1], ve[2], ve[3], ve[4]] + [None] * (L - 10) + [ve[0], ve[1], ve[2], ve[3], ve[4]]
|
| assert len(ve_layers) == L
|
|
|
| long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
|
|
|
| block_masks = [long_bm if i % 4 == 0 else short_bm for i in range(L)]
|
|
|
| x = x00 = norm(self.embed1(input_seq)[None])
|
| x01 = norm(self.embed2(input_seq)[None])
|
|
|
|
|
|
|
| skip_connections = []
|
| skip_map = {
|
| 15: 8,
|
| 17: 6,
|
| 19: 4,
|
| }
|
| skip_weights = self.scalars[:L]
|
| lambdas = self.scalars[1 * L: 4 * L].view(-1, 3)
|
| sa_lambdas = self.scalars[4 * L: 6 * L].view(-1, 2)
|
|
|
| for i in range(L):
|
| if i in skip_map:
|
| x = x + skip_weights[skip_map[i]] * skip_connections[skip_map[i]]
|
| x = self.blocks[i](x, ve_layers[i], x00, x01, block_masks[i], lambdas[i], sa_lambdas[i])
|
| skip_connections.append(x)
|
|
|
| x = norm(x)
|
| if self.training:
|
| logits: Tensor = F.linear(x.flatten(end_dim=1), self.lm_head_w.bfloat16()).float()
|
| loss = F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq)
|
| return loss
|
|
|
| loss = 0
|
| for i in range(4):
|
| logits: Tensor = F.linear(x.flatten(end_dim=1).chunk(4)[i], self.lm_head_w.bfloat16()).float()
|
| loss += F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq.chunk(4)[i]) / 4
|
| return loss
|
|
|
|
|
|
|
|
|
| def _load_data_shard(file: Path):
|
| header = torch.from_file(str(file), False, 256, dtype=torch.int32)
|
| assert header[0] == 20240520, "magic number mismatch in the data .bin file"
|
| assert header[1] == 1, "unsupported version"
|
| num_tokens = int(header[2])
|
| with file.open("rb", buffering=0) as f:
|
| tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
|
| f.seek(256 * 4)
|
| nbytes = f.readinto(tokens.numpy())
|
| assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
|
| return tokens
|
|
|
| def distributed_data_generator(filename_pattern: str, batch_size: int, rank: int, world_size: int):
|
| files = sorted(Path.cwd().glob(filename_pattern))
|
| assert batch_size % world_size == 0
|
| local_batch_size = batch_size // world_size
|
|
|
| epoch = 0
|
| while True:
|
|
|
| rng = np.random.default_rng(seed=42 + epoch)
|
| shuffled_files = rng.permutation(files).tolist()
|
|
|
| for file in shuffled_files:
|
| tokens = _load_data_shard(file)
|
| pos = 0
|
| while pos + batch_size + 1 < len(tokens):
|
| buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
|
| inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True)
|
| targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True)
|
| pos += batch_size
|
| yield inputs, targets
|
|
|
| epoch += 1
|
| if rank == 0:
|
| print(f"Completed epoch {epoch}, shuffling for next epoch...")
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class Hyperparameters:
|
|
|
| train_files = "data/gd_levels/train_*.bin"
|
| val_files = "data/gd_levels/val_*.bin"
|
| val_tokens = 10420224
|
|
|
|
|
| train_seq_len = 16 * 1024
|
| val_seq_len = 16 * 1024
|
|
|
|
|
| num_iterations = 109063
|
| cooldown_frac = 0.7
|
|
|
|
|
| vocab_size = 32000
|
| num_layers = 24
|
| num_heads = 10
|
| model_dim = 1280
|
| eos_token_id = 3
|
|
|
|
|
| val_loss_every = 5000
|
| wandb_log_every = 1
|
| save_every = 10000
|
| save_checkpoint = True
|
| resume_from = None
|
|
|
| args = Hyperparameters()
|
|
|
| if os.environ.get("RESUME_FROM"):
|
| args.resume_from = os.environ["RESUME_FROM"]
|
|
|
|
|
|
|
|
|
| run_id = int(os.environ.get("RUN_ID", 0))
|
| rank = int(os.environ["RANK"])
|
| world_size = int(os.environ["WORLD_SIZE"])
|
| assert torch.cuda.is_available()
|
| device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
|
| torch.cuda.set_device(device)
|
| dist.init_process_group(backend="nccl", device_id=device)
|
| dist.barrier()
|
| master_process = (rank == 0)
|
|
|
| if master_process:
|
| run_id_full = f"{run_id:03d}_{uuid.uuid4()}"
|
| os.makedirs("logs", exist_ok=True)
|
| logfile = f"logs/{run_id_full}.txt"
|
| print(logfile)
|
|
|
| wandb.init(
|
| project="gd-level-generation",
|
| name=run_id_full,
|
| config={
|
| "vocab_size": args.vocab_size,
|
| "num_layers": args.num_layers,
|
| "model_dim": args.model_dim,
|
| "num_heads": args.num_heads,
|
| "train_seq_len": args.train_seq_len,
|
| "num_iterations": args.num_iterations,
|
| "cooldown_frac": args.cooldown_frac,
|
| },
|
| )
|
|
|
| def print0(s, console=False):
|
| if master_process:
|
| with open(logfile, "a") as f:
|
| if console:
|
| print(s)
|
| print(s, file=f)
|
|
|
| print0(code)
|
| print0("=" * 100)
|
| print0(f"Running Python {sys.version}")
|
| print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
|
|
|
| def nvidia_smi():
|
| import subprocess
|
| return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
|
| print0(nvidia_smi())
|
| print0("=" * 100)
|
|
|
|
|
|
|
|
|
| model: nn.Module = GPT(
|
| vocab_size=args.vocab_size,
|
| num_layers=args.num_layers,
|
| num_heads=args.num_heads,
|
| model_dim=args.model_dim,
|
| max_seq_len=max(args.train_seq_len, args.val_seq_len),
|
| eos_token_id=args.eos_token_id,
|
| ).cuda()
|
|
|
| for m in model.modules():
|
| if isinstance(m, nn.Embedding):
|
| m.bfloat16()
|
| for param in model.parameters():
|
| dist.broadcast(param.detach(), 0)
|
|
|
|
|
| if master_process:
|
| total_params = sum(p.numel() for p in model.parameters())
|
| print0(f"Total parameters: {total_params:,} ({total_params/1e6:.1f}M)", console=True)
|
|
|
|
|
| hidden_matrix_params = sorted((p for p in model.blocks.parameters() if p.ndim >= 2), key=lambda x: x.size(), reverse=True)
|
| embed_params = [*model.embed1.parameters(), *model.embed2.parameters(), *model.value_embeds.parameters()]
|
| scalar_params = [model.scalars]
|
| head_params: list[nn.Parameter] = [model.lm_head_w]
|
|
|
| params_collections = [hidden_matrix_params, embed_params, scalar_params, head_params]
|
| optimized_parameters_set = {p for params in params_collections for p in params}
|
| assert optimized_parameters_set == {*model.parameters()}
|
| assert len(optimized_parameters_set) == sum(len(lst) for lst in params_collections)
|
|
|
|
|
| adam_param_groups = [
|
| dict(params=head_params, lr=1/320),
|
| dict(params=embed_params, lr=0.3),
|
| dict(params=scalar_params, lr=0.015),
|
| ]
|
| optimizer1 = torch.optim.AdamW(adam_param_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True)
|
| optimizer2 = Muon(hidden_matrix_params, lr=0.025, momentum=0.95, rank=rank, world_size=world_size)
|
| optimizers: list[torch.optim.Optimizer] = [optimizer1, optimizer2]
|
|
|
| def opt_params(opt: torch.optim.Optimizer) -> list[nn.Parameter]:
|
| return [p for group in opt.param_groups for p in group["params"]]
|
| opt2params = {opt: opt_params(opt) for opt in optimizers}
|
| for opt in optimizers:
|
| for group in opt.param_groups:
|
| group["initial_lr"] = group["lr"]
|
|
|
|
|
| start_step = 0
|
| if args.resume_from:
|
| print0(f"Resuming from checkpoint: {args.resume_from}", console=True)
|
| checkpoint = torch.load(args.resume_from, map_location=device)
|
|
|
| model_state = checkpoint["model"]
|
| if any(k.startswith("_orig_mod.") for k in model_state.keys()):
|
| model_state = {k.replace("_orig_mod.", ""): v for k, v in model_state.items()}
|
| model.load_state_dict(model_state)
|
|
|
| for opt, opt_state in zip(optimizers, checkpoint["optimizers"]):
|
| opt.load_state_dict(opt_state)
|
| start_step = checkpoint["step"] + 1
|
| print0(f"Resumed from step {checkpoint['step']}, continuing from step {start_step}", console=True)
|
| del checkpoint
|
|
|
|
|
| def get_lr(step: int):
|
| x = step / args.num_iterations
|
| assert 0 <= x < 1
|
| if x < 1 - args.cooldown_frac:
|
| return 1.0
|
| else:
|
| return (1 - x) / args.cooldown_frac
|
|
|
|
|
| @lru_cache(1)
|
| def get_window_size_blocks_helper(window_size: int):
|
| return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
|
| def get_window_size_blocks(step: int):
|
| x = step / args.num_iterations
|
| assert 0 <= x <= 1
|
|
|
| factor = 4 * x ** 3 - 6 * x ** 2 + 3 * x
|
| window_size = next_multiple_of_n(3456 * factor, n=128)
|
| return get_window_size_blocks_helper(window_size)
|
|
|
| model: nn.Module = torch.compile(model, dynamic=False)
|
|
|
|
|
|
|
|
|
| warmup_steps = 10
|
| initial_state = copy.deepcopy(dict(model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]))
|
| for warmup_step in range(warmup_steps):
|
| print0(f"Warmup step {warmup_step+1}/{warmup_steps}")
|
| inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda")
|
| model(inputs.to(torch.int32), targets, get_window_size_blocks(0)).backward()
|
| for param in model.parameters():
|
| dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
| for opt in optimizers:
|
| opt.step()
|
| model.zero_grad(set_to_none=True)
|
| model.load_state_dict(initial_state["model"])
|
| for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
|
| opt.load_state_dict(opt_state)
|
| del initial_state
|
|
|
|
|
|
|
|
|
| torch.cuda.reset_peak_memory_stats()
|
| train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size)
|
| training_time_ms = 0
|
| dist.barrier()
|
| t0 = time.perf_counter()
|
|
|
| train_steps = args.num_iterations
|
| for step in range(start_step, train_steps + 1):
|
| last_step = (step == train_steps)
|
|
|
|
|
| if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
|
| dist.barrier()
|
| training_time_ms += 1000 * (time.perf_counter() - t0)
|
| model.eval()
|
| val_batch_size = world_size * args.val_seq_len
|
| assert args.val_tokens % val_batch_size == 0
|
| val_steps = args.val_tokens // val_batch_size
|
| val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
|
| val_loss = 0
|
| with torch.no_grad():
|
| for _ in range(val_steps):
|
| inputs, targets = next(val_loader)
|
| val_loss += model(inputs, targets, get_window_size_blocks(step))
|
| val_loss /= val_steps
|
| del val_loader
|
| dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG)
|
| print0(f"step:{step}/{train_steps} val_loss:{val_loss:.6f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True)
|
|
|
|
|
| if master_process:
|
| wandb.log({
|
| "val_loss": val_loss.item() if hasattr(val_loss, 'item') else val_loss,
|
| "step": step,
|
| "train_time_ms": training_time_ms,
|
| "step_avg_ms": training_time_ms / max(step, 1),
|
| "lr_mult": get_lr(step) if step < train_steps else 0,
|
| })
|
|
|
|
|
| if master_process and args.save_checkpoint and step > 0 and step % args.save_every == 0:
|
| log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
|
| os.makedirs(f"logs/{run_id_full}", exist_ok=True)
|
| torch.save(log, f"logs/{run_id_full}/state_step{step:06d}.pt")
|
| print0(f"Saved checkpoint at step {step}", console=True)
|
|
|
| model.train()
|
| dist.barrier()
|
| t0 = time.perf_counter()
|
|
|
| if last_step:
|
| if master_process and args.save_checkpoint:
|
| log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
|
| os.makedirs(f"logs/{run_id_full}", exist_ok=True)
|
| torch.save(log, f"logs/{run_id_full}/state_step{step:06d}.pt")
|
| break
|
|
|
|
|
| inputs, targets = next(train_loader)
|
| train_loss = model(inputs, targets, get_window_size_blocks(step))
|
| train_loss.backward()
|
| opt2futures = {
|
| opt: [dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True).get_future() for p in params]
|
| for opt, params in opt2params.items()
|
| }
|
| for opt in optimizers:
|
| for group in opt.param_groups:
|
| group["lr"] = group["initial_lr"] * get_lr(step)
|
| for group in optimizer2.param_groups:
|
| frac = min(step / 300, 1)
|
| group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
|
| for opt in optimizers:
|
| torch.futures.collect_all(opt2futures[opt]).wait()
|
| opt.step()
|
| model.zero_grad(set_to_none=True)
|
|
|
| approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
|
| print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True)
|
|
|
|
|
| if master_process and step % args.wandb_log_every == 0:
|
| wandb.log({
|
| "train_loss": train_loss.item(),
|
| "step": step,
|
| "train_time_ms": approx_training_time_ms,
|
| "step_avg_ms": approx_training_time_ms / (step + 1),
|
| "lr_mult": get_lr(step),
|
| }, step=step)
|
|
|
| print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
|
| f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
|
| dist.destroy_process_group()
|
|
|