CLIWorks commited on
Commit
9f0cc3b
·
verified ·
1 Parent(s): e066e25

Upload mythos-fineweb-dense.py

Browse files
Files changed (1) hide show
  1. mythos-fineweb-dense.py +1060 -0
mythos-fineweb-dense.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
+
5
+ Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
6
+ - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window
7
+ - Engram conditional memory at recurrent layers 1 and 4
8
+ - Dense FFN (all params active, MoE conversion in Phase 2)
9
+ - LTI Injection + ACT Halting + LoRA Adapter
10
+ - 32k context (extendable to 256k at inference via YaRN)
11
+
12
+ Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing
13
+
14
+ Single GPU:
15
+ python mythos-fineweb-dense.py
16
+
17
+ Multi-GPU:
18
+ torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
19
+ """
20
+ import os
21
+ import math
22
+ import time
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import torch.distributed as dist
27
+ from loguru import logger
28
+ from torch.distributed.fsdp import (
29
+ FullyShardedDataParallel as FSDP,
30
+ ShardingStrategy,
31
+ MixedPrecision,
32
+ FullStateDictConfig,
33
+ StateDictType,
34
+ )
35
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
36
+ from torch.utils.data import IterableDataset, DataLoader, get_worker_info
37
+ from contextlib import nullcontext
38
+ from dataclasses import dataclass, field
39
+ from typing import Optional, Tuple, Dict, List
40
+ from torch.nn import CrossEntropyLoss
41
+ from datasets import load_dataset
42
+ from transformers import AutoTokenizer
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # SpiderPortal Model Architecture (Dense + MLA + Engram)
47
+ # ---------------------------------------------------------------------------
48
+
49
+ @dataclass
50
+ class SpiderPortalConfig:
51
+ vocab_size: int = 50257
52
+ hidden_size: int = 2048
53
+ num_hidden_layers: int = 6
54
+ num_attention_heads: int = 16
55
+ num_key_value_heads: int = 4
56
+ intermediate_size: int = 8192
57
+ hidden_act: str = "silu"
58
+ num_experts: int = 32
59
+ num_experts_per_tok: int = 2
60
+ num_shared_experts: int = 1
61
+ router_aux_loss_coef: float = 0.05
62
+ max_loop_iters: int = 4
63
+ act_threshold: float = 0.5
64
+ max_position_embeddings: int = 32768
65
+ rope_theta: float = 10000000.0
66
+ rope_scaling: dict = None
67
+ sliding_window: int = 4096
68
+ attention_dropout: float = 0.0
69
+ rms_norm_eps: float = 1e-6
70
+ initializer_range: float = 0.02
71
+ use_cache: bool = True
72
+ tie_word_embeddings: bool = True
73
+ prelude_layers: int = 2
74
+ coda_layers: int = 2
75
+ lora_rank: int = 128
76
+ loop_embed_dim: int = 128
77
+ vision_hidden_size: int = 2048
78
+ audio_hidden_size: int = 512
79
+ vision_num_frames: int = 60
80
+ vision_tokens_per_frame: int = 256
81
+ vision_temporal_tokens: int = 64
82
+ vision_temporal_layers: int = 2
83
+ model_type: str = "spiderportal"
84
+ torch_dtype: str = "bfloat16"
85
+
86
+ # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
87
+ kv_lora_rank: int = 128
88
+ q_lora_rank: int = 256
89
+ qk_rope_head_dim: int = 64
90
+ qk_nope_head_dim: int = 64
91
+ v_head_dim: int = 64
92
+
93
+ # Engram parameters (DeepSeek conditional memory)
94
+ engram_layers: List[int] = field(default_factory=lambda: [1, 4])
95
+ engram_ngram_orders: Tuple[int, ...] = (2, 3)
96
+ engram_hash_heads: int = 4
97
+ engram_table_size: int = 65537 # prime number for hash table
98
+ engram_conv_kernel: int = 4
99
+ engram_conv_dilation: int = 3
100
+ engram_dim: int = 128 # per-head embedding dimension
101
+
102
+
103
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
104
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
105
+ angles = loop_t * freqs
106
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
107
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
108
+ emb_full[:loop_dim] = emb
109
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
110
+
111
+
112
+ class SpiderPortalRMSNorm(nn.Module):
113
+ def __init__(self, hidden_size, eps=1e-6):
114
+ super().__init__()
115
+ self.weight = nn.Parameter(torch.ones(hidden_size))
116
+ self.variance_epsilon = eps
117
+ def forward(self, hidden_states):
118
+ input_dtype = hidden_states.dtype
119
+ hidden_states = hidden_states.to(torch.float32)
120
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
121
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
123
+
124
+
125
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
126
+ dim = head_dim
127
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
128
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
129
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
130
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
131
+ return orig_inv_freq * scale
132
+
133
+
134
+ # ---------------------------------------------------------------------------
135
+ # MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window
136
+ # ---------------------------------------------------------------------------
137
+
138
+ class SpiderPortalMLA(nn.Module):
139
+ """Multi-Latent Attention with compressed KV cache and sliding window.
140
+
141
+ For hidden_size=2048, num_heads=16:
142
+ - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128
143
+ - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV
144
+ - v_head_dim=64 → value projection
145
+ - sliding_window=4096 → local attention range
146
+ """
147
+ def __init__(self, config):
148
+ super().__init__()
149
+ self.config = config
150
+ self.hidden_size = config.hidden_size
151
+ self.num_heads = config.num_attention_heads
152
+ self.kv_lora_rank = config.kv_lora_rank
153
+ self.q_lora_rank = config.q_lora_rank
154
+ self.qk_rope_head_dim = config.qk_rope_head_dim
155
+ self.qk_nope_head_dim = config.qk_nope_head_dim
156
+ self.v_head_dim = config.v_head_dim
157
+ self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
158
+ self.sliding_window = getattr(config, 'sliding_window', None)
159
+
160
+ # Q projection: optional low-rank → full Q
161
+ if self.q_lora_rank > 0:
162
+ self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
163
+ self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank)
164
+ self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
165
+ else:
166
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
167
+
168
+ # KV compression: hidden → kv_lora_rank (shared latent)
169
+ self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
170
+ self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank)
171
+ # Decompress: kv_lora_rank → nope heads + v heads
172
+ self.kv_b_proj = nn.Linear(
173
+ self.kv_lora_rank,
174
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
175
+ bias=False,
176
+ )
177
+ # Output projection
178
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
179
+
180
+ # RoPE frequencies
181
+ rope_scaling = getattr(config, 'rope_scaling', None)
182
+ if rope_scaling and rope_scaling.get("type") == "yarn":
183
+ factor = rope_scaling.get("factor", 1.0)
184
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
185
+ inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos)
186
+ else:
187
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim))
188
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
189
+
190
+ def _rotate_half(self, x):
191
+ x1 = x[..., :x.shape[-1] // 2]
192
+ x2 = x[..., x.shape[-1] // 2:]
193
+ return torch.cat((-x2, x1), dim=-1)
194
+
195
+ def _apply_rotary(self, x, cos, sin):
196
+ return (x * cos) + (self._rotate_half(x) * sin)
197
+
198
+ def _make_sliding_window_mask(self, q_len, kv_len, device, dtype):
199
+ """Create a sliding window causal mask."""
200
+ if self.sliding_window is None or self.sliding_window <= 0:
201
+ return None
202
+ mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, device=device, dtype=dtype)
203
+ for i in range(q_len):
204
+ start = max(0, i - self.sliding_window + 1)
205
+ mask[i, start:i + 1] = 0.0
206
+ return mask
207
+
208
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
209
+ bsz, q_len, _ = hidden_states.size()
210
+
211
+ # Q projection
212
+ if self.q_lora_rank > 0:
213
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
214
+ else:
215
+ q = self.q_proj(hidden_states)
216
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
217
+ q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
218
+
219
+ # KV: compress to latent, then decompress
220
+ kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
221
+ kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
222
+ kv_latent_norm = self.kv_a_layernorm(kv_latent)
223
+ kv_b_out = self.kv_b_proj(kv_latent_norm)
224
+ k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1)
225
+
226
+ k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
227
+ v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
228
+ k_rope = k_rope.unsqueeze(1)
229
+
230
+ # RoPE on Q and K rope parts
231
+ if position_ids is None:
232
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
233
+ max_pos = position_ids.max().item() + 1
234
+ seq_len = max(max_pos, q_len)
235
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
236
+ freqs = torch.outer(t, self.inv_freq)
237
+ emb = torch.cat((freqs, freqs), dim=-1)
238
+ cos, sin = emb.cos(), emb.sin()
239
+ cos = cos[position_ids].unsqueeze(1)
240
+ sin = sin[position_ids].unsqueeze(1)
241
+
242
+ q_rope = self._apply_rotary(q_rope, cos, sin)
243
+ k_rope = self._apply_rotary(k_rope, cos, sin)
244
+
245
+ # Assemble full K
246
+ k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
247
+ k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
248
+ q_full = torch.cat([q_nope, q_rope], dim=-1)
249
+
250
+ # KV cache
251
+ if past_key_value is not None:
252
+ k_full = torch.cat([past_key_value[0], k_full], dim=2)
253
+ v = torch.cat([past_key_value[1], v], dim=2)
254
+ past_kv = (k_full, v) if use_cache else None
255
+
256
+ # Build attention mask: user mask + sliding window
257
+ final_mask = attention_mask
258
+ if self.sliding_window is not None and self.sliding_window > 0:
259
+ kv_len = k_full.size(2)
260
+ sw_mask = self._make_sliding_window_mask(q_len, kv_len, hidden_states.device, hidden_states.dtype)
261
+ if final_mask is not None:
262
+ final_mask = final_mask + sw_mask
263
+ else:
264
+ final_mask = sw_mask
265
+
266
+ # Attention with SDPA
267
+ attn_output = F.scaled_dot_product_attention(
268
+ q_full, k_full, v,
269
+ attn_mask=final_mask,
270
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
271
+ is_causal=(final_mask is None),
272
+ )
273
+ attn_output = attn_output.transpose(1, 2).contiguous()
274
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
275
+ return self.o_proj(attn_output), past_kv
276
+
277
+
278
+ # ---------------------------------------------------------------------------
279
+ # Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
280
+ # ---------------------------------------------------------------------------
281
+
282
+ def _tokenizer_compress(token_ids, vocab_size=50257):
283
+ """Simulate NFKC + lowercase canonical ID projection."""
284
+ return token_ids % (vocab_size * 77 // 100)
285
+
286
+
287
+ class SpiderPortalEngram(nn.Module):
288
+ """Conditional memory module via NN-gram lookup.
289
+
290
+ Applied only at specific recurrent layers (config.engram_layers).
291
+ """
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.config = config
295
+ self.ngram_orders = config.engram_ngram_orders
296
+ self.num_heads = config.engram_hash_heads
297
+ self.table_size = config.engram_table_size
298
+ self.d_mem = config.engram_dim
299
+
300
+ self.total_mem_dim = len(self.ngram_orders) * self.num_heads * self.d_mem
301
+
302
+ self.embed_tables = nn.ParameterDict()
303
+ for n in self.ngram_orders:
304
+ for h in range(self.num_heads):
305
+ key = f"e_{n}_{h}"
306
+ self.embed_tables[key] = nn.Parameter(
307
+ torch.randn(self.table_size, self.d_mem) * 0.02
308
+ )
309
+
310
+ self.register_buffer("hash_seeds", torch.tensor([
311
+ (h + 1) * 2654435761
312
+ for _ in self.ngram_orders
313
+ for h in range(self.num_heads)
314
+ ], dtype=torch.int64))
315
+
316
+ self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
317
+ self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
318
+
319
+ self.conv = nn.Conv1d(
320
+ config.hidden_size, config.hidden_size,
321
+ kernel_size=config.engram_conv_kernel,
322
+ padding=config.engram_conv_kernel - 1,
323
+ groups=config.hidden_size,
324
+ )
325
+ self.conv_dilation = config.engram_conv_dilation
326
+
327
+ with torch.no_grad():
328
+ self.conv.weight.zero_()
329
+ if self.conv.bias is not None:
330
+ self.conv.bias.zero_()
331
+
332
+ self.q_norm = SpiderPortalRMSNorm(config.hidden_size)
333
+ self.k_norm = SpiderPortalRMSNorm(config.hidden_size)
334
+
335
+ def _compute_indices(self, compressed_ids, n, head_idx):
336
+ """Vectorized NN-gram hash indices for a single (order, head)."""
337
+ bsz, seq_len = compressed_ids.shape
338
+ pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device)
339
+ padded = torch.cat([pad, compressed_ids], dim=1)
340
+
341
+ indices_list = []
342
+ for i in range(n):
343
+ indices_list.append(padded[:, i:i + seq_len])
344
+ ngrams = torch.stack(indices_list, dim=-1)
345
+
346
+ seed = int(self.hash_seeds[head_idx].item())
347
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device)
348
+ for i in range(n):
349
+ h_val = h_val * 31 + ngrams[:, :, i]
350
+ h_val = h_val % self.table_size
351
+ h_val = (h_val * seed) % self.table_size
352
+ return h_val
353
+
354
+ def _retrieve(self, token_ids):
355
+ """Retrieve memory vectors for a batch of token sequences."""
356
+ bsz, seq_len = token_ids.shape
357
+ compressed = _tokenizer_compress(token_ids)
358
+
359
+ all_parts = []
360
+ head_counter = 0
361
+ for n in self.ngram_orders:
362
+ for h in range(self.num_heads):
363
+ key = f"e_{n}_{h}"
364
+ table = self.embed_tables[key]
365
+ indices = self._compute_indices(compressed, n, head_counter)
366
+ emb = table[indices.view(-1)]
367
+ all_parts.append(emb.view(bsz, seq_len, self.d_mem))
368
+ head_counter += 1
369
+
370
+ memory = torch.cat(all_parts, dim=-1)
371
+ return memory
372
+
373
+ def forward(self, hidden_states, token_ids):
374
+ mem = self._retrieve(token_ids)
375
+
376
+ q = hidden_states
377
+ k = self.W_k(mem)
378
+ v = self.W_v(mem)
379
+
380
+ q_norm = self.q_norm(q)
381
+ k_norm = self.k_norm(k)
382
+
383
+ alpha = torch.sigmoid(
384
+ (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
385
+ )
386
+
387
+ v_gated = alpha * v
388
+
389
+ v_gated_t = v_gated.transpose(1, 2)
390
+ conv_out = self.conv(v_gated_t)
391
+ conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
392
+ conv_out = conv_out.transpose(1, 2)
393
+
394
+ y = F.silu(conv_out) + v_gated
395
+
396
+ return y
397
+
398
+
399
+ # ---------------------------------------------------------------------------
400
+ # FFN Expert (dense)
401
+ # ---------------------------------------------------------------------------
402
+
403
+ class SpiderPortalExpert(nn.Module):
404
+ def __init__(self, config, intermediate_size=None):
405
+ super().__init__()
406
+ inter_size = intermediate_size or config.intermediate_size
407
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
408
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
409
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
410
+ self.act_fn = nn.SiLU()
411
+ def forward(self, hidden_states):
412
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
413
+
414
+
415
+ # ---------------------------------------------------------------------------
416
+ # Prelude/Coda Dense Layer (uses MLA)
417
+ # ---------------------------------------------------------------------------
418
+
419
+ class SpiderPortalDenseLayer(nn.Module):
420
+ """Prelude/coda dense layer with MLA attention."""
421
+ def __init__(self, config):
422
+ super().__init__()
423
+ self.self_attn = SpiderPortalMLA(config)
424
+ dense_intermediate = config.hidden_size * 4 // 3
425
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
426
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
428
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
429
+ attn_input = self.input_layernorm(hidden_states)
430
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
431
+ hidden_states = hidden_states + attn_output
432
+ ffn_input = self.post_attention_layernorm(hidden_states)
433
+ ffn_output = self.ffn(ffn_input)
434
+ hidden_states = hidden_states + ffn_output
435
+ return hidden_states, past_kv
436
+
437
+
438
+ # ---------------------------------------------------------------------------
439
+ # Recurrent Dense Layer (uses MLA + optional Engram)
440
+ # ---------------------------------------------------------------------------
441
+
442
+ class SpiderPortalRecurrentDenseLayer(nn.Module):
443
+ """Recurrent layer with MLA attention and optional Engram memory."""
444
+ def __init__(self, config, layer_idx, has_engram=False):
445
+ super().__init__()
446
+ self.layer_idx = layer_idx
447
+ self.has_engram = has_engram
448
+ self.self_attn = SpiderPortalMLA(config)
449
+ if has_engram:
450
+ self.engram = SpiderPortalEngram(config)
451
+ self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
452
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
453
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
454
+ self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None
455
+ def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
456
+ attn_input = self.input_layernorm(hidden_states)
457
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
458
+ hidden_states = hidden_states + attn_output
459
+
460
+ if self.has_engram and token_ids is not None:
461
+ engram_out = self.engram(hidden_states, token_ids)
462
+ hidden_states = hidden_states + engram_out
463
+ if self.post_engram_layernorm is not None:
464
+ hidden_states = self.post_engram_layernorm(hidden_states)
465
+
466
+ ffn_input = self.post_attention_layernorm(hidden_states)
467
+ ffn_output = self.ffn(ffn_input)
468
+ hidden_states = hidden_states + ffn_output
469
+ return hidden_states, 0.0, past_kv
470
+
471
+
472
+ # ---------------------------------------------------------------------------
473
+ # LTI Injection, ACT Halting, LoRA Adapter
474
+ # ---------------------------------------------------------------------------
475
+
476
+ class LTIInjection(nn.Module):
477
+ def __init__(self, config):
478
+ super().__init__()
479
+ self.hidden_size = config.hidden_size
480
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
481
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
482
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
483
+ with torch.no_grad():
484
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
485
+ def get_A(self):
486
+ return -torch.exp(self.log_A)
487
+ def forward(self, h_t, e):
488
+ A = self.get_A()
489
+ return A * h_t + self.B(e)
490
+
491
+
492
+ class ACTHalting(nn.Module):
493
+ def __init__(self, config):
494
+ super().__init__()
495
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
496
+ self.threshold = config.act_threshold
497
+ def forward(self, hidden_states):
498
+ return torch.sigmoid(self.halt_predictor(hidden_states))
499
+
500
+
501
+ class LoRAAdapter(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ rank = config.lora_rank
505
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
506
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
507
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
508
+ with torch.no_grad():
509
+ self.scale.weight.data.zero_()
510
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
511
+ def forward(self, x, loop_t):
512
+ max_t = self.scale.num_embeddings - 1
513
+ t_idx = min(loop_t, max_t)
514
+ s = self.scale(torch.tensor(t_idx, device=x.device))
515
+ down = self.down(x) * s
516
+ return down @ self.B
517
+
518
+
519
+ def checkpoint(func, *args, **kwargs):
520
+ """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
521
+ if torch.is_grad_enabled():
522
+ return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
523
+ return func(*args, **kwargs)
524
+
525
+
526
+ # ---------------------------------------------------------------------------
527
+ # Full Model
528
+ # ---------------------------------------------------------------------------
529
+
530
+ class SpiderPortalDenseModel(nn.Module):
531
+ """Full RDT model with MLA attention + Engram memory at layers 1,4.
532
+
533
+ Architecture:
534
+ 2x Prelude (MLA + dense FFN)
535
+ 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing
536
+ 2x Coda (MLA + dense FFN)
537
+ LTI Injection + ACT Halting + LoRA Adapter
538
+ """
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.config = config
542
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
543
+ self.recurrent_layers = nn.ModuleList([
544
+ SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers))
545
+ for i in range(config.num_hidden_layers)
546
+ ])
547
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
548
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
549
+ self.injection = LTIInjection(config)
550
+ self.act_halting = ACTHalting(config)
551
+ self.lora_adapter = LoRAAdapter(config)
552
+ self.loop_embed_dim = config.loop_embed_dim
553
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None):
554
+ n_loops = n_loops or self.config.max_loop_iters
555
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
556
+ for layer in self.prelude_layers:
557
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
558
+ e = hidden_states.clone()
559
+ B, T_seq, D = hidden_states.shape
560
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
561
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
562
+ h_out = torch.zeros_like(hidden_states)
563
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
564
+ for t in range(n_loops):
565
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
566
+ if t > 0:
567
+ injection = self.injection(hidden_states, input_embedding)
568
+ hidden_states = hidden_states + injection
569
+ new_past_key_values = []
570
+ for i, layer in enumerate(self.recurrent_layers):
571
+ hidden_states, aux_loss, past_kv = checkpoint(
572
+ layer, hidden_states,
573
+ token_ids=token_ids,
574
+ attention_mask=attention_mask,
575
+ position_ids=position_ids,
576
+ past_key_value=past_key_values[i] if t == 0 else None,
577
+ use_cache=use_cache
578
+ )
579
+ new_past_key_values.append(past_kv)
580
+ lora_delta = self.lora_adapter(hidden_states, t)
581
+ hidden_states = hidden_states + lora_delta
582
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
583
+ still_running = ~halted
584
+ remainder = (1.0 - cumulative_p).clamp(min=0)
585
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
586
+ weight = weight * still_running.to(hidden_states.dtype)
587
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
588
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
589
+ halted = halted | (cumulative_p >= self.config.act_threshold)
590
+ if halted.all() and not self.training:
591
+ break
592
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
593
+ hidden_states = h_out + never_halted * hidden_states
594
+ for layer in self.coda_layers:
595
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
596
+ hidden_states = self.norm(hidden_states)
597
+ return hidden_states, 0.0, new_past_key_values
598
+
599
+
600
+ class SpiderPortalForConditionalGeneration(nn.Module):
601
+ def __init__(self, config):
602
+ super().__init__()
603
+ self.config = config
604
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
605
+ self.model = SpiderPortalDenseModel(config)
606
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
607
+ if config.tie_word_embeddings:
608
+ self.lm_head.weight = self.embed_tokens.weight
609
+ self.apply(self._init_weights)
610
+ def _init_weights(self, module):
611
+ if isinstance(module, nn.Linear):
612
+ if hasattr(self, 'model') and module is self.model.injection.B:
613
+ return
614
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
615
+ if module.bias is not None:
616
+ module.bias.data.zero_()
617
+ elif isinstance(module, nn.Embedding):
618
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
619
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
620
+ hidden_states = self.embed_tokens(input_ids)
621
+ model_dtype = next(self.model.parameters()).dtype
622
+ hidden_states = hidden_states.to(model_dtype)
623
+ input_embedding = hidden_states.clone()
624
+ if attention_mask is None:
625
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
626
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
627
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
628
+ causal_mask = causal_mask.triu(1)
629
+ hidden_states, aux_loss, past_kv = self.model(
630
+ hidden_states, input_embedding=input_embedding,
631
+ attention_mask=causal_mask, position_ids=position_ids,
632
+ use_cache=use_cache, n_loops=n_loops, token_ids=input_ids
633
+ )
634
+ logits = self.lm_head(hidden_states)
635
+ loss = None
636
+ if labels is not None:
637
+ shift_logits = logits[..., :-1, :].contiguous()
638
+ shift_labels = labels[..., 1:].contiguous()
639
+ loss_fct = CrossEntropyLoss()
640
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
641
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
642
+ def get_num_params(self):
643
+ total = sum(p.numel() for p in self.parameters())
644
+ return {"total": total, "trainable": total}
645
+
646
+
647
+ # ---------------------------------------------------------------------------
648
+ # Dataset
649
+ # ---------------------------------------------------------------------------
650
+
651
+ class FineWebEduDataset(IterableDataset):
652
+ def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int):
653
+ self.tokenizer = tokenizer
654
+ self.seq_len = seq_len
655
+ self.subset = subset
656
+ self.rank = rank
657
+ self.world_size = world_size
658
+ def __iter__(self):
659
+ worker = get_worker_info()
660
+ num_workers = worker.num_workers if worker else 1
661
+ worker_id = worker.id if worker else 0
662
+ total_shards = self.world_size * num_workers
663
+ shard_index = self.rank * num_workers + worker_id
664
+ ds = load_dataset(
665
+ "HuggingFaceFW/fineweb-edu",
666
+ name=self.subset,
667
+ split="train",
668
+ streaming=True,
669
+ ).shard(num_shards=total_shards, index=shard_index)
670
+ buf = []
671
+ for sample in ds:
672
+ buf.extend(self.tokenizer.encode(sample["text"]))
673
+ while len(buf) >= self.seq_len + 1:
674
+ chunk = buf[: self.seq_len + 1]
675
+ buf = buf[self.seq_len + 1 :]
676
+ yield (
677
+ torch.tensor(chunk[:-1], dtype=torch.long),
678
+ torch.tensor(chunk[1:], dtype=torch.long),
679
+ )
680
+
681
+
682
+ # ---------------------------------------------------------------------------
683
+ # LR schedule
684
+ # ---------------------------------------------------------------------------
685
+
686
+ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
687
+ if step < warmup:
688
+ return max_lr * step / warmup
689
+ if step >= total:
690
+ return min_lr
691
+ decay = (step - warmup) / (total - warmup)
692
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
693
+
694
+
695
+ # ---------------------------------------------------------------------------
696
+ # Checkpointing
697
+ # ---------------------------------------------------------------------------
698
+
699
+ def save_weights_only(model, step, epoch, ckpt_dir, ddp):
700
+ if ddp:
701
+ with FSDP.state_dict_type(
702
+ model,
703
+ StateDictType.FULL_STATE_DICT,
704
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
705
+ ):
706
+ model_state = model.state_dict()
707
+ else:
708
+ model_state = model.state_dict()
709
+ ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
710
+ tmp_path = ckpt_path + ".tmp"
711
+ torch.save(model_state, tmp_path)
712
+ os.replace(tmp_path, ckpt_path)
713
+ size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
714
+ return ckpt_path, size_mb
715
+
716
+
717
+ def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
718
+ if ddp:
719
+ with FSDP.state_dict_type(
720
+ model,
721
+ StateDictType.FULL_STATE_DICT,
722
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
723
+ ):
724
+ model_state = model.state_dict()
725
+ optim_state = FSDP.optim_state_dict(model, optimizer)
726
+ else:
727
+ model_state = model.state_dict()
728
+ optim_state = optimizer.state_dict()
729
+ if not master:
730
+ return None, 0
731
+ os.makedirs(ckpt_dir, exist_ok=True)
732
+ final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
733
+ tmp_path = final_path + ".tmp"
734
+ torch.save(
735
+ {
736
+ "step": step,
737
+ "epoch": epoch,
738
+ "model_state_dict": model_state,
739
+ "optimizer_state_dict": optim_state,
740
+ "cfg": cfg,
741
+ "vocab_size": vocab_size,
742
+ },
743
+ tmp_path,
744
+ )
745
+ os.replace(tmp_path, final_path)
746
+ size_mb = os.path.getsize(final_path) / (1024 * 1024)
747
+ return final_path, size_mb
748
+
749
+
750
+ def load_checkpoint(model, optimizer, path, ddp):
751
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
752
+ if ddp:
753
+ with FSDP.state_dict_type(
754
+ model,
755
+ StateDictType.FULL_STATE_DICT,
756
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
757
+ ):
758
+ model.load_state_dict(ckpt["model_state_dict"])
759
+ optim_state = FSDP.optim_state_dict_to_load(
760
+ model=model,
761
+ optim=optimizer,
762
+ optim_state_dict=ckpt["optimizer_state_dict"],
763
+ )
764
+ optimizer.load_state_dict(optim_state)
765
+ else:
766
+ model.load_state_dict(ckpt["model_state_dict"])
767
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
768
+ return int(ckpt["step"]), int(ckpt.get("epoch", 0))
769
+
770
+
771
+ # ---------------------------------------------------------------------------
772
+ # Main
773
+ # ---------------------------------------------------------------------------
774
+
775
+ def main():
776
+ # ------------------------------------------------------------------
777
+ # Distributed init
778
+ # ------------------------------------------------------------------
779
+ ddp = int(os.environ.get("RANK", -1)) != -1
780
+ if ddp:
781
+ dist.init_process_group("nccl")
782
+ rank = int(os.environ["RANK"])
783
+ local_rank = int(os.environ["LOCAL_RANK"])
784
+ world_size = int(os.environ["WORLD_SIZE"])
785
+ device = f"cuda:{local_rank}"
786
+ torch.cuda.set_device(device)
787
+ else:
788
+ rank = local_rank = 0
789
+ world_size = 1
790
+ device = "cuda" if torch.cuda.is_available() else "cpu"
791
+ master = rank == 0
792
+ if master:
793
+ logger.info(
794
+ f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
795
+ )
796
+
797
+ # ------------------------------------------------------------------
798
+ # Tokenizer
799
+ # ------------------------------------------------------------------
800
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
801
+ tokenizer.pad_token = tokenizer.eos_token
802
+ vocab_size = tokenizer.vocab_size
803
+ if master:
804
+ logger.info(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
805
+
806
+ # ------------------------------------------------------------------
807
+ # Hyperparameters
808
+ # ------------------------------------------------------------------
809
+ seq_len = 2048
810
+ micro_batch = 16
811
+ target_tokens = 20_000_000_000
812
+ grad_accum = 2
813
+ global_batch_tok = world_size * micro_batch * grad_accum * seq_len
814
+ total_steps = target_tokens // global_batch_tok
815
+ warmup_steps = 200
816
+ lr = 3e-4
817
+ wd = 0.1
818
+ log_every = 10
819
+ ckpt_every = 500
820
+ ckpt_dir = "checkpoints-dense"
821
+ dataset_subset = "sample-10BT"
822
+
823
+ if master:
824
+ logger.info(
825
+ f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
826
+ f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
827
+ )
828
+ logger.info(
829
+ f"Attention: MLA (kv_lora_rank=128, sliding_window=4096) | "
830
+ f"Engram: layers [1,4] | Context: 32k | "
831
+ f"Gradient checkpointing: enabled"
832
+ )
833
+
834
+ # ------------------------------------------------------------------
835
+ # Model
836
+ # ------------------------------------------------------------------
837
+ cfg = SpiderPortalConfig(
838
+ hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
839
+ num_key_value_heads=4, intermediate_size=8192,
840
+ num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
841
+ router_aux_loss_coef=0.05, max_loop_iters=4,
842
+ prelude_layers=2, coda_layers=2, lora_rank=128,
843
+ rope_theta=10000000.0,
844
+ rope_scaling=None,
845
+ max_position_embeddings=32768, sliding_window=4096,
846
+ tie_word_embeddings=True,
847
+ kv_lora_rank=128, q_lora_rank=256,
848
+ qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64,
849
+ engram_layers=[1, 4],
850
+ engram_ngram_orders=(2, 3),
851
+ engram_hash_heads=4,
852
+ engram_table_size=65537,
853
+ engram_dim=128,
854
+ )
855
+ cfg.vocab_size = vocab_size
856
+
857
+ bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
858
+ amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
859
+
860
+ model = SpiderPortalForConditionalGeneration(cfg)
861
+
862
+ if ddp:
863
+ mp_policy = MixedPrecision(
864
+ param_dtype=amp_dtype,
865
+ reduce_dtype=amp_dtype,
866
+ buffer_dtype=amp_dtype,
867
+ )
868
+ wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
869
+ model = FSDP(
870
+ model,
871
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
872
+ mixed_precision=mp_policy,
873
+ auto_wrap_policy=wrap_policy,
874
+ device_id=local_rank,
875
+ )
876
+ else:
877
+ model = model.to(device)
878
+ amp_ctx = (
879
+ torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
880
+ if "cuda" in device
881
+ else nullcontext()
882
+ )
883
+
884
+ amp_ctx = nullcontext() if ddp else amp_ctx
885
+
886
+ if master:
887
+ n_params = sum(p.numel() for p in model.parameters())
888
+ engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n)
889
+ mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n)
890
+ embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n)
891
+ ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n)
892
+ other_params = n_params - engram_params - mla_params - embed_params - ffn_params
893
+ logger.info(
894
+ f"Parameters: {n_params:,} (all active) | "
895
+ f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | "
896
+ f"FFN: {ffn_params:,} | Engram: {engram_params:,} | "
897
+ f"Other: {other_params:,} | AMP dtype: {amp_dtype}"
898
+ )
899
+
900
+ # ------------------------------------------------------------------
901
+ # Optimizer — dual optimizer for Engram embeddings
902
+ # ------------------------------------------------------------------
903
+ engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed_tables' in n]
904
+ backbone_params = [p for n, p in model.named_parameters() if 'engram' not in n or 'embed_tables' not in n]
905
+
906
+ optimizer = torch.optim.AdamW(
907
+ backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
908
+ )
909
+ if engram_params_list:
910
+ engram_optimizer = torch.optim.Adam(
911
+ engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8
912
+ )
913
+ else:
914
+ engram_optimizer = None
915
+
916
+ # ------------------------------------------------------------------
917
+ # Resume from latest checkpoint
918
+ # ------------------------------------------------------------------
919
+ start_step = 0
920
+ start_epoch = 1
921
+ best_loss = float("inf")
922
+ existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
923
+ if existing_ckpts:
924
+ latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
925
+ if master:
926
+ logger.info(f"Resuming from checkpoint: {latest}")
927
+ start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
928
+ if master:
929
+ logger.success(f"Resumed at step {start_step}, epoch {start_epoch}")
930
+
931
+ # ------------------------------------------------------------------
932
+ # Dataset + DataLoader
933
+ # ------------------------------------------------------------------
934
+ dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size)
935
+ loader = DataLoader(dataset, batch_size=micro_batch, num_workers=8, pin_memory=True, prefetch_factor=2)
936
+
937
+ # ------------------------------------------------------------------
938
+ # Training loop
939
+ # ------------------------------------------------------------------
940
+ if master:
941
+ os.makedirs(ckpt_dir, exist_ok=True)
942
+
943
+ model.train()
944
+ data_iter = iter(loader)
945
+ t0 = time.perf_counter()
946
+ step = start_step
947
+ epoch = start_epoch
948
+ step_ckpt_files = []
949
+ tokens_in_epoch = 0
950
+ tokens_per_epoch = target_tokens
951
+
952
+ while step < total_steps:
953
+ cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
954
+ for g in optimizer.param_groups:
955
+ g["lr"] = cur_lr
956
+ if engram_optimizer:
957
+ for g in engram_optimizer.param_groups:
958
+ g["lr"] = cur_lr * 5
959
+
960
+ optimizer.zero_grad()
961
+ if engram_optimizer:
962
+ engram_optimizer.zero_grad()
963
+ loss_accum = 0.0
964
+
965
+ for micro_step in range(grad_accum):
966
+ try:
967
+ x, y = next(data_iter)
968
+ except StopIteration:
969
+ data_iter = iter(loader)
970
+ x, y = next(data_iter)
971
+
972
+ x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
973
+ y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
974
+
975
+ sync = (
976
+ nullcontext()
977
+ if (not ddp or micro_step == grad_accum - 1)
978
+ else model.no_sync()
979
+ )
980
+ with sync, amp_ctx:
981
+ output = model(x)
982
+ if isinstance(output, dict):
983
+ logits = output["logits"]
984
+ else:
985
+ logits = output
986
+ loss = nn.functional.cross_entropy(
987
+ logits.view(-1, vocab_size), y.view(-1)
988
+ )
989
+ loss = loss / grad_accum
990
+
991
+ loss.backward()
992
+ loss_accum += loss.item()
993
+
994
+ if ddp:
995
+ grad_norm = model.clip_grad_norm_(1.0)
996
+ else:
997
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
998
+ optimizer.step()
999
+ if engram_optimizer:
1000
+ engram_optimizer.step()
1001
+ step += 1
1002
+ tokens_in_epoch += global_batch_tok
1003
+
1004
+ if master and step % log_every == 0:
1005
+ dt = time.perf_counter() - t0
1006
+ tok_per_sec = global_batch_tok * log_every / dt
1007
+ tokens_seen = step * global_batch_tok
1008
+ logger.info(
1009
+ f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
1010
+ f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
1011
+ f"| {tok_per_sec / 1e6:.2f}M tok/s "
1012
+ f"| {tokens_seen / 1e9:.2f}B tokens seen"
1013
+ )
1014
+ t0 = time.perf_counter()
1015
+
1016
+ if step % ckpt_every == 0 and master:
1017
+ ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
1018
+ step_ckpt_files.append(ckpt_path)
1019
+ logger.info(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1020
+
1021
+ if tokens_in_epoch >= tokens_per_epoch:
1022
+ epoch_loss = loss_accum
1023
+ if master:
1024
+ epoch_time = (time.perf_counter() - t0) / 60
1025
+ logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
1026
+
1027
+ for f in step_ckpt_files:
1028
+ if os.path.exists(f):
1029
+ os.remove(f)
1030
+ logger.info(f" Deleted step checkpoint: {os.path.basename(f)}")
1031
+ step_ckpt_files.clear()
1032
+
1033
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
1034
+ if ckpt_path:
1035
+ logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1036
+
1037
+ if epoch_loss < best_loss:
1038
+ best_loss = epoch_loss
1039
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
1040
+ if ckpt_path:
1041
+ logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1042
+
1043
+ epoch += 1
1044
+ tokens_in_epoch = 0
1045
+
1046
+ if step > start_step and master:
1047
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
1048
+ if ckpt_path:
1049
+ logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
1050
+
1051
+ if ddp:
1052
+ dist.barrier()
1053
+ dist.destroy_process_group()
1054
+
1055
+ if master:
1056
+ logger.success("Training complete.")
1057
+
1058
+
1059
+ if __name__ == "__main__":
1060
+ main()