ronnengmail commited on
Commit
198c544
·
verified ·
1 Parent(s): 6aac7da

Upload training_scripts/train_multilingual_3b_fsdp.py with huggingface_hub

Browse files
training_scripts/train_multilingual_3b_fsdp.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multilingual 3.14B GPT Training — FSDP Version (Arabic-Rebalanced Data)
4
+
5
+ Converted from DDP script. Key changes:
6
+ - FullyShardedDataParallel (FSDP) with FULL_SHARD strategy
7
+ - Mixed precision: bf16 compute, fp32 reduce
8
+ - transformer_auto_wrap_policy wrapping each Block
9
+ - Activation checkpointing via FSDP's apply_activation_checkpointing
10
+ - FULL_STATE_DICT for checkpoint saving
11
+ - SWA simplified (gather full state dict on rank 0)
12
+
13
+ Architecture: dim=3072, depth=26, heads=24, ~3.14B params
14
+ Data: training-data-v2 (4.48B tokens, multi-epoch)
15
+ Schedule: WSD-LINEAR
16
+ LR: 3e-4
17
+
18
+ Run with:
19
+ /opt/pytorch/bin/torchrun --nproc_per_node=8 --master_port=29500 train_multilingual_3b_fsdp.py
20
+ """
21
+
22
+ import os, sys, json, math, time, copy, functools, datetime
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torch.distributed as dist
28
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
29
+
30
+ # FSDP imports
31
+ from torch.distributed.fsdp import (
32
+ FullyShardedDataParallel as FSDP,
33
+ ShardingStrategy,
34
+ MixedPrecision,
35
+ StateDictType,
36
+ FullStateDictConfig,
37
+ BackwardPrefetch,
38
+ CPUOffload,
39
+ )
40
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
41
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
42
+
43
+ # Activation checkpointing for FSDP
44
+ try:
45
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
46
+ apply_activation_checkpointing,
47
+ checkpoint_wrapper,
48
+ CheckpointImpl,
49
+ )
50
+ HAS_FSDP_AC = True
51
+ except ImportError:
52
+ HAS_FSDP_AC = False
53
+
54
+ # ============ MODEL CONFIG ============
55
+ VOCAB_SIZE = 32000
56
+ DIM = 3072
57
+ DEPTH = 26
58
+ N_HEADS = 24
59
+ MAX_SEQ_LEN = 2048
60
+ ROPE_THETA = 10000
61
+ DROPOUT = 0.05
62
+
63
+ # ============ TRAINING CONFIG ============
64
+ TOTAL_STEPS = 20000
65
+ WARMUP_STEPS = 600
66
+ STABLE_END = 14000
67
+ MIN_LR_RATIO = 0.03
68
+
69
+ BATCH_PER_GPU = 4 # FSDP uses less memory per GPU → can increase from 2 to 4
70
+ GRAD_ACCUM = 8 # With 8 GPUs: 8*4*8 = 256 seqs = 524K tokens/step
71
+ # Total: 20000 * 524288 = 10.49B tokens
72
+
73
+ ADAMW_LR = 3e-4
74
+ ADAMW_BETAS = (0.9, 0.98)
75
+ ADAMW_WD = 0.02
76
+ ADAMW_EPS = 1e-8
77
+
78
+ LABEL_SMOOTHING = 0.06
79
+ GRAD_CLIP = 1.0
80
+
81
+ SWA_START_FRAC = 0.40
82
+ SWA_FREQ = 40
83
+
84
+ EVAL_EVERY = 500
85
+ SAVE_EVERY = 500
86
+ LOG_EVERY = 50
87
+ RESUME_STEP = int(os.environ.get('RESUME_STEP', '0')) # Set via env to resume from checkpoint
88
+
89
+ DATA_DIR = "/tmp/training-data"
90
+ CKPT_DIR = "/tmp/checkpoints"
91
+ LOG_FILE = "/tmp/training.log"
92
+ EVAL_FILE = "/tmp/eval_results.json"
93
+
94
+ S3_BUCKET = "autoresearch-dashboard-196766918360"
95
+ S3_PREFIX = "multilingual-7b"
96
+ VERSION = "3b-v1-fsdp"
97
+
98
+ # ============ MODEL ============
99
+ class RMSNorm(nn.Module):
100
+ def __init__(self, dim, eps=1e-6):
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(dim))
103
+ self.eps = eps
104
+ def forward(self, x):
105
+ return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.weight
106
+
107
+ class SwiGLU(nn.Module):
108
+ def __init__(self, dim, hidden_dim):
109
+ super().__init__()
110
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
111
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
112
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
113
+ def forward(self, x):
114
+ return self.down(F.silu(self.gate(x)) * self.up(x))
115
+
116
+ def apply_rope(x, cos, sin):
117
+ x1, x2 = x[..., ::2], x[..., 1::2]
118
+ return torch.stack((x1*cos - x2*sin, x1*sin + x2*cos), dim=-1).flatten(-2)
119
+
120
+ class Attention(nn.Module):
121
+ def __init__(self, dim, n_heads, dropout=0.0):
122
+ super().__init__()
123
+ self.n_heads = n_heads
124
+ self.head_dim = dim // n_heads
125
+ self.qkv = nn.Linear(dim, 3*dim, bias=False)
126
+ self.proj = nn.Linear(dim, dim, bias=False)
127
+ self.attn_dropout = dropout
128
+ def forward(self, x, cos, sin):
129
+ B, T, C = x.shape
130
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
131
+ q, k, v = qkv[0], qkv[1], qkv[2]
132
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
133
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True,
134
+ dropout_p=self.attn_dropout if self.training else 0.0)
135
+ return self.proj(y.transpose(1, 2).contiguous().view(B, T, C))
136
+
137
+ class Block(nn.Module):
138
+ def __init__(self, dim, n_heads, mlp_dim, dropout=0.0):
139
+ super().__init__()
140
+ self.ln1 = RMSNorm(dim)
141
+ self.attn = Attention(dim, n_heads, dropout)
142
+ self.ln2 = RMSNorm(dim)
143
+ self.mlp = SwiGLU(dim, mlp_dim)
144
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
145
+ def forward(self, x, cos, sin):
146
+ x = x + self.drop(self.attn(self.ln1(x), cos, sin))
147
+ x = x + self.drop(self.mlp(self.ln2(x)))
148
+ return x
149
+
150
+ class GPT(nn.Module):
151
+ def __init__(self, vocab_size=VOCAB_SIZE, dim=DIM, depth=DEPTH, n_heads=N_HEADS,
152
+ max_seq_len=MAX_SEQ_LEN, rope_theta=ROPE_THETA, dropout=DROPOUT):
153
+ super().__init__()
154
+ self.tok_emb = nn.Embedding(vocab_size, dim)
155
+ mlp_dim = ((int(2 * dim * 4 / 3) + 63) // 64) * 64 # = 8192 for dim=3072
156
+ self.blocks = nn.ModuleList([Block(dim, n_heads, mlp_dim, dropout) for _ in range(depth)])
157
+ self.ln_f = RMSNorm(dim)
158
+ self.head = nn.Linear(dim, vocab_size, bias=False)
159
+ self.head.weight = self.tok_emb.weight # weight tying
160
+ hd = dim // n_heads
161
+ freqs = 1.0 / (rope_theta ** (torch.arange(0, hd, 2).float() / hd))
162
+ angles = torch.outer(torch.arange(max_seq_len).float(), freqs)
163
+ self.register_buffer('rope_cos', angles.cos())
164
+ self.register_buffer('rope_sin', angles.sin())
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, module):
168
+ if isinstance(module, nn.Linear):
169
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
170
+ if module.bias is not None:
171
+ torch.nn.init.zeros_(module.bias)
172
+ elif isinstance(module, nn.Embedding):
173
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
174
+
175
+ def forward(self, idx):
176
+ B, T = idx.shape
177
+ x = self.tok_emb(idx)
178
+ cos = self.rope_cos[:T][None, None]
179
+ sin = self.rope_sin[:T][None, None]
180
+ for block in self.blocks:
181
+ # NOTE: With FSDP activation checkpointing applied externally via
182
+ # apply_activation_checkpointing, we do NOT need manual torch_checkpoint here.
183
+ # FSDP's checkpoint wrapper handles it.
184
+ x = block(x, cos, sin)
185
+ return self.head(self.ln_f(x))
186
+
187
+ # ============ WSD LINEAR SCHEDULE ============
188
+ def wsd_lr_linear(step, total_steps, warmup_steps, stable_end, min_lr_ratio, base_lr):
189
+ if step < warmup_steps:
190
+ return base_lr * (step + 1) / max(warmup_steps, 1)
191
+ elif step < stable_end:
192
+ return base_lr
193
+ else:
194
+ progress = (step - stable_end) / max(total_steps - stable_end, 1)
195
+ return base_lr * (1.0 - progress * (1.0 - min_lr_ratio))
196
+
197
+ # ============ DATA LOADING ============
198
+ class BinaryDataset:
199
+ def __init__(self, path, seq_len):
200
+ self.data = np.memmap(path, dtype=np.uint16, mode='r')
201
+ self.seq_len = seq_len
202
+ self.n_tokens = len(self.data)
203
+ def get_batch(self, batch_size, device, rng):
204
+ ix = torch.from_numpy(rng.integers(0, self.n_tokens - self.seq_len - 1, size=(batch_size,)))
205
+ x = torch.stack([torch.from_numpy(self.data[i:i+self.seq_len].astype(np.int64)) for i in ix])
206
+ y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.seq_len].astype(np.int64)) for i in ix])
207
+ return x.to(device), y.to(device)
208
+
209
+ def load_val_data(path, seq_len, max_batches=20, batch_size=8):
210
+ data = np.memmap(path, dtype=np.uint16, mode='r')
211
+ n_tokens = len(data)
212
+ batches = []
213
+ stride = seq_len + 1
214
+ all_starts = list(range(0, n_tokens - stride, stride))
215
+ max_samples = max_batches * batch_size
216
+ if len(all_starts) > max_samples:
217
+ step_size = len(all_starts) // max_samples
218
+ all_starts = all_starts[::step_size][:max_samples]
219
+ for i in range(0, len(all_starts), batch_size):
220
+ batch_starts = all_starts[i:i+batch_size]
221
+ if len(batch_starts) < batch_size:
222
+ break
223
+ x = torch.stack([torch.from_numpy(data[s:s+seq_len].astype(np.int64)) for s in batch_starts])
224
+ y = torch.stack([torch.from_numpy(data[s+1:s+1+seq_len].astype(np.int64)) for s in batch_starts])
225
+ batches.append((x, y))
226
+ return batches
227
+
228
+ @torch.no_grad()
229
+ def evaluate(model, val_batches, device):
230
+ """Evaluate model. Works with both FSDP-wrapped and unwrapped models."""
231
+ model.eval()
232
+ total_loss = 0.0
233
+ total_tokens = 0
234
+ for x, y in val_batches:
235
+ x, y = x.to(device), y.to(device)
236
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
237
+ logits = model(x)
238
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='sum')
239
+ total_loss += loss.item()
240
+ total_tokens += y.numel()
241
+ model.train()
242
+ return (total_loss / total_tokens) / math.log(2) if total_tokens > 0 else float('inf')
243
+
244
+ class Logger:
245
+ def __init__(self, log_file, rank):
246
+ self.rank = rank
247
+ if rank == 0:
248
+ self.f = open(log_file, 'w')
249
+ def log(self, msg):
250
+ if self.rank == 0:
251
+ ts = time.strftime('%Y-%m-%d %H:%M:%S')
252
+ line = f"[{ts}] {msg}"
253
+ print(line, flush=True)
254
+ self.f.write(line + '\n')
255
+ self.f.flush()
256
+ def close(self):
257
+ if self.rank == 0:
258
+ self.f.close()
259
+
260
+ class SWAState:
261
+ """SWA for FSDP: gathers full state dict on rank 0 before averaging."""
262
+ def __init__(self):
263
+ self.avg_state = None
264
+ self.n_averaged = 0
265
+
266
+ def update(self, model):
267
+ """Gather full state dict from FSDP model and update running average on rank 0."""
268
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
269
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
270
+ state = model.state_dict()
271
+
272
+ # Only rank 0 gets the full state dict with rank0_only=True
273
+ if dist.get_rank() != 0:
274
+ return
275
+
276
+ if self.avg_state is None:
277
+ self.avg_state = {k: v.float().clone() for k, v in state.items()}
278
+ self.n_averaged = 1
279
+ else:
280
+ n = self.n_averaged
281
+ for k in self.avg_state:
282
+ self.avg_state[k] = (self.avg_state[k] * n + state[k].float()) / (n + 1)
283
+ self.n_averaged += 1
284
+
285
+ # ============ FSDP HELPERS ============
286
+ def get_fsdp_wrap_policy():
287
+ """Wrap each Block as a separate FSDP unit."""
288
+ return functools.partial(
289
+ transformer_auto_wrap_policy,
290
+ transformer_layer_cls={Block},
291
+ )
292
+
293
+ def get_mixed_precision():
294
+ """bf16 for compute, fp32 for reduce (gradient all-reduce in fp32 for stability)."""
295
+ return MixedPrecision(
296
+ param_dtype=torch.bfloat16,
297
+ reduce_dtype=torch.float32,
298
+ buffer_dtype=torch.bfloat16,
299
+ )
300
+
301
+ def save_fsdp_checkpoint(model, optimizer, scaler, step, best_val_bpb,
302
+ tokens_processed, eval_results, swa_n, ckpt_dir, logger):
303
+ """Save full state dict checkpoint from FSDP model (rank 0 only)."""
304
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
305
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
306
+ model_state = model.state_dict()
307
+
308
+ if dist.get_rank() == 0:
309
+ ckpt = {
310
+ 'step': step,
311
+ 'model': model_state,
312
+ 'best_val_bpb': best_val_bpb,
313
+ 'tokens_processed': tokens_processed,
314
+ 'eval_results': eval_results,
315
+ 'swa_n': swa_n,
316
+ # NOTE: We don't save optimizer/scaler state for simplicity with FSDP.
317
+ # For full resumability, use FSDP's ShardedStateDictConfig instead.
318
+ }
319
+ torch.save(ckpt, f"{ckpt_dir}/ckpt_step_{step}.pt")
320
+ logger.log(f"Saved checkpoint at step {step}")
321
+ dist.barrier()
322
+
323
+ def save_fsdp_model(model, path, logger):
324
+ """Save just the model state dict (rank 0 only)."""
325
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
326
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
327
+ state = model.state_dict()
328
+ if dist.get_rank() == 0:
329
+ torch.save(state, path)
330
+ logger.log(f"Saved model to {path}")
331
+ dist.barrier()
332
+
333
+ # ============ MAIN ============
334
+ def main():
335
+ dist.init_process_group('nccl', timeout=datetime.timedelta(minutes=30))
336
+ rank = dist.get_rank()
337
+ world_size = dist.get_world_size()
338
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
339
+ device = torch.device(f'cuda:{local_rank}')
340
+ torch.cuda.set_device(device)
341
+
342
+ effective_batch = BATCH_PER_GPU * GRAD_ACCUM * world_size
343
+ tokens_per_step = effective_batch * MAX_SEQ_LEN
344
+
345
+ logger = Logger(LOG_FILE, rank)
346
+ logger.log(f"=== Multilingual 3.14B Training — FSDP (Arabic-Rebalanced) ===")
347
+ logger.log(f"World size: {world_size}, Batch/GPU: {BATCH_PER_GPU}, Grad accum: {GRAD_ACCUM}")
348
+ logger.log(f"Effective batch: {effective_batch} seqs = {tokens_per_step:,} tokens/step")
349
+ logger.log(f"Total steps: {TOTAL_STEPS} = {TOTAL_STEPS * tokens_per_step:,} tokens")
350
+ logger.log(f"Schedule: WSD-LINEAR | warmup={WARMUP_STEPS} | stable_end={STABLE_END} | total={TOTAL_STEPS}")
351
+ logger.log(f"AdamW LR={ADAMW_LR}, betas={ADAMW_BETAS}, WD={ADAMW_WD}")
352
+ logger.log(f"Label smoothing={LABEL_SMOOTHING}, min_lr={MIN_LR_RATIO}, grad_clip={GRAD_CLIP}")
353
+ logger.log(f"SWA: start={int(TOTAL_STEPS*SWA_START_FRAC)}, freq={SWA_FREQ}")
354
+ logger.log(f"Model: dim={DIM}, depth={DEPTH}, heads={N_HEADS}, dropout={DROPOUT}")
355
+ logger.log(f"FSDP: FULL_SHARD, MixedPrecision(bf16/fp32), Block-level wrapping")
356
+
357
+ os.makedirs(CKPT_DIR, exist_ok=True)
358
+
359
+ # Data
360
+ logger.log("Loading training data...")
361
+ train_ds = BinaryDataset(f"{DATA_DIR}/train.bin", MAX_SEQ_LEN)
362
+ logger.log(f"Train tokens: {train_ds.n_tokens:,}")
363
+
364
+ logger.log("Loading validation data...")
365
+ val_batches = load_val_data(f"{DATA_DIR}/val.bin", MAX_SEQ_LEN)
366
+ val_lang_batches = {}
367
+ for lang in ['en', 'ar', 'he', 'fa']:
368
+ vpath = f"{DATA_DIR}/val_{lang}.bin"
369
+ if os.path.exists(vpath):
370
+ val_lang_batches[lang] = load_val_data(vpath, MAX_SEQ_LEN)
371
+ logger.log(f" val_{lang}: {len(val_lang_batches[lang])} batches")
372
+
373
+ # Model — create on CPU first, then FSDP wraps and shards to GPUs
374
+ logger.log("Creating model...")
375
+ torch.manual_seed(42)
376
+ model = GPT()
377
+ n_params = sum(p.numel() for p in model.parameters())
378
+ n_params_no_emb = n_params - model.tok_emb.weight.numel()
379
+ logger.log(f"Model params: {n_params:,} (non-embedding: {n_params_no_emb:,})")
380
+
381
+ # Wrap with FSDP
382
+ logger.log("Wrapping model with FSDP...")
383
+ wrap_policy = get_fsdp_wrap_policy()
384
+ mixed_precision = get_mixed_precision()
385
+
386
+ model = FSDP(
387
+ model,
388
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
389
+ mixed_precision=mixed_precision,
390
+ auto_wrap_policy=wrap_policy,
391
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
392
+ device_id=local_rank,
393
+ limit_all_gathers=True,
394
+ use_orig_params=True, # Required for weight decay on specific params
395
+ )
396
+
397
+ logger.log(f"FSDP wrapped. GPU memory: {torch.cuda.memory_allocated(device)/1e9:.1f} GB")
398
+
399
+ # Apply activation checkpointing to each Block within FSDP
400
+ if HAS_FSDP_AC:
401
+ check_fn = lambda submodule: isinstance(submodule, Block)
402
+ apply_activation_checkpointing(
403
+ model,
404
+ checkpoint_wrapper_fn=checkpoint_wrapper,
405
+ check_fn=check_fn,
406
+ )
407
+ logger.log("Applied FSDP activation checkpointing to Block layers")
408
+ else:
409
+ logger.log("WARNING: FSDP activation checkpointing not available, using manual checkpointing")
410
+
411
+ # Optimizer — use standard AdamW (bitsandbytes may not work well with FSDP sharded params)
412
+ try:
413
+ import bitsandbytes as bnb
414
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WD,
415
+ betas=ADAMW_BETAS, eps=ADAMW_EPS)
416
+ logger.log("Using 8-bit AdamW (bitsandbytes)")
417
+ except ImportError:
418
+ optimizer = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WD,
419
+ betas=ADAMW_BETAS, eps=ADAMW_EPS)
420
+ logger.log("Using standard AdamW (bitsandbytes not available)")
421
+
422
+ swa = SWAState()
423
+ swa_start_step = int(TOTAL_STEPS * SWA_START_FRAC)
424
+ rng = np.random.default_rng(42 + rank)
425
+
426
+ # Use ShardedGradScaler for FSDP (handles sharded gradients correctly)
427
+ scaler = ShardedGradScaler()
428
+
429
+ best_val_bpb = float('inf')
430
+ eval_results = []
431
+ tokens_processed = 0
432
+ start_step = 1
433
+
434
+ # Resume from checkpoint if requested
435
+ if RESUME_STEP > 0:
436
+ ckpt_path = f"{CKPT_DIR}/ckpt_step_{RESUME_STEP}.pt"
437
+ if not os.path.exists(ckpt_path):
438
+ # Try downloading from S3
439
+ if rank == 0:
440
+ logger.log(f"Downloading checkpoint from S3 for step {RESUME_STEP}...")
441
+ os.system(f"aws s3 cp s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}/ckpt_step_{RESUME_STEP}.pt {ckpt_path}")
442
+ dist.barrier()
443
+ if os.path.exists(ckpt_path):
444
+ if rank == 0:
445
+ logger.log(f"Loading checkpoint from step {RESUME_STEP}...")
446
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
447
+ # Load model weights into FSDP
448
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
449
+ model.load_state_dict(ckpt['model'])
450
+ start_step = ckpt['step'] + 1
451
+ tokens_processed = ckpt.get('tokens_processed', 0)
452
+ best_val_bpb = ckpt.get('best_val_bpb', float('inf'))
453
+ eval_results = ckpt.get('eval_results', [])
454
+ swa.n_averaged = ckpt.get('swa_n', 0)
455
+ if rank == 0:
456
+ logger.log(f"Resumed from step {RESUME_STEP}. Tokens: {tokens_processed:,}, Best BPB: {best_val_bpb:.4f}")
457
+ logger.log(f"NOTE: Optimizer state reset (fresh AdamW). LR schedule continues from step {start_step}.")
458
+ del ckpt
459
+ # Advance RNG to approximately correct position
460
+ rng = np.random.default_rng(42 + rank + RESUME_STEP * 1000)
461
+ dist.barrier()
462
+ else:
463
+ if rank == 0:
464
+ logger.log(f"WARNING: No checkpoint found for step {RESUME_STEP}, starting from scratch")
465
+
466
+ start_time = time.time()
467
+ logger.log(f"Starting training from step {start_step}...")
468
+
469
+ for step in range(start_step, TOTAL_STEPS + 1):
470
+ model.train()
471
+ lr = wsd_lr_linear(step, TOTAL_STEPS, WARMUP_STEPS, STABLE_END, MIN_LR_RATIO, ADAMW_LR)
472
+ for g in optimizer.param_groups:
473
+ g['lr'] = lr
474
+
475
+ optimizer.zero_grad()
476
+ accum_loss = 0.0
477
+
478
+ for micro in range(GRAD_ACCUM):
479
+ x, y = train_ds.get_batch(BATCH_PER_GPU, device, rng)
480
+
481
+ # For FSDP with gradient accumulation, we need to use no_sync() for all
482
+ # micro-steps except the last one to avoid unnecessary all-reduce
483
+ ctx = model.no_sync() if micro < GRAD_ACCUM - 1 else nullcontext()
484
+ with ctx:
485
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
486
+ logits = model(x)
487
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1),
488
+ label_smoothing=LABEL_SMOOTHING) / GRAD_ACCUM
489
+ scaler.scale(loss).backward()
490
+ accum_loss += loss.item()
491
+
492
+ scaler.unscale_(optimizer)
493
+ # FSDP clip_grad_norm_ works on the sharded params
494
+ model.clip_grad_norm_(GRAD_CLIP)
495
+ scaler.step(optimizer)
496
+ scaler.update()
497
+ tokens_processed += tokens_per_step
498
+
499
+ # SWA: gather full state dict and average on rank 0
500
+ if step >= swa_start_step and step % SWA_FREQ == 0:
501
+ swa.update(model)
502
+
503
+ if step % LOG_EVERY == 0 and rank == 0:
504
+ elapsed = time.time() - start_time
505
+ tps = tokens_processed / elapsed
506
+ bpb = accum_loss / math.log(2)
507
+ phase = "warmup" if step < WARMUP_STEPS else ("stable" if step < STABLE_END else "decay")
508
+ mem = torch.cuda.max_memory_allocated(device) / 1e9
509
+ logger.log(f"Step {step}/{TOTAL_STEPS} [{phase}] | Loss: {accum_loss:.4f} | "
510
+ f"BPB: {bpb:.4f} | LR: {lr:.6f} | Tokens: {tokens_processed:,} | "
511
+ f"TPS: {tps:,.0f} | SWA: {swa.n_averaged} | Mem: {mem:.1f}GB | {elapsed/60:.1f}min")
512
+
513
+ # Save checkpoint BEFORE eval to avoid FSDP state dict deadlocks
514
+ if step % SAVE_EVERY == 0:
515
+ save_fsdp_checkpoint(
516
+ model, optimizer, scaler, step, best_val_bpb,
517
+ tokens_processed, eval_results, swa.n_averaged, CKPT_DIR, logger
518
+ )
519
+ if rank == 0:
520
+ os.system(f"aws s3 sync {CKPT_DIR}/ s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}/ --quiet &")
521
+ os.system(f"aws s3 cp {EVAL_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_eval_results.json --quiet &")
522
+ os.system(f"aws s3 cp {LOG_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_training.log --quiet &")
523
+
524
+ if step % EVAL_EVERY == 0 or step == TOTAL_STEPS:
525
+ dist.barrier() # Ensure all ranks are synced before eval
526
+ # All ranks participate in eval (FSDP needs all ranks for forward pass)
527
+ combined_bpb = evaluate(model, val_batches, device)
528
+
529
+ if rank == 0:
530
+ logger.log(f"--- Evaluation at step {step} ---")
531
+ logger.log(f" Combined val BPB: {combined_bpb:.4f}")
532
+ result = {"step": step, "tokens": tokens_processed, "combined_bpb": combined_bpb}
533
+
534
+ for lang, batches in val_lang_batches.items():
535
+ lang_bpb = evaluate(model, batches, device)
536
+ if rank == 0:
537
+ result[f"{lang}_bpb"] = lang_bpb
538
+ logger.log(f" {lang} val BPB: {lang_bpb:.4f}")
539
+
540
+ dist.barrier() # All ranks must finish eval before proceeding
541
+
542
+ # Determine if new best (rank 0 decides, broadcast to all)
543
+ is_new_best = False
544
+ if rank == 0:
545
+ eval_results.append(result)
546
+ with open(EVAL_FILE, 'w') as f:
547
+ json.dump(eval_results, f, indent=2)
548
+ is_new_best = combined_bpb < best_val_bpb
549
+ if is_new_best:
550
+ best_val_bpb = combined_bpb
551
+ logger.log(f" New best! BPB: {combined_bpb:.4f}")
552
+
553
+ # All ranks participate in model save (FSDP state dict gather requires it)
554
+ save_flag = torch.tensor([1 if is_new_best else 0], device=device)
555
+ dist.broadcast(save_flag, src=0)
556
+ if save_flag.item() == 1:
557
+ save_fsdp_model(model, f"{CKPT_DIR}/best_model.pt", logger)
558
+ dist.barrier()
559
+
560
+ # Finalize
561
+ save_fsdp_model(model, f"{CKPT_DIR}/final_model.pt", logger)
562
+
563
+ if rank == 0:
564
+ # SWA evaluation
565
+ if swa.avg_state is not None and swa.n_averaged > 0:
566
+ logger.log(f"Evaluating SWA model ({swa.n_averaged} checkpoints)...")
567
+ swa_model = GPT().to(device)
568
+ swa_load = {k: v.float().to(device) for k, v in swa.avg_state.items()}
569
+ swa_model.load_state_dict(swa_load)
570
+ swa_bpb = evaluate(swa_model, val_batches, device)
571
+ logger.log(f"SWA model combined BPB: {swa_bpb:.4f} (vs best raw: {best_val_bpb:.4f})")
572
+ swa_result = {"step": "swa", "combined_bpb": swa_bpb, "n_averaged": swa.n_averaged}
573
+ for lang, batches in val_lang_batches.items():
574
+ lang_bpb = evaluate(swa_model, batches, device)
575
+ swa_result[f"{lang}_bpb"] = lang_bpb
576
+ logger.log(f" SWA {lang} BPB: {lang_bpb:.4f}")
577
+ eval_results.append(swa_result)
578
+ torch.save(swa_load, f"{CKPT_DIR}/swa_model.pt")
579
+ with open(EVAL_FILE, 'w') as f:
580
+ json.dump(eval_results, f, indent=2)
581
+ del swa_model
582
+
583
+ # Final S3 upload
584
+ logger.log("Uploading all artifacts to S3...")
585
+ os.system(f"aws s3 sync {CKPT_DIR}/ s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}/")
586
+ os.system(f"aws s3 cp {LOG_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_training.log")
587
+ os.system(f"aws s3 cp {EVAL_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_eval_results.json")
588
+
589
+ elapsed = time.time() - start_time
590
+ logger.log(f"=== Training complete! Total time: {elapsed/3600:.2f}h ===")
591
+ logger.log(f"Best combined BPB: {best_val_bpb:.4f}")
592
+ logger.log(f"Total tokens: {tokens_processed:,}")
593
+
594
+ logger.close()
595
+ dist.destroy_process_group()
596
+
597
+ # Need nullcontext for no_sync
598
+ from contextlib import nullcontext
599
+
600
+ if __name__ == '__main__':
601
+ main()