ASTERIZER commited on
Commit
2be87ed
Β·
verified Β·
1 Parent(s): cd7ee10

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +633 -0
train.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LUNA 100M β€” Config-Driven Dynamic Training Script
3
+ ==================================================
4
+ Reads train_config.yaml for all hyperparameters.
5
+
6
+ auto_config: true -> hardware probed; batch/lr/workers set automatically
7
+ auto_config: false -> every value in config used exactly as-is
8
+
9
+ Usage:
10
+ python train.py # uses train_config.yaml defaults
11
+ python train.py --config train_config.yaml # explicit config path
12
+ python train.py --data_path /mnt/data/litdata_final # override data path only
13
+ python train.py --max_tokens 10000000 # short smoke-test run
14
+ """
15
+
16
+ import os
17
+ import gc
18
+ import sys
19
+ import math
20
+ import time
21
+ import json
22
+ import argparse
23
+ import yaml
24
+ import psutil
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.amp import autocast, GradScaler
29
+ from pathlib import Path
30
+
31
+ # Reduce CUDA memory fragmentation
32
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
33
+
34
+
35
+ # ─── Model ────────────────────────────────────────────────────────────────────
36
+
37
+ class RotaryEmbedding(nn.Module):
38
+ def __init__(self, dim, max_seq_len=1024):
39
+ super().__init__()
40
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+ t = torch.arange(max_seq_len).float()
43
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
44
+ emb = torch.cat([freqs, freqs], dim=-1)
45
+ self.register_buffer("cos_cached", emb.cos())
46
+ self.register_buffer("sin_cached", emb.sin())
47
+
48
+ def forward(self, seq_len):
49
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
50
+
51
+
52
+ def rotate_half(x):
53
+ x1, x2 = x.chunk(2, dim=-1)
54
+ return torch.cat([-x2, x1], dim=-1)
55
+
56
+
57
+ def apply_rotary(x, cos, sin):
58
+ c = cos.unsqueeze(0).unsqueeze(0)
59
+ s = sin.unsqueeze(0).unsqueeze(0)
60
+ return x * c + rotate_half(x) * s
61
+
62
+
63
+ class CausalSelfAttention(nn.Module):
64
+ def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
65
+ super().__init__()
66
+ self.n_head = n_head
67
+ self.head_dim = n_embd // n_head
68
+ self.rot_dim = int(self.head_dim * rotary_pct)
69
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
70
+ self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
71
+ self.rotary = RotaryEmbedding(self.rot_dim, block_size)
72
+
73
+ def forward(self, x):
74
+ B, T, C = x.size()
75
+ qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
76
+ q, k, v = qkv.unbind(0)
77
+ cos, sin = self.rotary(T)
78
+ q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
79
+ k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
80
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
81
+ return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
82
+
83
+
84
+ class MLP(nn.Module):
85
+ def __init__(self, n_embd):
86
+ super().__init__()
87
+ self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
88
+ self.gelu = nn.GELU()
89
+ self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
90
+
91
+ def forward(self, x):
92
+ return self.proj(self.gelu(self.fc(x)))
93
+
94
+
95
+ class Block(nn.Module):
96
+ def __init__(self, n_embd, n_head, block_size):
97
+ super().__init__()
98
+ self.ln1 = nn.LayerNorm(n_embd)
99
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
100
+ self.ln2 = nn.LayerNorm(n_embd)
101
+ self.mlp = MLP(n_embd)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln1(x))
105
+ x = x + self.mlp(self.ln2(x))
106
+ return x
107
+
108
+
109
+ class LUNAModel(nn.Module):
110
+ def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
111
+ super().__init__()
112
+ self.wte = nn.Embedding(vocab_size, n_embd)
113
+ self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
114
+ self.ln_f = nn.LayerNorm(n_embd)
115
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
116
+ self.lm_head.weight = self.wte.weight # tie
117
+ self.apply(self._init_weights)
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, (nn.Linear, nn.Embedding)):
121
+ m.weight.data.normal_(mean=0.0, std=0.02)
122
+ if isinstance(m, nn.Linear) and m.bias is not None:
123
+ m.bias.data.zero_()
124
+
125
+ def forward(self, idx, targets=None, return_logits=True):
126
+ x = self.wte(idx)
127
+ for block in self.blocks:
128
+ x = block(x)
129
+ x = self.ln_f(x)
130
+ logits = self.lm_head(x)
131
+ loss = None
132
+ if targets is not None:
133
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
134
+ if not return_logits:
135
+ logits = None
136
+ return logits, loss
137
+
138
+ @property
139
+ def num_params(self):
140
+ return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
141
+
142
+
143
+ # ─── Dataset ──────────────────────────────────────────────────────────────────
144
+
145
+ class LitDataDataset(torch.utils.data.Dataset):
146
+ def __init__(self, data_path: str, block_size: int = 1024):
147
+ import struct, numpy as np
148
+ self.block_size = block_size
149
+ self.data_path = Path(data_path)
150
+ with open(self.data_path / "index.json") as f:
151
+ idx = json.load(f)
152
+ self.chunks_meta = idx["chunks"]
153
+ self._cum_blocks = []
154
+ total = 0
155
+ for c in self.chunks_meta:
156
+ n = c["dim"] // (block_size + 1)
157
+ total += n
158
+ self._cum_blocks.append(total)
159
+ self.total_blocks = total
160
+ self._chunk_cache = {}
161
+
162
+ def _load_chunk(self, chunk_idx: int):
163
+ if chunk_idx in self._chunk_cache:
164
+ return self._chunk_cache[chunk_idx]
165
+ import struct, numpy as np
166
+ meta = self.chunks_meta[chunk_idx]
167
+ with open(self.data_path / meta["filename"], "rb") as f:
168
+ raw = f.read()
169
+ num_items = struct.unpack_from("<I", raw, 0)[0]
170
+ header_bytes = (num_items + 2) * 4
171
+ tokens = torch.from_numpy(np.frombuffer(raw[header_bytes:], dtype=np.int32).copy())
172
+ if len(self._chunk_cache) >= 4:
173
+ del self._chunk_cache[next(iter(self._chunk_cache))]
174
+ self._chunk_cache[chunk_idx] = tokens
175
+ return tokens
176
+
177
+ def __len__(self):
178
+ return self.total_blocks
179
+
180
+ def __getitem__(self, idx):
181
+ chunk_idx = 0
182
+ for i, cum in enumerate(self._cum_blocks):
183
+ if idx < cum:
184
+ chunk_idx = i
185
+ break
186
+ prev = self._cum_blocks[chunk_idx - 1] if chunk_idx > 0 else 0
187
+ tokens = self._load_chunk(chunk_idx)
188
+ s = (idx - prev) * (self.block_size + 1)
189
+ e = s + self.block_size + 1
190
+ chunk = tokens[s:e]
191
+ if len(chunk) < self.block_size + 1:
192
+ pad = torch.zeros(self.block_size + 1, dtype=torch.int32)
193
+ pad[:len(chunk)] = chunk
194
+ chunk = pad
195
+ chunk = chunk.long()
196
+ return chunk[:self.block_size], chunk[1:self.block_size + 1]
197
+
198
+
199
+ # ─── Hardware Detection ────────────────────────────────────────────────────────
200
+
201
+ def probe_hardware():
202
+ info = {
203
+ "cpu_cores": os.cpu_count() or 4,
204
+ "ram_gb": psutil.virtual_memory().total / 1024**3,
205
+ }
206
+ if torch.cuda.is_available():
207
+ props = torch.cuda.get_device_properties(0)
208
+ info.update({
209
+ "device": "cuda",
210
+ "gpu_name": props.name,
211
+ "vram_gb": props.total_memory / 1024**3,
212
+ "sm_major": props.major,
213
+ })
214
+ if props.major >= 8:
215
+ torch.backends.cuda.matmul.allow_tf32 = True
216
+ torch.backends.cudnn.allow_tf32 = True
217
+ info["precision"] = "bf16"
218
+ info["dtype"] = torch.bfloat16
219
+ else:
220
+ info["precision"] = "fp16"
221
+ info["dtype"] = torch.float16
222
+ else:
223
+ info.update({
224
+ "device": "cpu",
225
+ "gpu_name": "CPU",
226
+ "vram_gb": 0,
227
+ "sm_major": 0,
228
+ "precision": "fp32",
229
+ "dtype": torch.float32,
230
+ })
231
+ return info
232
+
233
+
234
+ def probe_max_batch(model, device, dtype, seq_len, vocab_size, max_search=4096, grad_accum_sim=4):
235
+ """Binary search for max micro_batch. Simulates grad_accum forward+backward
236
+ passes to account for real training memory patterns. Safety: x0.70."""
237
+ tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
238
+ lo, hi, best = 1, max_search, 1
239
+ while lo <= hi:
240
+ mid = (lo + hi) // 2
241
+ try:
242
+ torch.cuda.empty_cache(); gc.collect()
243
+ tmp_opt.zero_grad(set_to_none=True)
244
+ # Simulate grad_accum micro-batches (real training pattern)
245
+ for _ in range(grad_accum_sim):
246
+ x = torch.randint(0, vocab_size, (mid, seq_len), device=device)
247
+ t = torch.randint(0, vocab_size, (mid, seq_len), device=device)
248
+ with autocast(device_type="cuda", dtype=dtype):
249
+ _, loss = model(x, t, return_logits=False)
250
+ loss = loss / grad_accum_sim
251
+ loss.backward()
252
+ del x, t, loss
253
+ tmp_opt.step()
254
+ tmp_opt.zero_grad(set_to_none=True)
255
+ best = mid; lo = mid + 1
256
+ torch.cuda.empty_cache()
257
+ except torch.cuda.OutOfMemoryError:
258
+ try: del x, t, loss
259
+ except: pass
260
+ torch.cuda.empty_cache()
261
+ tmp_opt.zero_grad(set_to_none=True)
262
+ hi = mid - 1
263
+ except RuntimeError as e:
264
+ if "out of memory" in str(e).lower():
265
+ try: del x, t, loss
266
+ except: pass
267
+ torch.cuda.empty_cache()
268
+ tmp_opt.zero_grad(set_to_none=True)
269
+ hi = mid - 1
270
+ else:
271
+ raise
272
+ del tmp_opt; torch.cuda.empty_cache(); gc.collect()
273
+ safe = max(1, int(best * 0.70))
274
+ print(f" Probe found max_batch={best}, using {safe} (70% safety, tested with {grad_accum_sim} accum steps)")
275
+ return safe
276
+
277
+
278
+ # ─── LR Schedule ──────────────────────────────────────────────────────────────
279
+
280
+ def cosine_lr(step, warmup, total, lr_max, lr_min):
281
+ if step < warmup:
282
+ return lr_max * (step + 1) / warmup
283
+ p = (step - warmup) / max(1, total - warmup)
284
+ return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min)
285
+
286
+
287
+ # ─── Config Loading ───────────────────────────────────────────────────────────
288
+
289
+ def load_config(config_path: str) -> dict:
290
+ """Load YAML config and return flat namespace dict."""
291
+ with open(config_path, encoding="utf-8") as f:
292
+ raw = yaml.safe_load(f)
293
+
294
+ cfg = {
295
+ # top-level
296
+ "auto_config": raw.get("auto_config", True),
297
+ "data_path": raw.get("data_path", "Base/data/litdata_pretrain_final"),
298
+ "out_dir": raw.get("out_dir", "out/pretrain/luna-100m"),
299
+ "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
300
+ # model
301
+ "vocab_size": raw["model"]["vocab_size"],
302
+ "seq_len": raw["model"]["seq_len"],
303
+ "n_layer": raw["model"]["n_layer"],
304
+ "n_embd": raw["model"]["n_embd"],
305
+ "n_head": raw["model"]["n_head"],
306
+ # train
307
+ "max_tokens": raw["train"]["max_tokens"],
308
+ "lr_warmup_steps":raw["train"]["lr_warmup_steps"],
309
+ "save_interval": raw["train"]["save_interval"],
310
+ "log_interval": raw["train"]["log_interval"],
311
+ "max_norm": raw["train"]["max_norm"],
312
+ # optimizer
313
+ "lr": raw["optimizer"]["lr"],
314
+ "min_lr": raw["optimizer"]["min_lr"],
315
+ "weight_decay": raw["optimizer"]["weight_decay"],
316
+ "betas": tuple(raw["optimizer"]["betas"]),
317
+ "eps": raw["optimizer"]["eps"],
318
+ # batch
319
+ "global_batch": raw["batch"]["global_batch"],
320
+ "micro_batch": raw["batch"]["micro_batch"],
321
+ "grad_accum": raw["batch"]["grad_accum"],
322
+ # dataloader
323
+ "num_workers": raw["dataloader"]["num_workers"],
324
+ "pin_memory": raw["dataloader"]["pin_memory"],
325
+ # hardware
326
+ "precision": raw["hardware"]["precision"],
327
+ "compile": raw["hardware"]["compile"],
328
+ }
329
+ return cfg
330
+
331
+
332
+ def apply_cli_overrides(cfg: dict, cli_args: argparse.Namespace) -> dict:
333
+ """CLI args override config values (only if explicitly provided)."""
334
+ for key, val in vars(cli_args).items():
335
+ if key == "config":
336
+ continue
337
+ if val is not None: # argparse default=None means "not provided"
338
+ cfg[key] = val
339
+ return cfg
340
+
341
+
342
+ def resolve_auto(cfg: dict, hw: dict) -> dict:
343
+ """
344
+ When auto_config=True: override batch, workers, lr-warmup, pin_memory,
345
+ precision from real hardware. Never touches model arch or max_tokens.
346
+ Returns updated cfg plus injected hw info.
347
+ """
348
+ if not cfg["auto_config"]:
349
+ print(" [CONFIG] auto_config=false -- using manual values as-is")
350
+ cfg.update({"_hw": hw})
351
+ return cfg
352
+
353
+ print(" [CONFIG] auto_config=true -- tuning settings to this hardware")
354
+
355
+ # Precision
356
+ cfg["precision"] = hw["precision"]
357
+ cfg["_dtype"] = hw["dtype"]
358
+
359
+ # Workers
360
+ auto_workers = hw["cpu_cores"] // 2
361
+ # Cap by RAM: each worker caches up to 4 chunks Γ— ~67MB
362
+ max_by_ram = max(0, int(hw["ram_gb"] * 0.25 * 1024 / 268))
363
+ cfg["num_workers"] = min(auto_workers, max_by_ram, hw["cpu_cores"])
364
+ if cfg["num_workers"] == -1:
365
+ cfg["num_workers"] = 0
366
+
367
+ # Pin memory
368
+ cfg["pin_memory"] = hw["ram_gb"] > 16 and hw["device"] == "cuda"
369
+
370
+ # LR warmup: 5% of total steps (will be computed again in train())
371
+ cfg["_auto_warmup"] = True # flag: recompute once total_steps is known
372
+
373
+ # LR scaling: sqrt(global_batch / 120) relative to base lr
374
+ base_global = 120
375
+ cfg["lr"] = cfg["lr"] * math.sqrt(cfg["global_batch"] / base_global)
376
+ cfg["min_lr"] = cfg["min_lr"] * math.sqrt(cfg["global_batch"] / base_global)
377
+
378
+ cfg["_hw"] = hw
379
+ return cfg
380
+
381
+
382
+ # ─── Training ─────────────────────────────────────────────────────────────────
383
+
384
+ SEP = "=" * 72
385
+
386
+ def train(cfg: dict):
387
+ hw = cfg["_hw"]
388
+ device = torch.device(hw["device"])
389
+
390
+ # Clean GPU before anything β€” kill leftover allocations from prior runs
391
+ if device.type == "cuda":
392
+ torch.cuda.empty_cache()
393
+ gc.collect()
394
+ free_gb = (torch.cuda.get_device_properties(0).total_memory
395
+ - torch.cuda.memory_allocated()) / 1024**3
396
+ print(f" GPU free before model load: {free_gb:.1f} GB")
397
+
398
+ # Pick precision dtype
399
+ if cfg["auto_config"]:
400
+ dtype = hw.get("dtype", torch.float32)
401
+ else:
402
+ dtype = {"bf16": torch.bfloat16, "fp16": torch.float16,
403
+ "fp32": torch.float32}.get(cfg["precision"], torch.float32)
404
+
405
+ print(SEP)
406
+ print(" LUNA 100M - Training")
407
+ print(SEP)
408
+ mode = "AUTO" if cfg["auto_config"] else "MANUAL"
409
+ print(f" Config mode : {mode}")
410
+ print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
411
+ print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores")
412
+ print(f" Precision : {cfg['precision']} dtype={dtype}")
413
+ print(f" Workers : {cfg['num_workers']} pin_memory={cfg['pin_memory']}")
414
+
415
+ # ── Model ─────────────────────────────────────────────────────────────────
416
+ print(f"\n Building LUNA-100M...")
417
+ model = LUNAModel(
418
+ vocab_size=cfg["vocab_size"],
419
+ block_size=cfg["seq_len"],
420
+ n_layer=cfg["n_layer"],
421
+ n_embd=cfg["n_embd"],
422
+ n_head=cfg["n_head"],
423
+ ).to(device)
424
+
425
+ compiled_model = False
426
+ # torch.compile disabled: causes CUDA graph / OOM issues with tied
427
+ # embeddings at this model size. Raw PyTorch + SDPA is already fast.
428
+ print(" torch.compile: disabled (not needed for 100M params)")
429
+
430
+ print(f" Parameters: {model.num_params:,} (unique)")
431
+
432
+ # ── Batch sizing ──────────────────────────────────────────────────────────
433
+ if cfg["auto_config"] and device.type == "cuda":
434
+ print(f"\n Probing max micro_batch_size (VRAM search)...")
435
+ # Probe using the actual model β€” no second copy wasting VRAM
436
+ max_mbs = probe_max_batch(
437
+ model, device, dtype, cfg["seq_len"], cfg["vocab_size"]
438
+ )
439
+ # Re-init model weights after probe (probe dirties optimizer state)
440
+ model.apply(model._init_weights)
441
+ torch.cuda.empty_cache(); gc.collect()
442
+ # grad_accum to hit global_batch
443
+ grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs))
444
+ effective_batch = max_mbs * grad_accum
445
+ print(f" AUTO -> micro_batch={max_mbs}, grad_accum={grad_accum}, "
446
+ f"effective_batch={effective_batch}")
447
+ else:
448
+ max_mbs = cfg["micro_batch"]
449
+ grad_accum = cfg["grad_accum"]
450
+ effective_batch = max_mbs * grad_accum
451
+ print(f"\n MANUAL -> micro_batch={max_mbs}, grad_accum={grad_accum}, "
452
+ f"effective_batch={effective_batch}")
453
+
454
+ tokens_per_step = effective_batch * cfg["seq_len"]
455
+ print(f" Tokens/step : {tokens_per_step:,}")
456
+
457
+ # ── Dataset ───────────────────────────────────────────────────────────────
458
+ print(f"\n Dataset: {cfg['data_path']}")
459
+ dataset = LitDataDataset(cfg["data_path"], block_size=cfg["seq_len"])
460
+ print(f" Blocks : {len(dataset):,} ({len(dataset) * cfg['seq_len']:,} tokens)")
461
+
462
+ loader = torch.utils.data.DataLoader(
463
+ dataset,
464
+ batch_size=max_mbs,
465
+ shuffle=True,
466
+ num_workers=cfg["num_workers"],
467
+ pin_memory=cfg["pin_memory"],
468
+ drop_last=True,
469
+ prefetch_factor=4 if cfg["num_workers"] > 0 else None,
470
+ persistent_workers=cfg["num_workers"] > 0,
471
+ )
472
+
473
+ # ── Optimiser ─────────────────────────────────────────────────────────────
474
+ fused_ok = device.type == "cuda" and hasattr(torch.optim, "AdamW")
475
+ try:
476
+ optimizer = torch.optim.AdamW(
477
+ model.parameters(),
478
+ lr=cfg["lr"], weight_decay=cfg["weight_decay"],
479
+ betas=cfg["betas"], eps=cfg["eps"],
480
+ fused=True,
481
+ )
482
+ except TypeError:
483
+ optimizer = torch.optim.AdamW(
484
+ model.parameters(),
485
+ lr=cfg["lr"], weight_decay=cfg["weight_decay"],
486
+ betas=cfg["betas"], eps=cfg["eps"],
487
+ )
488
+
489
+ use_scaler = dtype == torch.float16
490
+ scaler = GradScaler(enabled=use_scaler)
491
+
492
+ # ── Schedule ──────────────────────────────────────────────────────────────
493
+ total_steps = max(1, cfg["max_tokens"] // tokens_per_step)
494
+ if cfg["auto_config"] and cfg.get("_auto_warmup"):
495
+ warmup_steps = max(50, min(500, total_steps // 20))
496
+ else:
497
+ warmup_steps = min(cfg["lr_warmup_steps"], total_steps)
498
+
499
+ out_dir = Path(cfg["out_dir"])
500
+ out_dir.mkdir(parents=True, exist_ok=True)
501
+
502
+ print(f"\n max_tokens : {cfg['max_tokens']:,}")
503
+ print(f" total_steps : {total_steps:,}")
504
+ print(f" warmup_steps : {warmup_steps}")
505
+ print(f" lr : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}")
506
+ print(f" save every : {cfg['save_interval']} steps")
507
+ print(f" out_dir : {out_dir}")
508
+ print(SEP)
509
+
510
+ # ── Resume ────────────────────────────────────────────────────────────────
511
+ start_step = 0
512
+ ckpt_path = out_dir / "latest.pt"
513
+ if ckpt_path.exists():
514
+ print(f"\n Resuming from {ckpt_path}...")
515
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
516
+ model.load_state_dict(ckpt["model"])
517
+ optimizer.load_state_dict(ckpt["optimizer"])
518
+ start_step = ckpt["step"]
519
+ print(f" Resumed at step {start_step}")
520
+
521
+ # ── Loop ──────────────────────────────────────────────────────────────────
522
+ model.train()
523
+ data_iter = iter(loader)
524
+
525
+ def get_batch():
526
+ nonlocal data_iter
527
+ try:
528
+ return next(data_iter)
529
+ except StopIteration:
530
+ data_iter = iter(loader)
531
+ return next(data_iter)
532
+
533
+ run_t0 = time.perf_counter()
534
+ tokens_seen = start_step * tokens_per_step
535
+ step = start_step
536
+
537
+ print(f"\n Starting training (step {start_step} -> {total_steps})...")
538
+
539
+ while step < total_steps:
540
+ t0 = time.perf_counter()
541
+ lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
542
+ for pg in optimizer.param_groups:
543
+ pg["lr"] = lr_now
544
+
545
+ optimizer.zero_grad(set_to_none=True)
546
+ total_loss = 0.0
547
+
548
+ for _ in range(grad_accum):
549
+ x, t = get_batch()
550
+ x = x.to(device, non_blocking=True)
551
+ t = t.to(device, non_blocking=True)
552
+ with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
553
+ _, loss = model(x, t, return_logits=False)
554
+ loss = loss / grad_accum
555
+ scaler.scale(loss).backward()
556
+ total_loss += loss.item()
557
+
558
+ scaler.unscale_(optimizer)
559
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"])
560
+ scaler.step(optimizer)
561
+ scaler.update()
562
+
563
+ if device.type == "cuda":
564
+ torch.cuda.synchronize()
565
+
566
+ dt = time.perf_counter() - t0
567
+ step += 1
568
+ tokens_seen += tokens_per_step
569
+
570
+ if step % cfg["log_interval"] == 0 or step <= 2:
571
+ tps = tokens_per_step / dt
572
+ steps_left = total_steps - step
573
+ eta_h = steps_left * dt / 3600
574
+ vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
575
+ print(f" step {step:6d}/{total_steps} | loss {total_loss:.4f} | "
576
+ f"lr {lr_now:.2e} | {tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h")
577
+
578
+ if step % cfg["save_interval"] == 0 or step == total_steps:
579
+ raw = model._orig_mod if hasattr(model, "_orig_mod") else model
580
+ step_dir = out_dir / f"step-{step:08d}"
581
+ step_dir.mkdir(parents=True, exist_ok=True)
582
+ torch.save(raw.state_dict(), step_dir / "lit_model.pth")
583
+ torch.save({"step": step, "model": raw.state_dict(),
584
+ "optimizer": optimizer.state_dict(),
585
+ "tokens_seen": tokens_seen},
586
+ out_dir / "latest.pt")
587
+ print(f" Saved -> {step_dir}")
588
+
589
+ # ── Final ─────────────────────────────────────────────────────────────────
590
+ final_dir = out_dir / "final"
591
+ final_dir.mkdir(parents=True, exist_ok=True)
592
+ raw = model._orig_mod if hasattr(model, "_orig_mod") else model
593
+ torch.save(raw.state_dict(), final_dir / "lit_model.pth")
594
+
595
+ import shutil
596
+ tok_src = Path(cfg["tokenizer_dir"])
597
+ if tok_src.exists():
598
+ shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True)
599
+
600
+ total_h = (time.perf_counter() - run_t0) / 3600
601
+ print(SEP)
602
+ print(f" Done! {total_h:.2f} h -> {final_dir}")
603
+ print(SEP)
604
+
605
+
606
+ # ─── Entry point ──────────────────────────────────────────────────────────────
607
+
608
+ def parse_args():
609
+ p = argparse.ArgumentParser(description="LUNA 100M Trainer")
610
+ p.add_argument("--config", type=str, default="train_config.yaml",
611
+ help="Path to train_config.yaml")
612
+ # CLI overrides (all optional - omit to use config value)
613
+ p.add_argument("--data_path", type=str, default=None)
614
+ p.add_argument("--out_dir", type=str, default=None)
615
+ p.add_argument("--max_tokens", type=int, default=None)
616
+ p.add_argument("--micro_batch", type=int, default=None)
617
+ p.add_argument("--global_batch",type=int, default=None)
618
+ p.add_argument("--lr", type=float, default=None)
619
+ p.add_argument("--num_workers", type=int, default=None)
620
+ p.add_argument("--save_interval",type=int, default=None)
621
+ p.add_argument("--log_interval",type=int, default=None)
622
+ p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"),
623
+ default=None, help="Override auto_config (true/false)")
624
+ return p.parse_args()
625
+
626
+
627
+ if __name__ == "__main__":
628
+ args = parse_args()
629
+ cfg = load_config(args.config)
630
+ cfg = apply_cli_overrides(cfg, args)
631
+ hw = probe_hardware()
632
+ cfg = resolve_auto(cfg, hw)
633
+ train(cfg)