CLIWorks commited on
Commit
ca96662
·
verified ·
1 Parent(s): e86c5bf

Upload mythos-fineweb-moe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mythos-fineweb-moe.py +1262 -0
mythos-fineweb-moe.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
+
5
+ Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
6
+ - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window
7
+ - Engram conditional memory at recurrent layers 1 and 4
8
+ - Dense FFN (all params active, MoE conversion in Phase 2)
9
+ - LTI Injection + ACT Halting + LoRA Adapter
10
+ - 32k context (extendable to 256k at inference via YaRN)
11
+
12
+ Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing
13
+
14
+ Single GPU:
15
+ python mythos-fineweb-dense.py
16
+
17
+ Multi-GPU:
18
+ torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
19
+ """
20
+ import os
21
+ import math
22
+ import time
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import torch.distributed as dist
27
+ from loguru import logger
28
+ import sys
29
+
30
+ # Configure loguru to file + stderr
31
+ LOG_FILE = "train_spiderportal.log"
32
+ logger.remove()
33
+ logger.add(sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
34
+ logger.add(LOG_FILE, rotation="100 MB", retention="10 days", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
35
+
36
+ # Speed up CUDA memory allocation
37
+ import torch
38
+ torch.cuda.empty_cache()
39
+
40
+ # Numba CPU fallback
41
+ from triton_kernels import (
42
+ numba_dispatch,
43
+ NUMBA_AVAILABLE as _NUMBA_OK,
44
+ )
45
+ from numba_cuda_kernels import (
46
+ cuda_engram_hash,
47
+ cuda_engram_gate,
48
+ cuda_act_halting,
49
+ cuda_engram_conv1d,
50
+ cuda_available as _CUDA_OK,
51
+ )
52
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
53
+ from torch.distributed.fsdp import (
54
+ FullyShardedDataParallel as FSDP,
55
+ ShardingStrategy,
56
+ MixedPrecision,
57
+ FullStateDictConfig,
58
+ StateDictType,
59
+ )
60
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
61
+ from torch.utils.data import IterableDataset, DataLoader, get_worker_info
62
+ from contextlib import nullcontext
63
+ from dataclasses import dataclass, field
64
+ from typing import Optional, Tuple, Dict, List
65
+ from torch.nn import CrossEntropyLoss
66
+ from datasets import load_dataset
67
+ from transformers import AutoTokenizer
68
+
69
+ # MoE imports
70
+ from torchtitan.models.moe.moe import build_moe, MoEArgs, MoE, FeedForward
71
+
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # SpiderPortal Model Architecture (Dense + MLA + Engram)
75
+ # ---------------------------------------------------------------------------
76
+
77
+ @dataclass
78
+ class SpiderPortalConfig:
79
+ vocab_size: int = 50257
80
+ hidden_size: int = 2048
81
+ num_hidden_layers: int = 6
82
+ num_attention_heads: int = 16
83
+ num_key_value_heads: int = 4
84
+ intermediate_size: int = 8192
85
+ hidden_act: str = "silu"
86
+ num_experts: int = 32
87
+ num_experts_per_tok: int = 2
88
+ num_shared_experts: int = 1
89
+ router_aux_loss_coef: float = 0.05
90
+ max_loop_iters: int = 2
91
+ act_threshold: float = 0.5
92
+ max_position_embeddings: int = 32768
93
+ rope_theta: float = 10000000.0
94
+ rope_scaling: dict = None
95
+ sliding_window: int = 4096
96
+ attention_dropout: float = 0.0
97
+ rms_norm_eps: float = 1e-6
98
+ initializer_range: float = 0.02
99
+ use_cache: bool = True
100
+ tie_word_embeddings: bool = True
101
+ prelude_layers: int = 2
102
+ coda_layers: int = 2
103
+ lora_rank: int = 128
104
+ loop_embed_dim: int = 128
105
+ vision_hidden_size: int = 2048
106
+ audio_hidden_size: int = 512
107
+ vision_num_frames: int = 60
108
+ vision_tokens_per_frame: int = 256
109
+ vision_temporal_tokens: int = 64
110
+ vision_temporal_layers: int = 2
111
+ model_type: str = "spiderportal"
112
+ torch_dtype: str = "bfloat16"
113
+
114
+ # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
115
+ kv_lora_rank: int = 128
116
+ q_lora_rank: int = 256
117
+ qk_rope_head_dim: int = 64
118
+ qk_nope_head_dim: int = 64
119
+ v_head_dim: int = 64
120
+
121
+ # Engram parameters (DeepSeek conditional memory)
122
+ engram_layers: List[int] = field(default_factory=lambda: [1, 4])
123
+ engram_ngram_orders: Tuple[int, ...] = (2, 3)
124
+ engram_hash_heads: int = 4
125
+ engram_table_size: int = 65537 # prime number for hash table
126
+ engram_conv_kernel: int = 4
127
+ engram_conv_dilation: int = 3
128
+ engram_dim: int = 128 # per-head embedding dimension
129
+
130
+
131
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
132
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
133
+ angles = loop_t * freqs
134
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
135
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
136
+ emb_full[:loop_dim] = emb
137
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
138
+
139
+
140
+ class SpiderPortalRMSNorm(nn.Module):
141
+ def __init__(self, hidden_size, eps=1e-6):
142
+ super().__init__()
143
+ self.weight = nn.Parameter(torch.ones(hidden_size))
144
+ self.variance_epsilon = eps
145
+ def forward(self, hidden_states):
146
+ # bf16-only RMSNorm: no dtype conversions inside forward.
147
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
148
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
149
+ return self.weight * hidden_states
150
+
151
+
152
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
153
+ dim = head_dim
154
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
155
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
156
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
157
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
158
+ return orig_inv_freq * scale
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window
163
+ # ---------------------------------------------------------------------------
164
+
165
+ class SpiderPortalMLA(nn.Module):
166
+ """Multi-Latent Attention with compressed KV cache and sliding window.
167
+
168
+ For hidden_size=2048, num_heads=16:
169
+ - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128
170
+ - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV
171
+ - v_head_dim=64 → value projection
172
+ - sliding_window=4096 → local attention range
173
+ """
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.config = config
177
+ self.hidden_size = config.hidden_size
178
+ self.num_heads = config.num_attention_heads
179
+ self.kv_lora_rank = config.kv_lora_rank
180
+ self.q_lora_rank = config.q_lora_rank
181
+ self.qk_rope_head_dim = config.qk_rope_head_dim
182
+ self.qk_nope_head_dim = config.qk_nope_head_dim
183
+ self.v_head_dim = config.v_head_dim
184
+ self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
185
+ self.sliding_window = getattr(config, 'sliding_window', None)
186
+
187
+ # Q projection: optional low-rank → full Q
188
+ if self.q_lora_rank > 0:
189
+ self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
190
+ self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank)
191
+ self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
192
+ else:
193
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
194
+
195
+ # KV compression: hidden → kv_lora_rank (shared latent)
196
+ self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
197
+ self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank)
198
+ # Decompress: kv_lora_rank → nope heads + v heads
199
+ self.kv_b_proj = nn.Linear(
200
+ self.kv_lora_rank,
201
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
202
+ bias=False,
203
+ )
204
+ # Output projection
205
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
206
+
207
+ # RoPE frequencies
208
+ rope_scaling = getattr(config, 'rope_scaling', None)
209
+ if rope_scaling and rope_scaling.get("type") == "yarn":
210
+ factor = rope_scaling.get("factor", 1.0)
211
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
212
+ inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos)
213
+ else:
214
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim))
215
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
216
+
217
+ def _rotate_half(self, x):
218
+ x1 = x[..., :x.shape[-1] // 2]
219
+ x2 = x[..., x.shape[-1] // 2:]
220
+ return torch.cat((-x2, x1), dim=-1)
221
+
222
+ def _apply_rotary(self, x, cos, sin):
223
+ return (x * cos) + (self._rotate_half(x) * sin)
224
+
225
+ def _make_sliding_window_mask(self, q_len, kv_len, device, dtype):
226
+ """Unused: sliding_window disabled."""
227
+ return None
228
+
229
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
230
+ bsz, q_len, _ = hidden_states.size()
231
+ # Q projection
232
+ if self.q_lora_rank > 0:
233
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
234
+ else:
235
+ q = self.q_proj(hidden_states)
236
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
237
+ q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
238
+
239
+ # KV: compress to latent, then decompress
240
+ kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
241
+ kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
242
+ kv_latent_norm = self.kv_a_layernorm(kv_latent)
243
+ kv_b_out = self.kv_b_proj(kv_latent_norm)
244
+ k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1)
245
+
246
+ k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
247
+ v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
248
+ k_rope = k_rope.unsqueeze(1)
249
+
250
+ # RoPE on Q and K rope parts
251
+ if position_ids is None:
252
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
253
+ max_pos = position_ids.max().item() + 1
254
+ seq_len = max(max_pos, q_len)
255
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
256
+ freqs = torch.outer(t, self.inv_freq)
257
+ emb = torch.cat((freqs, freqs), dim=-1)
258
+ cos, sin = emb.cos(), emb.sin()
259
+ cos_full = cos[position_ids].unsqueeze(1)
260
+ sin_full = sin[position_ids].unsqueeze(1)
261
+
262
+ q_rope = self._apply_rotary(q_rope, cos_full, sin_full)
263
+ k_rope = self._apply_rotary(k_rope, cos_full, sin_full)
264
+
265
+ # Assemble full K
266
+ k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
267
+ k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
268
+ q_full = torch.cat([q_nope, q_rope], dim=-1)
269
+
270
+ # KV cache
271
+ if past_key_value is not None:
272
+ k_full = torch.cat([past_key_value[0], k_full], dim=2)
273
+ v = torch.cat([past_key_value[1], v], dim=2)
274
+ past_kv = (k_full, v) if use_cache else None
275
+
276
+ # Attention with SDPA — is_causal=True for flash-attention fast path
277
+ # No 4D causal mask needed; sliding window disabled, so pure causal.
278
+ attn_output = F.scaled_dot_product_attention(
279
+ q_full, k_full, v,
280
+ attn_mask=None,
281
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
282
+ is_causal=True,
283
+ )
284
+ attn_output = attn_output.transpose(1, 2).contiguous()
285
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
286
+ return self.o_proj(attn_output), past_kv
287
+
288
+
289
+ # ---------------------------------------------------------------------------
290
+ # Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
291
+ # ---------------------------------------------------------------------------
292
+
293
+ def _tokenizer_compress(token_ids, vocab_size=50257):
294
+ """Simulate NFKC + lowercase canonical ID projection."""
295
+ return token_ids % (vocab_size * 77 // 100)
296
+
297
+
298
+ class SpiderPortalEngram(nn.Module):
299
+ """Conditional memory module via NN-gram lookup.
300
+
301
+ Applied only at specific recurrent layers (config.engram_layers).
302
+ """
303
+ def __init__(self, config):
304
+ super().__init__()
305
+ self.config = config
306
+ self.ngram_orders = list(config.engram_ngram_orders)
307
+ self.num_heads_per_order = config.engram_hash_heads
308
+ self.table_size = config.engram_table_size
309
+ self.d_mem = config.engram_dim
310
+
311
+ self.total_mem_dim = len(self.ngram_orders) * self.num_heads_per_order * self.d_mem
312
+
313
+ # Stacked embedding table with offsets: [orders, heads, table_size, d_mem]
314
+ # This matches the deepseek MultiHeadEmbedding principle.
315
+ self.embed = nn.Parameter(
316
+ torch.randn(len(self.ngram_orders), self.num_heads_per_order, self.table_size, self.d_mem) * 0.02
317
+ )
318
+
319
+ # Seeds per (order, head) in a stable head_counter ordering.
320
+ seeds = []
321
+ for _order in self.ngram_orders:
322
+ for h in range(self.num_heads_per_order):
323
+ seeds.append((h + 1) * 2654435761)
324
+ self.register_buffer("hash_seeds", torch.tensor(seeds, dtype=torch.int64), persistent=False)
325
+
326
+ self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
327
+ self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
328
+
329
+ self.conv = nn.Conv1d(
330
+ config.hidden_size, config.hidden_size,
331
+ kernel_size=config.engram_conv_kernel,
332
+ padding=config.engram_conv_kernel - 1,
333
+ groups=config.hidden_size,
334
+ )
335
+ self.conv_dilation = config.engram_conv_dilation
336
+
337
+ with torch.no_grad():
338
+ self.conv.weight.zero_()
339
+ if self.conv.bias is not None:
340
+ self.conv.bias.zero_()
341
+
342
+ self.q_norm = SpiderPortalRMSNorm(config.hidden_size)
343
+ self.k_norm = SpiderPortalRMSNorm(config.hidden_size)
344
+
345
+ # No caching: required for gradient checkpoint recomputation stability.
346
+ self._fwd_cache = None
347
+
348
+ def _compute_indices(self, compressed_ids, n, head_idx):
349
+ """Vectorized NN-gram hash indices for a single (order, head)."""
350
+ # Kept for backward compatibility; not used in the stacked embedding path.
351
+ bsz, seq_len = compressed_ids.shape
352
+ pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device)
353
+ padded = torch.cat([pad, compressed_ids], dim=1)
354
+
355
+ indices_list = []
356
+ for i in range(n):
357
+ indices_list.append(padded[:, i:i + seq_len])
358
+ ngrams = torch.stack(indices_list, dim=-1)
359
+
360
+ seed = int(self.hash_seeds[head_idx].item())
361
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device)
362
+ for i in range(n):
363
+ h_val = h_val * 31 + ngrams[:, :, i]
364
+ h_val = h_val % self.table_size
365
+ h_val = (h_val * seed) % self.table_size
366
+ return h_val
367
+
368
+ def _compute_hash(self, compressed, n, head_counter, bsz, seq_len):
369
+ """Compute n-gram hash indices, with Numba CPU fallback."""
370
+ if not compressed.is_cuda and NUMBA_AVAILABLE:
371
+ import numpy as np
372
+ h_val_np = numba_dispatch(
373
+ "hash_indices",
374
+ compressed.cpu().numpy().astype(np.int64),
375
+ n, self.table_size,
376
+ int(self.hash_seeds[head_counter].item()),
377
+ )
378
+ if h_val_np is not None:
379
+ return torch.from_numpy(h_val_np).to(compressed.device)
380
+
381
+ pad = torch.zeros(bsz, n - 1, dtype=compressed.dtype, device=compressed.device)
382
+ padded = torch.cat([pad, compressed], dim=1)
383
+ ngrams = torch.stack([padded[:, i : i + seq_len] for i in range(n)], dim=-1)
384
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed.device)
385
+ for i in range(n):
386
+ h_val = h_val * 31 + ngrams[:, :, i].to(torch.int64)
387
+ h_val = h_val % self.table_size
388
+ return h_val
389
+
390
+ def _retrieve(self, token_ids):
391
+ """Retrieve memory vectors for a batch of token sequences."""
392
+ bsz, seq_len = token_ids.shape
393
+ compressed = _tokenizer_compress(token_ids)
394
+
395
+ # Use Numba CUDA hash if faster (PyTorch path is default, ~0.2ms per call)
396
+ indices = cuda_engram_hash(
397
+ compressed, self.hash_seeds,
398
+ self.ngram_orders, self.num_heads_per_order, self.table_size,
399
+ )
400
+ if indices is not None:
401
+ all_parts = []
402
+ head_counter = 0
403
+ for order_idx, n in enumerate(self.ngram_orders):
404
+ head_indices = indices[:, :, head_counter:head_counter + self.num_heads_per_order]
405
+ emb_table = self.embed[order_idx]
406
+ idx = head_indices.permute(0, 2, 1).unsqueeze(-1).expand(-1, -1, -1, self.d_mem)
407
+ mem = torch.gather(emb_table.unsqueeze(0).expand(bsz, -1, -1, -1), dim=2, index=idx)
408
+ mem = mem.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads_per_order * self.d_mem)
409
+ all_parts.append(mem)
410
+ head_counter += self.num_heads_per_order
411
+ return torch.cat(all_parts, dim=-1)
412
+
413
+ # PyTorch fallback (CPU or if CUDA kernel unavailable)
414
+ all_parts = []
415
+ head_counter = 0
416
+ for order_idx, n in enumerate(self.ngram_orders):
417
+ h_val = self._compute_hash(compressed, n, head_counter, bsz, seq_len)
418
+ seeds_slice = self.hash_seeds[head_counter : head_counter + self.num_heads_per_order]
419
+ indices_pt = (h_val.unsqueeze(-1) * seeds_slice.view(1, 1, -1)) % self.table_size
420
+ emb_table = self.embed[order_idx]
421
+ idx = indices_pt.permute(0, 2, 1).unsqueeze(-1).expand(-1, -1, -1, self.d_mem)
422
+ mem = torch.gather(emb_table.unsqueeze(0).expand(bsz, -1, -1, -1), dim=2, index=idx)
423
+ mem = mem.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads_per_order * self.d_mem)
424
+ all_parts.append(mem)
425
+ head_counter += self.num_heads_per_order
426
+ return torch.cat(all_parts, dim=-1)
427
+
428
+ def forward(self, hidden_states, token_ids, layer_id: int):
429
+ mem = self._retrieve(token_ids)
430
+
431
+ q = hidden_states
432
+ k = self.W_k(mem)
433
+ v = self.W_v(mem)
434
+ q_norm = self.q_norm(q)
435
+ k_norm = self.k_norm(k)
436
+ alpha = torch.sigmoid(
437
+ (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
438
+ )
439
+ v_gated = alpha * v
440
+ v_gated_t = v_gated.transpose(1, 2)
441
+ conv_out = self.conv(v_gated_t)
442
+ conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
443
+ conv_out = conv_out.transpose(1, 2)
444
+
445
+ y = F.silu(conv_out) + v_gated
446
+ return y
447
+
448
+
449
+ # ---------------------------------------------------------------------------
450
+ # FFN Expert (dense)
451
+ # ---------------------------------------------------------------------------
452
+
453
+ class SpiderPortalExpert(nn.Module):
454
+ def __init__(self, config, intermediate_size=None):
455
+ super().__init__()
456
+ inter_size = intermediate_size or config.intermediate_size
457
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
458
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
459
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
460
+ self.act_fn = nn.SiLU()
461
+ def forward(self, hidden_states):
462
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
463
+
464
+
465
+ # ---------------------------------------------------------------------------
466
+ # Prelude/Coda Dense Layer (uses MLA)
467
+ # ---------------------------------------------------------------------------
468
+
469
+ class SpiderPortalDenseLayer(nn.Module):
470
+ """Prelude/coda dense layer with MLA attention."""
471
+ def __init__(self, config):
472
+ super().__init__()
473
+ self.self_attn = SpiderPortalMLA(config)
474
+ dense_intermediate = config.hidden_size * 4 // 3
475
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
476
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
477
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
478
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
479
+ attn_input = self.input_layernorm(hidden_states)
480
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
481
+ hidden_states = hidden_states + attn_output
482
+ ffn_input = self.post_attention_layernorm(hidden_states)
483
+ ffn_output = self.ffn(ffn_input)
484
+ hidden_states = hidden_states + ffn_output
485
+ return hidden_states, past_kv
486
+
487
+
488
+ # ---------------------------------------------------------------------------
489
+ # Recurrent Dense Layer (uses MLA + optional Engram)
490
+ # ---------------------------------------------------------------------------
491
+
492
+ class SpiderPortalRecurrentDenseLayer(nn.Module):
493
+ """Recurrent layer with MLA attention and optional Engram memory + MoE."""
494
+ def __init__(self, config, layer_idx, has_engram=False):
495
+ super().__init__()
496
+ self.layer_idx = layer_idx
497
+ self.has_engram = has_engram
498
+ self.self_attn = SpiderPortalMLA(config)
499
+ if has_engram:
500
+ self.engram = SpiderPortalEngram(config)
501
+ moe_args = MoEArgs(
502
+ num_experts=config.num_experts,
503
+ num_shared_experts=config.num_shared_experts,
504
+ top_k=config.num_experts_per_tok,
505
+ score_func="sigmoid",
506
+ gate_bias=True,
507
+ route_scale=1.0,
508
+ load_balance_coeff=config.router_aux_loss_coef,
509
+ use_grouped_mm=True,
510
+ )
511
+ self.moe = build_moe(moe_args, dim=config.hidden_size, hidden_dim=config.intermediate_size)
512
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
513
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
514
+ self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None
515
+ def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
516
+ attn_input = self.input_layernorm(hidden_states)
517
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
518
+ hidden_states = hidden_states + attn_output
519
+
520
+ if self.has_engram and token_ids is not None:
521
+ engram_out = self.engram(hidden_states, token_ids, layer_id=self.layer_idx)
522
+ hidden_states = hidden_states + engram_out
523
+ if self.post_engram_layernorm is not None:
524
+ hidden_states = self.post_engram_layernorm(hidden_states)
525
+
526
+ ffn_input = self.post_attention_layernorm(hidden_states)
527
+ z_loss = (self.moe.router.gate(ffn_input.view(-1, ffn_input.size(-1))).logsumexp(dim=-1) ** 2).mean()
528
+ ffn_output = self.moe(ffn_input)
529
+ hidden_states = hidden_states + ffn_output
530
+ return hidden_states, 1e-4 * z_loss, past_kv
531
+
532
+
533
+ # ---------------------------------------------------------------------------
534
+ # LTI Injection, ACT Halting, LoRA Adapter
535
+ # ---------------------------------------------------------------------------
536
+
537
+ class LTIInjection(nn.Module):
538
+ def __init__(self, config):
539
+ super().__init__()
540
+ self.hidden_size = config.hidden_size
541
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
542
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
543
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
544
+ with torch.no_grad():
545
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
546
+ def get_A(self):
547
+ return -torch.exp(self.log_A)
548
+ def forward(self, h_t, e):
549
+ A = self.get_A()
550
+ return A * h_t + self.B(e)
551
+
552
+
553
+ class ACTHalting(nn.Module):
554
+ def __init__(self, config):
555
+ super().__init__()
556
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
557
+ self.threshold = config.act_threshold
558
+ def forward(self, hidden_states):
559
+ return torch.sigmoid(self.halt_predictor(hidden_states))
560
+
561
+
562
+ class LoRAAdapter(nn.Module):
563
+ def __init__(self, config):
564
+ super().__init__()
565
+ rank = config.lora_rank
566
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
567
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
568
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
569
+ with torch.no_grad():
570
+ self.scale.weight.data.zero_()
571
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
572
+ def forward(self, x, loop_t):
573
+ max_t = self.scale.num_embeddings - 1
574
+ t_idx = min(loop_t, max_t)
575
+ s = self.scale(torch.tensor(t_idx, device=x.device))
576
+ down = self.down(x) * s
577
+ return down @ self.B
578
+
579
+
580
+ def checkpoint(func, *args, **kwargs):
581
+ """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
582
+ if torch.is_grad_enabled():
583
+ return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
584
+ return func(*args, **kwargs)
585
+
586
+
587
+ # ---------------------------------------------------------------------------
588
+ # Full Model
589
+ # ---------------------------------------------------------------------------
590
+
591
+ class SpiderPortalDenseModel(nn.Module):
592
+ """Full RDT model with MLA attention + Engram memory at layers 1,4.
593
+
594
+ Architecture:
595
+ 2x Prelude (MLA + dense FFN)
596
+ 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing
597
+ 2x Coda (MLA + dense FFN)
598
+ LTI Injection + ACT Halting + LoRA Adapter
599
+ """
600
+ def __init__(self, config):
601
+ super().__init__()
602
+ self.config = config
603
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
604
+ self.recurrent_layers = nn.ModuleList([
605
+ SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers))
606
+ for i in range(config.num_hidden_layers)
607
+ ])
608
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
609
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
610
+ self.injection = LTIInjection(config)
611
+ self.act_halting = ACTHalting(config)
612
+ self.lora_adapter = LoRAAdapter(config)
613
+ self.loop_embed_dim = config.loop_embed_dim
614
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None):
615
+ n_loops = n_loops or 1
616
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
617
+ for layer in self.prelude_layers:
618
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
619
+ e = hidden_states.clone()
620
+ B, T_seq, D = hidden_states.shape
621
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
622
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
623
+ h_out = torch.zeros_like(hidden_states)
624
+ total_aux_loss = 0.0
625
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
626
+ for t in range(n_loops):
627
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
628
+ if t > 0:
629
+ injection = self.injection(hidden_states, input_embedding)
630
+ hidden_states = hidden_states + injection
631
+ new_past_key_values = []
632
+ for i, layer in enumerate(self.recurrent_layers):
633
+ hidden_states, aux_loss, past_kv = checkpoint(
634
+ layer, hidden_states,
635
+ token_ids=token_ids,
636
+ attention_mask=attention_mask,
637
+ position_ids=position_ids,
638
+ past_key_value=past_key_values[i] if t == 0 else None,
639
+ use_cache=use_cache
640
+ )
641
+ total_aux_loss = total_aux_loss + aux_loss
642
+ new_past_key_values.append(past_kv)
643
+ lora_delta = self.lora_adapter(hidden_states, t)
644
+ hidden_states = hidden_states + lora_delta
645
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
646
+ still_running = ~halted
647
+ remainder = (1.0 - cumulative_p).clamp(min=0)
648
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
649
+ weight = weight * still_running.to(hidden_states.dtype)
650
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
651
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
652
+ halted = halted | (cumulative_p >= self.config.act_threshold)
653
+ if halted.all() and not self.training:
654
+ break
655
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
656
+ hidden_states = h_out + never_halted * hidden_states
657
+ for layer in self.coda_layers:
658
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
659
+ hidden_states = self.norm(hidden_states)
660
+ return hidden_states, total_aux_loss, new_past_key_values
661
+
662
+
663
+ class SpiderPortalForConditionalGeneration(nn.Module):
664
+ def __init__(self, config):
665
+ super().__init__()
666
+ self.config = config
667
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
668
+ self.model = SpiderPortalDenseModel(config)
669
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
670
+ if config.tie_word_embeddings:
671
+ self.lm_head.weight = self.embed_tokens.weight
672
+ self.apply(self._init_weights)
673
+ def _init_weights(self, module):
674
+ if isinstance(module, nn.Linear):
675
+ if hasattr(self, 'model') and module is self.model.injection.B:
676
+ return
677
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
678
+ if module.bias is not None:
679
+ module.bias.data.zero_()
680
+ elif isinstance(module, nn.Embedding):
681
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
682
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
683
+ hidden_states = self.embed_tokens(input_ids)
684
+ model_dtype = next(self.model.parameters()).dtype
685
+ hidden_states = hidden_states.to(model_dtype)
686
+ input_embedding = hidden_states.clone()
687
+ hidden_states, aux_loss, past_kv = self.model(
688
+ hidden_states, input_embedding=input_embedding,
689
+ attention_mask=None, position_ids=position_ids,
690
+ use_cache=use_cache, n_loops=n_loops, token_ids=input_ids
691
+ )
692
+ logits = self.lm_head(hidden_states)
693
+ loss = None
694
+ if labels is not None:
695
+ shift_logits = logits[..., :-1, :].contiguous()
696
+ shift_labels = labels[..., 1:].contiguous()
697
+ loss_fct = CrossEntropyLoss()
698
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
699
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
700
+ def get_num_params(self):
701
+ total = sum(p.numel() for p in self.parameters())
702
+ return {"total": total, "trainable": total}
703
+
704
+
705
+ # ---------------------------------------------------------------------------
706
+ # Dataset
707
+ # ---------------------------------------------------------------------------
708
+
709
+ class FineWebEduDataset(IterableDataset):
710
+ def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int):
711
+ self.tokenizer = tokenizer
712
+ self.seq_len = seq_len
713
+ self.subset = subset
714
+ self.rank = rank
715
+ self.world_size = world_size
716
+
717
+ # Local tokenized data - USE mmapped binary for speed
718
+ LOCAL_TOKEN_FILE = "/data/fineweb_tokenized/train_tokens.bin"
719
+
720
+ if os.path.exists(LOCAL_TOKEN_FILE):
721
+ # Use memory-mapped file for fast I/O
722
+ import numpy as np
723
+ self.use_local = True
724
+ self.local_file = LOCAL_TOKEN_FILE
725
+ # Memory map for zero-copy reading
726
+ self.mm = np.memmap(LOCAL_TOKEN_FILE, dtype='<u4', mode='r')
727
+ self.num_tokens = len(self.mm)
728
+ self.num_samples = self.num_tokens // seq_len
729
+ else:
730
+ self.use_local = False
731
+
732
+ def __iter__(self):
733
+ if self.use_local:
734
+ # Fast: use memory-mapped array
735
+ worker = get_worker_info()
736
+ num_workers = worker.num_workers if worker else 1
737
+ worker_id = worker.id if worker else 0
738
+
739
+ samples_per_worker = self.num_samples // (self.world_size * num_workers)
740
+ start_sample = (self.rank * num_workers + worker_id) * samples_per_worker
741
+ end_sample = start_sample + samples_per_worker
742
+
743
+ # Batch read tokens - convert to numpy array slice then tensor
744
+ import numpy as np
745
+ for i in range(start_sample, end_sample):
746
+ start_idx = i * self.seq_len
747
+ # Direct slice from memory-mapped array (avoid extra copies when possible)
748
+ tokens = self.mm[start_idx : start_idx + self.seq_len + 1]
749
+ x_np = tokens[:-1].astype("int64", copy=False)
750
+ y_np = tokens[1:].astype("int64", copy=False)
751
+ yield torch.from_numpy(x_np), torch.from_numpy(y_np)
752
+ else:
753
+ # Fallback to HuggingFace
754
+ worker = get_worker_info()
755
+ num_workers = worker.num_workers if worker else 1
756
+ worker_id = worker.id if worker else 0
757
+ total_shards = self.world_size * num_workers
758
+ shard_index = self.rank * num_workers + worker_id
759
+ ds = load_dataset(
760
+ "HuggingFaceFW/fineweb-edu",
761
+ name=self.subset,
762
+ split="train",
763
+ streaming=True,
764
+ ).shard(num_shards=total_shards, index=shard_index)
765
+ buf = []
766
+ for sample in ds:
767
+ buf.extend(self.tokenizer.encode(sample["text"]))
768
+ while len(buf) >= self.seq_len + 1:
769
+ chunk = buf[: self.seq_len + 1]
770
+ buf = buf[self.seq_len + 1 :]
771
+ yield (
772
+ torch.tensor(chunk[:-1], dtype=torch.long),
773
+ torch.tensor(chunk[1:], dtype=torch.long),
774
+ )
775
+
776
+
777
+ # ---------------------------------------------------------------------------
778
+ # LR schedule
779
+ # ---------------------------------------------------------------------------
780
+
781
+ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
782
+ if step < warmup:
783
+ return max_lr * step / warmup
784
+ if step >= total:
785
+ return min_lr
786
+ decay = (step - warmup) / (total - warmup)
787
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
788
+
789
+
790
+ # ---------------------------------------------------------------------------
791
+ # Checkpointing
792
+ # ---------------------------------------------------------------------------
793
+
794
+ def save_weights_only(model, step, epoch, ckpt_dir, ddp):
795
+ if ddp:
796
+ with FSDP.state_dict_type(
797
+ model,
798
+ StateDictType.FULL_STATE_DICT,
799
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
800
+ ):
801
+ model_state = model.state_dict()
802
+ else:
803
+ model_state = model.state_dict()
804
+ ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-moe-ep{epoch}-step{step}.pt")
805
+ tmp_path = ckpt_path + ".tmp"
806
+ torch.save(model_state, tmp_path)
807
+ os.replace(tmp_path, ckpt_path)
808
+ size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
809
+ return ckpt_path, size_mb
810
+
811
+
812
+ def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
813
+ if ddp:
814
+ with FSDP.state_dict_type(
815
+ model,
816
+ StateDictType.FULL_STATE_DICT,
817
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
818
+ ):
819
+ model_state = model.state_dict()
820
+ optim_state = FSDP.optim_state_dict(model, optimizer)
821
+ else:
822
+ model_state = model.state_dict()
823
+ optim_state = optimizer.state_dict()
824
+ if not master:
825
+ return None, 0
826
+ os.makedirs(ckpt_dir, exist_ok=True)
827
+ final_path = os.path.join(ckpt_dir, f"spiderportal-v5-moe-{ckpt_name}.pt")
828
+ tmp_path = final_path + ".tmp"
829
+ torch.save(
830
+ {
831
+ "step": step,
832
+ "epoch": epoch,
833
+ "model_state_dict": model_state,
834
+ "optimizer_state_dict": optim_state,
835
+ "cfg": cfg,
836
+ "vocab_size": vocab_size,
837
+ },
838
+ tmp_path,
839
+ )
840
+ os.replace(tmp_path, final_path)
841
+ size_mb = os.path.getsize(final_path) / (1024 * 1024)
842
+ return final_path, size_mb
843
+
844
+
845
+ def load_checkpoint(model, optimizer, path, ddp):
846
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
847
+ if ddp:
848
+ with FSDP.state_dict_type(
849
+ model,
850
+ StateDictType.FULL_STATE_DICT,
851
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
852
+ ):
853
+ model.load_state_dict(ckpt["model_state_dict"])
854
+ optim_state = FSDP.optim_state_dict_to_load(
855
+ model=model,
856
+ optim=optimizer,
857
+ optim_state_dict=ckpt["optimizer_state_dict"],
858
+ )
859
+ optimizer.load_state_dict(optim_state)
860
+ else:
861
+ model.load_state_dict(ckpt["model_state_dict"])
862
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
863
+ return int(ckpt["step"]), int(ckpt.get("epoch", 0))
864
+
865
+
866
+ # ---------------------------------------------------------------------------
867
+ # Dense → MoE checkpoint converter
868
+ # ---------------------------------------------------------------------------
869
+
870
+ def convert_dense_to_moe(model, dense_ckpt_path, device):
871
+ """Load dense checkpoint and copy FFN weights into MoE experts + shared expert."""
872
+ ckpt = torch.load(dense_ckpt_path, map_location="cpu", weights_only=False)
873
+ if "model_state_dict" in ckpt:
874
+ dense_sd = ckpt["model_state_dict"]
875
+ else:
876
+ dense_sd = ckpt
877
+
878
+ moe_sd = model.state_dict()
879
+
880
+ # For each recurrent layer, expand dense FFN → 32 experts + shared expert
881
+ for layer_idx in range(model.config.num_hidden_layers):
882
+ prefix = f"model.recurrent_layers.{layer_idx}"
883
+
884
+ # Dense FFN weight keys
885
+ gate_key = f"{prefix}.ffn.gate_proj.weight"
886
+ up_key = f"{prefix}.ffn.up_proj.weight"
887
+ down_key = f"{prefix}.ffn.down_proj.weight"
888
+
889
+ if gate_key not in dense_sd:
890
+ continue
891
+
892
+ gate_w = dense_sd[gate_key] # [4096, 2048] (out_features, in_features)
893
+ up_w = dense_sd[up_key] # [4096, 2048]
894
+ down_w = dense_sd[down_key] # [2048, 4096]
895
+
896
+ # Copy into GroupedExperts: w1 [32, 4096, 2048], w2 [32, 2048, 4096], w3 [32, 4096, 2048]
897
+ for w_name, dense_w in [("w1", gate_w), ("w3", up_w), ("w2", down_w)]:
898
+ expert_key = f"{prefix}.moe.experts.{w_name}"
899
+ if expert_key in moe_sd:
900
+ # Expand dense weight across all 32 experts
901
+ expanded = dense_w.unsqueeze(0).expand(moe_sd[expert_key].shape[0], -1, -1).contiguous()
902
+ moe_sd[expert_key].copy_(expanded.to(moe_sd[expert_key].dtype))
903
+
904
+ # Copy into shared expert FeedForward
905
+ for w_name, dense_w in [("w1", gate_w), ("w3", up_w), ("w2", down_w)]:
906
+ shared_key = f"{prefix}.moe.shared_experts.{w_name}.weight"
907
+ if shared_key in moe_sd:
908
+ moe_sd[shared_key].copy_(dense_w.to(moe_sd[shared_key].dtype))
909
+
910
+ # Load the converted state dict back
911
+ model.load_state_dict(moe_sd, strict=False)
912
+
913
+ # Initialize router gates with small normal noise (break symmetry)
914
+ for name, module in model.named_modules():
915
+ if hasattr(module, 'router') and hasattr(module.router, 'gate'):
916
+ nn.init.normal_(module.router.gate.weight, mean=0.0, std=0.02)
917
+ if module.router.gate.bias is not None:
918
+ nn.init.zeros_(module.router.gate.bias)
919
+
920
+ return model
921
+
922
+
923
+ # ---------------------------------------------------------------------------
924
+ # Main
925
+ # ---------------------------------------------------------------------------
926
+
927
+ def main():
928
+ # ------------------------------------------------------------------
929
+ # Distributed init
930
+ # ------------------------------------------------------------------
931
+ ddp = int(os.environ.get("RANK", -1)) != -1
932
+ if ddp:
933
+ dist.init_process_group("nccl")
934
+ rank = int(os.environ["RANK"])
935
+ local_rank = int(os.environ["LOCAL_RANK"])
936
+ world_size = int(os.environ["WORLD_SIZE"])
937
+ device = f"cuda:{local_rank}"
938
+ torch.cuda.set_device(device)
939
+ else:
940
+ rank = local_rank = 0
941
+ world_size = 1
942
+ device = "cuda" if torch.cuda.is_available() else "cpu"
943
+ master = rank == 0
944
+
945
+ # ------------------------------------------------------------------
946
+ # Tokenizer
947
+ # ------------------------------------------------------------------
948
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
949
+ tokenizer.pad_token = tokenizer.eos_token
950
+ vocab_size = tokenizer.vocab_size
951
+ if master:
952
+ logger.info(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
953
+
954
+ # ------------------------------------------------------------------
955
+ # Hyperparameters
956
+ # ------------------------------------------------------------------
957
+ seq_len = int(os.environ.get("SEQ_LEN", "2048"))
958
+ micro_batch = int(os.environ.get("MICRO_BATCH", "32"))
959
+ target_tokens = int(os.environ.get("TARGET_TOKENS", "50_000_000"))
960
+ grad_accum = int(os.environ.get("GRAD_ACCUM", "1"))
961
+ global_batch_tok = world_size * micro_batch * grad_accum * seq_len
962
+ total_steps = target_tokens // global_batch_tok
963
+ warmup_steps = 200
964
+ lr = 3e-4
965
+ wd = 0.1
966
+ log_every = 10
967
+ ckpt_every = int(os.environ.get("CKPT_EVERY", "500"))
968
+ ckpt_dir = "checkpoints-moe"
969
+ dataset_subset = "sample-10BT"
970
+ dense_ckpt = os.environ.get("DENSE_CKPT", "")
971
+
972
+ if master:
973
+ logger.info(
974
+ f"[MOE MLA+Engram] hidden=2048 | layers=6 | experts=32 | top-2 | "
975
+ f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
976
+ f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
977
+ )
978
+ logger.info(
979
+ "Attention: MLA (sliding_window disabled) | "
980
+ "Engram: layers [1,4] | Context: 32k | "
981
+ "Gradient checkpointing: enabled"
982
+ )
983
+
984
+ # ------------------------------------------------------------------
985
+ # Model
986
+ # ------------------------------------------------------------------
987
+ cfg = SpiderPortalConfig(
988
+ hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
989
+ num_key_value_heads=4, intermediate_size=4096,
990
+ num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
991
+ router_aux_loss_coef=0.05, max_loop_iters=2,
992
+ prelude_layers=2, coda_layers=2, lora_rank=128,
993
+ rope_theta=10000000.0,
994
+ rope_scaling=None,
995
+ max_position_embeddings=32768,
996
+ sliding_window=0,
997
+ tie_word_embeddings=True,
998
+ kv_lora_rank=128, q_lora_rank=256,
999
+ qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64,
1000
+ engram_layers=[1, 4],
1001
+ engram_ngram_orders=(2, 3),
1002
+ engram_hash_heads=4,
1003
+ engram_table_size=65537,
1004
+ engram_dim=128,
1005
+ )
1006
+ cfg.vocab_size = vocab_size
1007
+
1008
+ bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
1009
+ amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
1010
+
1011
+ model = SpiderPortalForConditionalGeneration(cfg).to(torch.bfloat16)
1012
+
1013
+ if ddp:
1014
+ mp_policy = MixedPrecision(
1015
+ param_dtype=amp_dtype,
1016
+ reduce_dtype=amp_dtype,
1017
+ buffer_dtype=amp_dtype,
1018
+ )
1019
+ wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
1020
+ model = FSDP(
1021
+ model,
1022
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
1023
+ mixed_precision=mp_policy,
1024
+ auto_wrap_policy=wrap_policy,
1025
+ device_id=local_rank,
1026
+ )
1027
+
1028
+ else:
1029
+ model = model.to(device)
1030
+
1031
+ if master:
1032
+ logger.info("MoE mode: using native bf16 (MXFP8 disabled)")
1033
+
1034
+ # Dense → MoE checkpoint conversion
1035
+ if dense_ckpt and os.path.exists(dense_ckpt):
1036
+ if master:
1037
+ logger.info(f"Loading dense checkpoint and converting to MoE: {dense_ckpt}")
1038
+ model = convert_dense_to_moe(model, dense_ckpt, device)
1039
+ if master:
1040
+ logger.success("Dense → MoE conversion complete")
1041
+ elif master:
1042
+ logger.info("MoE from scratch: no dense checkpoint loaded")
1043
+
1044
+ # Triton compilation for MoE modules
1045
+ TRITON_COMPILE = os.environ.get("TRITON_COMPILE", "0") == "1"
1046
+ if TRITON_COMPILE:
1047
+ if master:
1048
+ logger.info("Applying torch.compile to MoE modules (default mode)")
1049
+ for layer in model.model.recurrent_layers:
1050
+ if hasattr(layer, 'moe'):
1051
+ layer.moe = torch.compile(layer.moe, dynamic=True)
1052
+ if master:
1053
+ logger.success("torch.compile applied to all MoE layers")
1054
+
1055
+ # Optimizer states will be bf16 (matching param dtype) via foreach=True — saves ~21GB vs fp32
1056
+ if master:
1057
+ logger.info("Optimizer: AdamW(foreach=True, bf16 states) — saves ~21GB VRAM over fp32")
1058
+
1059
+ amp_ctx = (
1060
+ torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
1061
+ if "cuda" in device
1062
+ else nullcontext()
1063
+ )
1064
+
1065
+ amp_ctx = nullcontext() if ddp else amp_ctx
1066
+
1067
+ # Enable SDPA best kernels when available.
1068
+ try:
1069
+ from torch.nn.attention import sdpa_kernel
1070
+
1071
+ sdpa_ctx = sdpa_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True)
1072
+ except Exception:
1073
+ sdpa_ctx = nullcontext()
1074
+
1075
+ if master:
1076
+ n_params = sum(p.numel() for p in model.parameters())
1077
+ engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n)
1078
+ mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n)
1079
+ embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n)
1080
+ moe_params = sum(p.numel() for n, p in model.named_parameters() if 'moe' in n)
1081
+ router_params = sum(p.numel() for n, p in model.named_parameters() if 'router' in n)
1082
+ other_params = n_params - engram_params - mla_params - embed_params - moe_params
1083
+ logger.info(
1084
+ f"Parameters: {n_params:,} | "
1085
+ f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | "
1086
+ f"MoE: {moe_params:,} | Router: {router_params:,} | "
1087
+ f"Engram: {engram_params:,} | Other: {other_params:,} | AMP dtype: {amp_dtype}"
1088
+ )
1089
+
1090
+ # ------------------------------------------------------------------
1091
+ # Optimizer — dual optimizer for Engram embeddings
1092
+ # ------------------------------------------------------------------
1093
+ engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed' in n and 'proj' not in n]
1094
+ backbone_params = [p for n, p in model.named_parameters() if not ('engram' in n and 'embed' in n and 'proj' not in n)]
1095
+
1096
+ optimizer = torch.optim.AdamW(
1097
+ backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=False, foreach=True, eps=1e-8
1098
+ )
1099
+ if engram_params_list:
1100
+ engram_optimizer = torch.optim.Adam(
1101
+ engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8
1102
+ )
1103
+ else:
1104
+ engram_optimizer = None
1105
+
1106
+ # ------------------------------------------------------------------
1107
+ # Resume from latest checkpoint
1108
+ # ------------------------------------------------------------------
1109
+ start_step = 0
1110
+ start_epoch = 1
1111
+ best_loss = float("inf")
1112
+ existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-moe-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
1113
+ if existing_ckpts:
1114
+ latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
1115
+ if master:
1116
+ logger.info(f"Resuming from checkpoint: {latest}")
1117
+ start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
1118
+ if master:
1119
+ logger.success(f"Resumed at step {start_step}, epoch {start_epoch}")
1120
+
1121
+ # ------------------------------------------------------------------
1122
+ # Dataset + DataLoader
1123
+ # ------------------------------------------------------------------
1124
+ dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size)
1125
+ loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True, prefetch_factor=1)
1126
+
1127
+ # ------------------------------------------------------------------
1128
+ # Training loop
1129
+ # ------------------------------------------------------------------
1130
+ if master:
1131
+ os.makedirs(ckpt_dir, exist_ok=True)
1132
+
1133
+ model.train()
1134
+ data_iter = iter(loader)
1135
+ t0 = time.perf_counter()
1136
+ step = start_step
1137
+ epoch = start_epoch
1138
+ step_ckpt_files = []
1139
+ tokens_in_epoch = 0
1140
+ tokens_per_epoch = target_tokens
1141
+
1142
+ # Allow env override for quick debugging.
1143
+ max_steps_override = os.environ.get("MAX_STEPS", None)
1144
+ while step < total_steps:
1145
+ if max_steps_override is not None and step >= int(max_steps_override):
1146
+ break
1147
+ cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
1148
+ for g in optimizer.param_groups:
1149
+ g["lr"] = cur_lr
1150
+ if engram_optimizer:
1151
+ for g in engram_optimizer.param_groups:
1152
+ g["lr"] = cur_lr * 5
1153
+
1154
+ optimizer.zero_grad()
1155
+ if engram_optimizer:
1156
+ engram_optimizer.zero_grad()
1157
+ loss_accum = 0.0
1158
+
1159
+ for micro_step in range(grad_accum):
1160
+ try:
1161
+ x, y = next(data_iter)
1162
+ except StopIteration:
1163
+ data_iter = iter(loader)
1164
+ x, y = next(data_iter)
1165
+
1166
+ x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
1167
+ y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
1168
+
1169
+ sync = (
1170
+ nullcontext()
1171
+ if (not ddp or micro_step == grad_accum - 1)
1172
+ else model.no_sync()
1173
+ )
1174
+ with sync, amp_ctx, sdpa_ctx:
1175
+ output = model(x)
1176
+ if master and step == start_step and micro_step == 0:
1177
+ peak_vram = torch.cuda.max_memory_allocated() / 1024**3
1178
+ logger.info(f"Reached first model forward | Peak VRAM: {peak_vram:.1f}GB")
1179
+ if isinstance(output, dict):
1180
+ logits = output["logits"]
1181
+ aux_loss = output.get("aux_loss", 0.0)
1182
+ else:
1183
+ logits = output
1184
+ aux_loss = 0.0
1185
+ loss = nn.functional.cross_entropy(
1186
+ logits.view(-1, vocab_size), y.view(-1)
1187
+ )
1188
+ loss = loss + cfg.router_aux_loss_coef * aux_loss
1189
+ loss = loss / grad_accum
1190
+
1191
+ loss.backward()
1192
+ if master and step == start_step and micro_step == 0:
1193
+ logger.info("Reached first backward")
1194
+ loss_accum += loss.item()
1195
+
1196
+ if ddp:
1197
+ grad_norm = model.clip_grad_norm_(1.0)
1198
+ else:
1199
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1200
+ optimizer.step()
1201
+ if engram_optimizer:
1202
+ engram_optimizer.step()
1203
+ step += 1
1204
+ tokens_in_epoch += global_batch_tok
1205
+
1206
+ if master and step % log_every == 0:
1207
+ dt = time.perf_counter() - t0
1208
+ tok_per_sec = global_batch_tok * log_every / dt
1209
+ tokens_seen = step * global_batch_tok
1210
+ logger.info(
1211
+ f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
1212
+ f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
1213
+ f"| {tok_per_sec / 1e6:.2f}M tok/s "
1214
+ f"| {tokens_seen / 1e9:.2f}B tokens seen"
1215
+ )
1216
+ t0 = time.perf_counter()
1217
+
1218
+ if step % ckpt_every == 0 and master:
1219
+ ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
1220
+ step_ckpt_files.append(ckpt_path)
1221
+ logger.info(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1222
+
1223
+ if tokens_in_epoch >= tokens_per_epoch:
1224
+ epoch_loss = loss_accum
1225
+ if master:
1226
+ epoch_time = (time.perf_counter() - t0) / 60
1227
+ logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
1228
+
1229
+ for f in step_ckpt_files:
1230
+ if os.path.exists(f):
1231
+ os.remove(f)
1232
+ logger.info(f" Deleted step checkpoint: {os.path.basename(f)}")
1233
+ step_ckpt_files.clear()
1234
+
1235
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
1236
+ if ckpt_path:
1237
+ logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1238
+
1239
+ if epoch_loss < best_loss:
1240
+ best_loss = epoch_loss
1241
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
1242
+ if ckpt_path:
1243
+ logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1244
+
1245
+ epoch += 1
1246
+ tokens_in_epoch = 0
1247
+
1248
+ if step > start_step and master:
1249
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
1250
+ if ckpt_path:
1251
+ logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1252
+
1253
+ if ddp:
1254
+ dist.barrier()
1255
+ dist.destroy_process_group()
1256
+
1257
+ if master:
1258
+ logger.success("Training complete.")
1259
+
1260
+
1261
+ if __name__ == "__main__":
1262
+ main()