CLIWorks commited on
Commit
e066e25
·
verified ·
1 Parent(s): 2ae4a50

Delete mythos-fineweb-dense.py

Browse files
Files changed (1) hide show
  1. mythos-fineweb-dense.py +0 -791
mythos-fineweb-dense.py DELETED
@@ -1,791 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
-
5
- Optimized dense variant — MoE replaced with single FFN per recurrent layer.
6
- Same RDT architecture (2 prelude + 8 recurrent + 2 coda) but all parameters
7
- active every forward pass. Designed for fast convergence on English.
8
-
9
- Performance optimizations:
10
- - Dense FFN instead of MoE (eliminates Python expert loop)
11
- - torch.compile with reduce-overhead mode
12
- - F.scaled_dot_product_attention (flash attention auto-selected)
13
- - Gradient checkpointing on recurrent layers (saves ~40% VRAM)
14
- - Larger micro_batch (128) with minimal grad_accum (2)
15
-
16
- Single GPU:
17
- python mythos-fineweb-dense.py
18
-
19
- Multi-GPU:
20
- torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
21
-
22
- Dense-to-MoE conversion (after training):
23
- Each recurrent layer's ffn weights are split into 64 chunks to initialize
24
- MoE experts. Attention layers, norms, and loop infrastructure carry over.
25
- """
26
- import os
27
- import math
28
- import time
29
- import torch
30
- import torch.nn as nn
31
- import torch.nn.functional as F
32
- import torch.distributed as dist
33
- from loguru import logger
34
- from torch.distributed.fsdp import (
35
- FullyShardedDataParallel as FSDP,
36
- ShardingStrategy,
37
- MixedPrecision,
38
- FullStateDictConfig,
39
- StateDictType,
40
- )
41
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
42
- from torch.utils.data import IterableDataset, DataLoader, get_worker_info
43
- from contextlib import nullcontext
44
- from dataclasses import dataclass
45
- from typing import Optional, Tuple, Dict, List
46
- from torch.nn import CrossEntropyLoss
47
- from datasets import load_dataset
48
- from transformers import AutoTokenizer
49
-
50
-
51
- # ---------------------------------------------------------------------------
52
- # SpiderPortal Model Architecture (Dense variant)
53
- # ---------------------------------------------------------------------------
54
-
55
- @dataclass
56
- class SpiderPortalConfig:
57
- vocab_size: int = 50257
58
- hidden_size: int = 384
59
- num_hidden_layers: int = 8
60
- num_attention_heads: int = 8
61
- num_key_value_heads: int = 2
62
- intermediate_size: int = 1024
63
- hidden_act: str = "silu"
64
- num_experts: int = 64
65
- num_experts_per_tok: int = 1
66
- num_shared_experts: int = 1
67
- router_aux_loss_coef: float = 0.05
68
- max_loop_iters: int = 1
69
- act_threshold: float = 0.5
70
- max_position_embeddings: int = 131072
71
- rope_theta: float = 10000000.0
72
- rope_scaling: dict = None
73
- sliding_window: int = 4096
74
- attention_dropout: float = 0.0
75
- rms_norm_eps: float = 1e-6
76
- initializer_range: float = 0.02
77
- use_cache: bool = True
78
- tie_word_embeddings: bool = True
79
- prelude_layers: int = 2
80
- coda_layers: int = 2
81
- lora_rank: int = 32
82
- loop_embed_dim: int = 48
83
- vision_hidden_size: int = 384
84
- audio_hidden_size: int = 512
85
- vision_num_frames: int = 60
86
- vision_tokens_per_frame: int = 256
87
- vision_temporal_tokens: int = 64
88
- vision_temporal_layers: int = 2
89
- model_type: str = "spiderportal"
90
- torch_dtype: str = "bfloat16"
91
-
92
-
93
- def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
94
- freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
95
- angles = loop_t * freqs
96
- emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
97
- emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
98
- emb_full[:loop_dim] = emb
99
- return h + emb_full.unsqueeze(0).unsqueeze(0)
100
-
101
-
102
- class SpiderPortalRMSNorm(nn.Module):
103
- def __init__(self, hidden_size, eps=1e-6):
104
- super().__init__()
105
- self.weight = nn.Parameter(torch.ones(hidden_size))
106
- self.variance_epsilon = eps
107
- def forward(self, hidden_states):
108
- input_dtype = hidden_states.dtype
109
- hidden_states = hidden_states.to(torch.float32)
110
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
111
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
112
- return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
113
-
114
-
115
- def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
116
- dim = head_dim
117
- orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
118
- pos_freqs = torch.arange(0, dim, 2).float() / dim
119
- beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
120
- scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
121
- return orig_inv_freq * scale
122
-
123
-
124
- class SpiderPortalGQA(nn.Module):
125
- def __init__(self, config):
126
- super().__init__()
127
- self.config = config
128
- self.hidden_size = config.hidden_size
129
- self.num_heads = config.num_attention_heads
130
- self.num_kv_heads = config.num_key_value_heads
131
- self.head_dim = config.hidden_size // config.num_attention_heads
132
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
133
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
134
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
135
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
136
- self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
137
- self.attention_dropout = config.attention_dropout
138
- rope_scaling = getattr(config, 'rope_scaling', None)
139
- if rope_scaling and rope_scaling.get("type") == "yarn":
140
- factor = rope_scaling.get("factor", 1.0)
141
- orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
142
- inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
143
- else:
144
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
145
- self.register_buffer("inv_freq", inv_freq, persistent=False)
146
- def _rotate_half(self, x):
147
- x1 = x[..., :x.shape[-1] // 2]
148
- x2 = x[..., x.shape[-1] // 2:]
149
- return torch.cat((-x2, x1), dim=-1)
150
- def _apply_rotary(self, x, cos, sin):
151
- return (x * cos) + (self._rotate_half(x) * sin)
152
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
153
- bsz, q_len, _ = hidden_states.size()
154
- query_states = self.q_proj(hidden_states)
155
- key_states = self.k_proj(hidden_states)
156
- value_states = self.v_proj(hidden_states)
157
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
158
- key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
159
- value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
160
- if position_ids is None:
161
- position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
162
- max_pos = position_ids.max().item() + 1
163
- seq_len = max(max_pos, q_len)
164
- t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
165
- freqs = torch.outer(t, self.inv_freq)
166
- emb = torch.cat((freqs, freqs), dim=-1)
167
- cos, sin = emb.cos(), emb.sin()
168
- cos = cos[position_ids].unsqueeze(1)
169
- sin = sin[position_ids].unsqueeze(1)
170
- query_states = self._apply_rotary(query_states, cos, sin)
171
- key_states = self._apply_rotary(key_states, cos, sin)
172
- if past_key_value is not None:
173
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
174
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
175
- past_kv = (key_states, value_states) if use_cache else None
176
- key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
177
- value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
178
- attn_output = F.scaled_dot_product_attention(
179
- query_states, key_states, value_states,
180
- attn_mask=attention_mask,
181
- dropout_p=self.attention_dropout if self.training else 0.0,
182
- is_causal=attention_mask is None
183
- )
184
- attn_output = attn_output.transpose(1, 2).contiguous()
185
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
186
- return self.o_proj(attn_output), past_kv
187
-
188
-
189
- class SpiderPortalExpert(nn.Module):
190
- def __init__(self, config, intermediate_size=None):
191
- super().__init__()
192
- inter_size = intermediate_size or config.intermediate_size
193
- self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
194
- self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
195
- self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
196
- self.act_fn = nn.SiLU()
197
- def forward(self, hidden_states):
198
- return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
199
-
200
-
201
- class SpiderPortalDenseLayer(nn.Module):
202
- """Prelude/coda dense layer. intermediate_size=512 (4/3 * hidden_size)."""
203
- def __init__(self, config):
204
- super().__init__()
205
- self.self_attn = SpiderPortalGQA(config)
206
- dense_intermediate = config.hidden_size * 4 // 3
207
- self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
208
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
210
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
211
- attn_input = self.input_layernorm(hidden_states)
212
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
213
- hidden_states = hidden_states + attn_output
214
- ffn_input = self.post_attention_layernorm(hidden_states)
215
- ffn_output = self.ffn(ffn_input)
216
- hidden_states = hidden_states + ffn_output
217
- return hidden_states, past_kv
218
-
219
-
220
- class SpiderPortalRecurrentDenseLayer(nn.Module):
221
- """Recurrent layer with DENSE FFN (not MoE). intermediate_size=1024.
222
-
223
- This replaces SpiderPortalMoELayer. After dense training converges,
224
- the ffn weights can be split into 64 chunks to initialize MoE experts.
225
- """
226
- def __init__(self, config, layer_idx):
227
- super().__init__()
228
- self.layer_idx = layer_idx
229
- self.self_attn = SpiderPortalGQA(config)
230
- self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
231
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
234
- attn_input = self.input_layernorm(hidden_states)
235
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
236
- hidden_states = hidden_states + attn_output
237
- ffn_input = self.post_attention_layernorm(hidden_states)
238
- ffn_output = self.ffn(ffn_input)
239
- hidden_states = hidden_states + ffn_output
240
- return hidden_states, 0.0, past_kv
241
-
242
-
243
- class LTIInjection(nn.Module):
244
- def __init__(self, config):
245
- super().__init__()
246
- self.hidden_size = config.hidden_size
247
- self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
248
- self.delta_t = nn.Parameter(torch.tensor(1.0))
249
- self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
250
- with torch.no_grad():
251
- self.B.weight.data.normal_(mean=0.0, std=0.01)
252
- def get_A(self):
253
- return -torch.exp(self.log_A)
254
- def forward(self, h_t, e):
255
- A = self.get_A()
256
- return A * h_t + self.B(e)
257
-
258
-
259
- class ACTHalting(nn.Module):
260
- def __init__(self, config):
261
- super().__init__()
262
- self.halt_predictor = nn.Linear(config.hidden_size, 1)
263
- self.threshold = config.act_threshold
264
- def forward(self, hidden_states):
265
- return torch.sigmoid(self.halt_predictor(hidden_states))
266
-
267
-
268
- class LoRAAdapter(nn.Module):
269
- def __init__(self, config):
270
- super().__init__()
271
- rank = config.lora_rank
272
- self.down = nn.Linear(config.hidden_size, rank, bias=False)
273
- self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
274
- self.scale = nn.Embedding(config.max_loop_iters, rank)
275
- with torch.no_grad():
276
- self.scale.weight.data.zero_()
277
- self.down.weight.data.normal_(mean=0.0, std=0.001)
278
- def forward(self, x, loop_t):
279
- max_t = self.scale.num_embeddings - 1
280
- t_idx = min(loop_t, max_t)
281
- s = self.scale(torch.tensor(t_idx, device=x.device))
282
- down = self.down(x) * s
283
- return down @ self.B
284
-
285
-
286
- def checkpoint(func, *args, **kwargs):
287
- """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
288
- if torch.is_grad_enabled():
289
- return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
290
- return func(*args, **kwargs)
291
-
292
-
293
- class SpiderPortalDenseModel(nn.Module):
294
- """Full RDT model with DENSE recurrent layers (no MoE).
295
-
296
- Architecture:
297
- 2x Prelude (dense, intermediate=512)
298
- 8x Recurrent (dense FFN, intermediate=1024) — with gradient checkpointing
299
- 2x Coda (dense, intermediate=512)
300
- LTI Injection + ACT Halting + LoRA Adapter
301
- """
302
- def __init__(self, config):
303
- super().__init__()
304
- self.config = config
305
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
306
- self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)])
307
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
308
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
309
- self.injection = LTIInjection(config)
310
- self.act_halting = ACTHalting(config)
311
- self.lora_adapter = LoRAAdapter(config)
312
- self.loop_embed_dim = config.loop_embed_dim
313
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
314
- n_loops = n_loops or self.config.max_loop_iters
315
- input_embedding = input_embedding if input_embedding is not None else hidden_states
316
- for layer in self.prelude_layers:
317
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
318
- e = hidden_states.clone()
319
- B, T_seq, D = hidden_states.shape
320
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
321
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
322
- h_out = torch.zeros_like(hidden_states)
323
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
324
- for t in range(n_loops):
325
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
326
- if t > 0:
327
- injection = self.injection(hidden_states, input_embedding)
328
- hidden_states = hidden_states + injection
329
- new_past_key_values = []
330
- for i, layer in enumerate(self.recurrent_layers):
331
- hidden_states, aux_loss, past_kv = checkpoint(
332
- layer, hidden_states,
333
- attention_mask=attention_mask,
334
- position_ids=position_ids,
335
- past_key_value=past_key_values[i] if t == 0 else None,
336
- use_cache=use_cache
337
- )
338
- new_past_key_values.append(past_kv)
339
- lora_delta = self.lora_adapter(hidden_states, t)
340
- hidden_states = hidden_states + lora_delta
341
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
342
- still_running = ~halted
343
- remainder = (1.0 - cumulative_p).clamp(min=0)
344
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
345
- weight = weight * still_running.to(hidden_states.dtype)
346
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
347
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
348
- halted = halted | (cumulative_p >= self.config.act_threshold)
349
- if halted.all() and not self.training:
350
- break
351
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
352
- hidden_states = h_out + never_halted * hidden_states
353
- for layer in self.coda_layers:
354
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
355
- hidden_states = self.norm(hidden_states)
356
- return hidden_states, 0.0, new_past_key_values
357
-
358
-
359
- class SpiderPortalForConditionalGeneration(nn.Module):
360
- def __init__(self, config):
361
- super().__init__()
362
- self.config = config
363
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
364
- self.model = SpiderPortalDenseModel(config)
365
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
366
- if config.tie_word_embeddings:
367
- self.lm_head.weight = self.embed_tokens.weight
368
- self.apply(self._init_weights)
369
- def _init_weights(self, module):
370
- if isinstance(module, nn.Linear):
371
- if hasattr(self, 'model') and module is self.model.injection.B:
372
- return
373
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
374
- if module.bias is not None:
375
- module.bias.data.zero_()
376
- elif isinstance(module, nn.Embedding):
377
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
378
- def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
379
- hidden_states = self.embed_tokens(input_ids)
380
- model_dtype = next(self.model.parameters()).dtype
381
- hidden_states = hidden_states.to(model_dtype)
382
- input_embedding = hidden_states.clone()
383
- if attention_mask is None:
384
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
385
- causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
386
- causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
387
- causal_mask = causal_mask.triu(1)
388
- hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
389
- logits = self.lm_head(hidden_states)
390
- loss = None
391
- if labels is not None:
392
- shift_logits = logits[..., :-1, :].contiguous()
393
- shift_labels = labels[..., 1:].contiguous()
394
- loss_fct = CrossEntropyLoss()
395
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
396
- return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
397
- def get_num_params(self):
398
- total = sum(p.numel() for p in self.parameters())
399
- return {"total": total, "trainable": total}
400
-
401
-
402
- # ---------------------------------------------------------------------------
403
- # Dataset
404
- # ---------------------------------------------------------------------------
405
-
406
- class FineWebEduDataset(IterableDataset):
407
- def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int):
408
- self.tokenizer = tokenizer
409
- self.seq_len = seq_len
410
- self.subset = subset
411
- self.rank = rank
412
- self.world_size = world_size
413
- def __iter__(self):
414
- worker = get_worker_info()
415
- num_workers = worker.num_workers if worker else 1
416
- worker_id = worker.id if worker else 0
417
- total_shards = self.world_size * num_workers
418
- shard_index = self.rank * num_workers + worker_id
419
- ds = load_dataset(
420
- "HuggingFaceFW/fineweb-edu",
421
- name=self.subset,
422
- split="train",
423
- streaming=True,
424
- ).shard(num_shards=total_shards, index=shard_index)
425
- buf = []
426
- for sample in ds:
427
- buf.extend(self.tokenizer.encode(sample["text"]))
428
- while len(buf) >= self.seq_len + 1:
429
- chunk = buf[: self.seq_len + 1]
430
- buf = buf[self.seq_len + 1 :]
431
- yield (
432
- torch.tensor(chunk[:-1], dtype=torch.long),
433
- torch.tensor(chunk[1:], dtype=torch.long),
434
- )
435
-
436
-
437
- # ---------------------------------------------------------------------------
438
- # LR schedule: linear warmup → cosine decay
439
- # ---------------------------------------------------------------------------
440
-
441
- def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
442
- if step < warmup:
443
- return max_lr * step / warmup
444
- if step >= total:
445
- return min_lr
446
- decay = (step - warmup) / (total - warmup)
447
- return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
448
-
449
-
450
- # ---------------------------------------------------------------------------
451
- # Checkpointing
452
- # ---------------------------------------------------------------------------
453
-
454
- def save_weights_only(model, step, epoch, ckpt_dir, ddp):
455
- if ddp:
456
- with FSDP.state_dict_type(
457
- model,
458
- StateDictType.FULL_STATE_DICT,
459
- FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
460
- ):
461
- model_state = model.state_dict()
462
- else:
463
- model_state = model.state_dict()
464
- ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
465
- tmp_path = ckpt_path + ".tmp"
466
- torch.save(model_state, tmp_path)
467
- os.replace(tmp_path, ckpt_path)
468
- size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
469
- return ckpt_path, size_mb
470
-
471
-
472
- def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
473
- if ddp:
474
- with FSDP.state_dict_type(
475
- model,
476
- StateDictType.FULL_STATE_DICT,
477
- FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
478
- ):
479
- model_state = model.state_dict()
480
- optim_state = FSDP.optim_state_dict(model, optimizer)
481
- else:
482
- model_state = model.state_dict()
483
- optim_state = optimizer.state_dict()
484
- if not master:
485
- return None, 0
486
- os.makedirs(ckpt_dir, exist_ok=True)
487
- final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
488
- tmp_path = final_path + ".tmp"
489
- torch.save(
490
- {
491
- "step": step,
492
- "epoch": epoch,
493
- "model_state_dict": model_state,
494
- "optimizer_state_dict": optim_state,
495
- "cfg": cfg,
496
- "vocab_size": vocab_size,
497
- },
498
- tmp_path,
499
- )
500
- os.replace(tmp_path, final_path)
501
- size_mb = os.path.getsize(final_path) / (1024 * 1024)
502
- return final_path, size_mb
503
-
504
-
505
- def load_checkpoint(model, optimizer, path, ddp):
506
- ckpt = torch.load(path, map_location="cpu", weights_only=False)
507
- if ddp:
508
- with FSDP.state_dict_type(
509
- model,
510
- StateDictType.FULL_STATE_DICT,
511
- FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
512
- ):
513
- model.load_state_dict(ckpt["model_state_dict"])
514
- optim_state = FSDP.optim_state_dict_to_load(
515
- model=model,
516
- optim=optimizer,
517
- optim_state_dict=ckpt["optimizer_state_dict"],
518
- )
519
- optimizer.load_state_dict(optim_state)
520
- else:
521
- model.load_state_dict(ckpt["model_state_dict"])
522
- optimizer.load_state_dict(ckpt["optimizer_state_dict"])
523
- return int(ckpt["step"]), int(ckpt.get("epoch", 0))
524
-
525
-
526
- # ---------------------------------------------------------------------------
527
- # Main
528
- # ---------------------------------------------------------------------------
529
-
530
- def main():
531
- # ------------------------------------------------------------------
532
- # Distributed init
533
- # ------------------------------------------------------------------
534
- ddp = int(os.environ.get("RANK", -1)) != -1
535
- if ddp:
536
- dist.init_process_group("nccl")
537
- rank = int(os.environ["RANK"])
538
- local_rank = int(os.environ["LOCAL_RANK"])
539
- world_size = int(os.environ["WORLD_SIZE"])
540
- device = f"cuda:{local_rank}"
541
- torch.cuda.set_device(device)
542
- else:
543
- rank = local_rank = 0
544
- world_size = 1
545
- device = "cuda" if torch.cuda.is_available() else "cpu"
546
- master = rank == 0
547
- if master:
548
- logger.info(
549
- f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
550
- )
551
-
552
- # ------------------------------------------------------------------
553
- # Tokenizer
554
- # ------------------------------------------------------------------
555
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
556
- tokenizer.pad_token = tokenizer.eos_token
557
- vocab_size = tokenizer.vocab_size
558
- if master:
559
- logger.info(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
560
-
561
- # ------------------------------------------------------------------
562
- # Hyperparameters — OPTIMIZED for speed
563
- # ------------------------------------------------------------------
564
- seq_len = 2048
565
- micro_batch = 128 # Increased from 32 — RTX 6000 has 96GB VRAM
566
- target_tokens = 20_000_000_000 # 20B tokens (2 epochs on 10BT)
567
- grad_accum = 2 # Reduced from 4 — fewer backward passes
568
- global_batch_tok = world_size * micro_batch * grad_accum * seq_len
569
- total_steps = target_tokens // global_batch_tok
570
- warmup_steps = 200
571
- lr = 3e-4
572
- wd = 0.1
573
- log_every = 10
574
- ckpt_every = 500
575
- ckpt_dir = "checkpoints-dense"
576
- dataset_subset = "sample-10BT"
577
-
578
- if master:
579
- logger.info(
580
- f"[DENSE OPTIMIZED] seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
581
- f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
582
- )
583
- logger.info(
584
- f"Gradient checkpointing: enabled | torch.compile: enabled | "
585
- f"SDPA attention: enabled"
586
- )
587
-
588
- # ------------------------------------------------------------------
589
- # Model — Dense variant
590
- # ------------------------------------------------------------------
591
- cfg = SpiderPortalConfig(
592
- hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
593
- num_key_value_heads=2, intermediate_size=1024,
594
- num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
595
- router_aux_loss_coef=0.05, max_loop_iters=1,
596
- prelude_layers=2, coda_layers=2, lora_rank=32,
597
- rope_theta=10000000.0,
598
- rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
599
- max_position_embeddings=131072, sliding_window=4096,
600
- tie_word_embeddings=True,
601
- )
602
- cfg.vocab_size = vocab_size
603
-
604
- bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
605
- amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
606
-
607
- model = SpiderPortalForConditionalGeneration(cfg)
608
-
609
- if ddp:
610
- mp_policy = MixedPrecision(
611
- param_dtype=amp_dtype,
612
- reduce_dtype=amp_dtype,
613
- buffer_dtype=amp_dtype,
614
- )
615
- wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
616
- model = FSDP(
617
- model,
618
- sharding_strategy=ShardingStrategy.FULL_SHARD,
619
- mixed_precision=mp_policy,
620
- auto_wrap_policy=wrap_policy,
621
- device_id=local_rank,
622
- )
623
- else:
624
- model = model.to(device)
625
- amp_ctx = (
626
- torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
627
- if "cuda" in device
628
- else nullcontext()
629
- )
630
-
631
- amp_ctx = nullcontext() if ddp else amp_ctx
632
-
633
- if master:
634
- n_params = sum(p.numel() for p in model.parameters())
635
- n_active = n_params # Dense = all params active
636
- logger.info(f"Parameters: {n_params:,} (all active) | AMP dtype: {amp_dtype}")
637
-
638
- # Compile — ACTUAL torch.compile this time
639
- try:
640
- model = torch.compile(model, mode="reduce-overhead")
641
- if master:
642
- logger.info("torch.compile: enabled (reduce-overhead)")
643
- except Exception as e:
644
- if master:
645
- logger.warning(f"torch.compile failed ({e}), using eager mode")
646
-
647
- # ------------------------------------------------------------------
648
- # Optimizer
649
- # ------------------------------------------------------------------
650
- optimizer = torch.optim.AdamW(
651
- model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
652
- )
653
-
654
- # ------------------------------------------------------------------
655
- # Resume from latest checkpoint (if any)
656
- # ------------------------------------------------------------------
657
- start_step = 0
658
- start_epoch = 1
659
- best_loss = float("inf")
660
- existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
661
- if existing_ckpts:
662
- latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
663
- if master:
664
- logger.info(f"Resuming from checkpoint: {latest}")
665
- start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
666
- if master:
667
- logger.success(f"Resumed at step {start_step}, epoch {start_epoch}")
668
-
669
- # ------------------------------------------------------------------
670
- # Dataset + DataLoader
671
- # ------------------------------------------------------------------
672
- dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size)
673
- loader = DataLoader(dataset, batch_size=micro_batch, num_workers=8, pin_memory=True, prefetch_factor=2)
674
-
675
- # ------------------------------------------------------------------
676
- # Training loop
677
- # ------------------------------------------------------------------
678
- if master:
679
- os.makedirs(ckpt_dir, exist_ok=True)
680
-
681
- model.train()
682
- data_iter = iter(loader)
683
- t0 = time.perf_counter()
684
- step = start_step
685
- epoch = start_epoch
686
- step_ckpt_files = []
687
- tokens_in_epoch = 0
688
- tokens_per_epoch = target_tokens
689
-
690
- while step < total_steps:
691
- cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
692
- for g in optimizer.param_groups:
693
- g["lr"] = cur_lr
694
-
695
- optimizer.zero_grad()
696
- loss_accum = 0.0
697
-
698
- for micro_step in range(grad_accum):
699
- try:
700
- x, y = next(data_iter)
701
- except StopIteration:
702
- data_iter = iter(loader)
703
- x, y = next(data_iter)
704
-
705
- x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
706
- y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
707
-
708
- sync = (
709
- nullcontext()
710
- if (not ddp or micro_step == grad_accum - 1)
711
- else model.no_sync()
712
- )
713
- with sync, amp_ctx:
714
- output = model(x)
715
- if isinstance(output, dict):
716
- logits = output["logits"]
717
- else:
718
- logits = output
719
- loss = nn.functional.cross_entropy(
720
- logits.view(-1, vocab_size), y.view(-1)
721
- )
722
- loss = loss / grad_accum
723
-
724
- loss.backward()
725
- loss_accum += loss.item()
726
-
727
- if ddp:
728
- grad_norm = model.clip_grad_norm_(1.0)
729
- else:
730
- grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
731
- optimizer.step()
732
- step += 1
733
- tokens_in_epoch += global_batch_tok
734
-
735
- if master and step % log_every == 0:
736
- dt = time.perf_counter() - t0
737
- tok_per_sec = global_batch_tok * log_every / dt
738
- tokens_seen = step * global_batch_tok
739
- logger.info(
740
- f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
741
- f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
742
- f"| {tok_per_sec / 1e6:.2f}M tok/s "
743
- f"| {tokens_seen / 1e9:.2f}B tokens seen"
744
- )
745
- t0 = time.perf_counter()
746
-
747
- if step % ckpt_every == 0 and master:
748
- ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
749
- step_ckpt_files.append(ckpt_path)
750
- logger.info(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
751
-
752
- if tokens_in_epoch >= tokens_per_epoch:
753
- epoch_loss = loss_accum
754
- if master:
755
- epoch_time = (time.perf_counter() - t0) / 60
756
- logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
757
-
758
- for f in step_ckpt_files:
759
- if os.path.exists(f):
760
- os.remove(f)
761
- logger.info(f" Deleted step checkpoint: {os.path.basename(f)}")
762
- step_ckpt_files.clear()
763
-
764
- ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
765
- if ckpt_path:
766
- logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
767
-
768
- if epoch_loss < best_loss:
769
- best_loss = epoch_loss
770
- ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
771
- if ckpt_path:
772
- logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
773
-
774
- epoch += 1
775
- tokens_in_epoch = 0
776
-
777
- if step > start_step and master:
778
- ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
779
- if ckpt_path:
780
- logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
781
-
782
- if ddp:
783
- dist.barrier()
784
- dist.destroy_process_group()
785
-
786
- if master:
787
- logger.success("Training complete.")
788
-
789
-
790
- if __name__ == "__main__":
791
- main()