CLIWorks commited on
Commit
2ae4a50
·
verified ·
1 Parent(s): 057bcd8

Upload 2 files

Browse files
Files changed (2) hide show
  1. eval_dense.py +673 -0
  2. mythos-fineweb-dense.py +791 -0
eval_dense.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate SpiderPortal v5-Dense checkpoint with side-by-side MoE comparison.
4
+
5
+ Usage:
6
+ python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-final-ep1.pt --moe checkpoints/spiderportal-v5-final-ep1.pt --all
7
+ python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-ep1-step1000.pt --prompts "The cat sat on the"
8
+ """
9
+
10
+ import argparse
11
+ import math
12
+ import sys
13
+ import time
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from dataclasses import dataclass
18
+ from transformers import AutoTokenizer
19
+
20
+
21
+ @dataclass
22
+ class SpiderPortalConfig:
23
+ vocab_size: int = 50257
24
+ hidden_size: int = 384
25
+ num_hidden_layers: int = 8
26
+ num_attention_heads: int = 8
27
+ num_key_value_heads: int = 2
28
+ intermediate_size: int = 1024
29
+ num_experts: int = 64
30
+ num_experts_per_tok: int = 1
31
+ router_aux_loss_coef: float = 0.05
32
+ max_loop_iters: int = 1
33
+ act_threshold: float = 0.5
34
+ max_position_embeddings: int = 131072
35
+ rope_theta: float = 10000000.0
36
+ rope_scaling: dict = None
37
+ sliding_window: int = 4096
38
+ attention_dropout: float = 0.0
39
+ rms_norm_eps: float = 1e-6
40
+ initializer_range: float = 0.02
41
+ tie_word_embeddings: bool = True
42
+ prelude_layers: int = 2
43
+ coda_layers: int = 2
44
+ lora_rank: int = 32
45
+ loop_embed_dim: int = 48
46
+
47
+
48
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
49
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
50
+ angles = loop_t * freqs
51
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
52
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
53
+ emb_full[:loop_dim] = emb
54
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
55
+
56
+
57
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
58
+ dim = head_dim
59
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
60
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
61
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
62
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
63
+ return orig_inv_freq * scale
64
+
65
+
66
+ class SpiderPortalRMSNorm(nn.Module):
67
+ def __init__(self, hidden_size, eps=1e-6):
68
+ super().__init__()
69
+ self.weight = nn.Parameter(torch.ones(hidden_size))
70
+ self.variance_epsilon = eps
71
+ def forward(self, hidden_states):
72
+ input_dtype = hidden_states.dtype
73
+ hidden_states = hidden_states.to(torch.float32)
74
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
77
+
78
+
79
+ class SpiderPortalGQA(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.config = config
83
+ self.hidden_size = config.hidden_size
84
+ self.num_heads = config.num_attention_heads
85
+ self.num_kv_heads = config.num_key_value_heads
86
+ self.head_dim = config.hidden_size // config.num_attention_heads
87
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
88
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
89
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
90
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
91
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
92
+ self.attention_dropout = config.attention_dropout
93
+ rope_scaling = getattr(config, 'rope_scaling', None)
94
+ if rope_scaling and rope_scaling.get("type") == "yarn":
95
+ factor = rope_scaling.get("factor", 1.0)
96
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
97
+ inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
98
+ else:
99
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
100
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
101
+ def _rotate_half(self, x):
102
+ x1 = x[..., :x.shape[-1] // 2]
103
+ x2 = x[..., x.shape[-1] // 2:]
104
+ return torch.cat((-x2, x1), dim=-1)
105
+ def _apply_rotary(self, x, cos, sin):
106
+ return (x * cos) + (self._rotate_half(x) * sin)
107
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
108
+ bsz, q_len, _ = hidden_states.size()
109
+ query_states = self.q_proj(hidden_states)
110
+ key_states = self.k_proj(hidden_states)
111
+ value_states = self.v_proj(hidden_states)
112
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
113
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
114
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
115
+ if position_ids is None:
116
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
117
+ max_pos = position_ids.max().item() + 1
118
+ seq_len = max(max_pos, q_len)
119
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
120
+ freqs = torch.outer(t, self.inv_freq)
121
+ emb = torch.cat((freqs, freqs), dim=-1)
122
+ cos, sin = emb.cos(), emb.sin()
123
+ cos = cos[position_ids].unsqueeze(1)
124
+ sin = sin[position_ids].unsqueeze(1)
125
+ query_states = self._apply_rotary(query_states, cos, sin)
126
+ key_states = self._apply_rotary(key_states, cos, sin)
127
+ if past_key_value is not None:
128
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
129
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
130
+ past_kv = (key_states, value_states) if use_cache else None
131
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
132
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
133
+ attn_output = F.scaled_dot_product_attention(
134
+ query_states, key_states, value_states,
135
+ attn_mask=attention_mask,
136
+ dropout_p=self.attention_dropout if self.training else 0.0,
137
+ is_causal=attention_mask is None
138
+ )
139
+ attn_output = attn_output.transpose(1, 2).contiguous()
140
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
141
+ return self.o_proj(attn_output), past_kv
142
+
143
+
144
+ class SpiderPortalExpert(nn.Module):
145
+ def __init__(self, config, intermediate_size=None):
146
+ super().__init__()
147
+ inter_size = intermediate_size or config.intermediate_size
148
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
149
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
150
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
151
+ self.act_fn = nn.SiLU()
152
+ def forward(self, hidden_states):
153
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
154
+
155
+
156
+ class SpiderPortalDenseLayer(nn.Module):
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ self.self_attn = SpiderPortalGQA(config)
160
+ dense_intermediate = config.hidden_size * 4 // 3
161
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
162
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
163
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
164
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
165
+ attn_input = self.input_layernorm(hidden_states)
166
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
167
+ hidden_states = hidden_states + attn_output
168
+ ffn_input = self.post_attention_layernorm(hidden_states)
169
+ ffn_output = self.ffn(ffn_input)
170
+ hidden_states = hidden_states + ffn_output
171
+ return hidden_states, past_kv
172
+
173
+
174
+ class SpiderPortalRecurrentDenseLayer(nn.Module):
175
+ """Dense recurrent layer — matches checkpoint keys."""
176
+ def __init__(self, config, layer_idx):
177
+ super().__init__()
178
+ self.layer_idx = layer_idx
179
+ self.self_attn = SpiderPortalGQA(config)
180
+ self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
181
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
182
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
183
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
184
+ attn_input = self.input_layernorm(hidden_states)
185
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
186
+ hidden_states = hidden_states + attn_output
187
+ ffn_input = self.post_attention_layernorm(hidden_states)
188
+ ffn_output = self.ffn(ffn_input)
189
+ hidden_states = hidden_states + ffn_output
190
+ return hidden_states, 0.0, past_kv
191
+
192
+
193
+ # MoE layer for comparison model
194
+ class SpiderPortalRouter(nn.Module):
195
+ def __init__(self, config):
196
+ super().__init__()
197
+ self.num_experts = config.num_experts
198
+ self.num_experts_per_tok = config.num_experts_per_tok
199
+ self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
200
+ self.register_buffer("router_bias", torch.zeros(config.num_experts))
201
+ def forward(self, hidden_states):
202
+ router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
203
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
204
+ biased_logits = router_logits + self.router_bias
205
+ biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
206
+ top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
207
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
208
+ top_weights = top_weights.to(hidden_states.dtype)
209
+ mean_probs = routing_weights.mean(dim=0)
210
+ aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
211
+ return top_weights, top_indices, aux_loss
212
+
213
+
214
+ class SpiderPortalMoE(nn.Module):
215
+ def __init__(self, config):
216
+ super().__init__()
217
+ self.config = config
218
+ self.num_experts = config.num_experts
219
+ self.num_experts_per_tok = config.num_experts_per_tok
220
+ self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
221
+ self.shared_expert = SpiderPortalExpert(config)
222
+ self.router = SpiderPortalRouter(config)
223
+ def forward(self, hidden_states):
224
+ batch_size, seq_len, hidden_dim = hidden_states.shape
225
+ top_weights, top_indices, aux_loss = self.router(hidden_states)
226
+ flat_hidden = hidden_states.view(-1, hidden_dim)
227
+ final_output = torch.zeros_like(flat_hidden)
228
+ for expert_idx in range(self.num_experts_per_tok):
229
+ expert_ids = top_indices[:, expert_idx]
230
+ expert_weights = top_weights[:, expert_idx:expert_idx+1]
231
+ for e in range(self.num_experts):
232
+ mask = expert_ids == e
233
+ if mask.any():
234
+ expert_output = self.experts[e](flat_hidden[mask])
235
+ final_output[mask] += expert_output * expert_weights[mask]
236
+ shared_output = self.shared_expert(flat_hidden)
237
+ final_output = final_output + shared_output
238
+ return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
239
+
240
+
241
+ class SpiderPortalMoELayer(nn.Module):
242
+ def __init__(self, config, layer_idx):
243
+ super().__init__()
244
+ self.layer_idx = layer_idx
245
+ self.self_attn = SpiderPortalGQA(config)
246
+ self.moe = SpiderPortalMoE(config)
247
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
250
+ attn_input = self.input_layernorm(hidden_states)
251
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
252
+ hidden_states = hidden_states + attn_output
253
+ moe_input = self.post_attention_layernorm(hidden_states)
254
+ moe_output, aux_loss = self.moe(moe_input)
255
+ hidden_states = hidden_states + moe_output
256
+ return hidden_states, aux_loss, past_kv
257
+
258
+
259
+ class LTIInjection(nn.Module):
260
+ def __init__(self, config):
261
+ super().__init__()
262
+ self.hidden_size = config.hidden_size
263
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
264
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
265
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
266
+ with torch.no_grad():
267
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
268
+ def get_A(self):
269
+ return -torch.exp(self.log_A)
270
+ def forward(self, h_t, e):
271
+ A = self.get_A()
272
+ return A * h_t + self.B(e)
273
+
274
+
275
+ class ACTHalting(nn.Module):
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
279
+ self.threshold = config.act_threshold
280
+ def forward(self, hidden_states):
281
+ return torch.sigmoid(self.halt_predictor(hidden_states))
282
+
283
+
284
+ class LoRAAdapter(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ rank = config.lora_rank
288
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
289
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
290
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
291
+ with torch.no_grad():
292
+ self.scale.weight.data.zero_()
293
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
294
+ def forward(self, x, loop_t):
295
+ max_t = self.scale.num_embeddings - 1
296
+ t_idx = min(loop_t, max_t)
297
+ s = self.scale(torch.tensor(t_idx, device=x.device))
298
+ down = self.down(x) * s
299
+ return down @ self.B
300
+
301
+
302
+ class SpiderPortalDenseModel(nn.Module):
303
+ def __init__(self, config):
304
+ super().__init__()
305
+ self.config = config
306
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
307
+ self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)])
308
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
309
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
310
+ self.injection = LTIInjection(config)
311
+ self.act_halting = ACTHalting(config)
312
+ self.lora_adapter = LoRAAdapter(config)
313
+ self.loop_embed_dim = config.loop_embed_dim
314
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
315
+ n_loops = n_loops or self.config.max_loop_iters
316
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
317
+ for layer in self.prelude_layers:
318
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
319
+ e = hidden_states.clone()
320
+ B, T_seq, D = hidden_states.shape
321
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
322
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
323
+ h_out = torch.zeros_like(hidden_states)
324
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
325
+ for t in range(n_loops):
326
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
327
+ if t > 0:
328
+ injection = self.injection(hidden_states, input_embedding)
329
+ hidden_states = hidden_states + injection
330
+ new_past_key_values = []
331
+ for i, layer in enumerate(self.recurrent_layers):
332
+ hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
333
+ new_past_key_values.append(past_kv)
334
+ lora_delta = self.lora_adapter(hidden_states, t)
335
+ hidden_states = hidden_states + lora_delta
336
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
337
+ still_running = ~halted
338
+ remainder = (1.0 - cumulative_p).clamp(min=0)
339
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
340
+ weight = weight * still_running.to(hidden_states.dtype)
341
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
342
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
343
+ halted = halted | (cumulative_p >= self.config.act_threshold)
344
+ if halted.all() and not self.training:
345
+ break
346
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
347
+ hidden_states = h_out + never_halted * hidden_states
348
+ for layer in self.coda_layers:
349
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
350
+ hidden_states = self.norm(hidden_states)
351
+ return hidden_states, 0.0, new_past_key_values
352
+
353
+
354
+ class SpiderPortalMoEModel(nn.Module):
355
+ def __init__(self, config):
356
+ super().__init__()
357
+ self.config = config
358
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
359
+ self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
360
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
361
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
362
+ self.injection = LTIInjection(config)
363
+ self.act_halting = ACTHalting(config)
364
+ self.lora_adapter = LoRAAdapter(config)
365
+ self.loop_embed_dim = config.loop_embed_dim
366
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
367
+ n_loops = n_loops or self.config.max_loop_iters
368
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
369
+ total_aux_loss = 0.0
370
+ for layer in self.prelude_layers:
371
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
372
+ e = hidden_states.clone()
373
+ B, T_seq, D = hidden_states.shape
374
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
375
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
376
+ h_out = torch.zeros_like(hidden_states)
377
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
378
+ for t in range(n_loops):
379
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
380
+ if t > 0:
381
+ injection = self.injection(hidden_states, input_embedding)
382
+ hidden_states = hidden_states + injection
383
+ new_past_key_values = []
384
+ for i, layer in enumerate(self.recurrent_layers):
385
+ hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
386
+ total_aux_loss = total_aux_loss + aux_loss
387
+ new_past_key_values.append(past_kv)
388
+ lora_delta = self.lora_adapter(hidden_states, t)
389
+ hidden_states = hidden_states + lora_delta
390
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
391
+ still_running = ~halted
392
+ remainder = (1.0 - cumulative_p).clamp(min=0)
393
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
394
+ weight = weight * still_running.to(hidden_states.dtype)
395
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
396
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
397
+ halted = halted | (cumulative_p >= self.config.act_threshold)
398
+ if halted.all() and not self.training:
399
+ break
400
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
401
+ hidden_states = h_out + never_halted * hidden_states
402
+ for layer in self.coda_layers:
403
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
404
+ hidden_states = self.norm(hidden_states)
405
+ return hidden_states, total_aux_loss, new_past_key_values
406
+
407
+
408
+ class SpiderPortalForConditionalGeneration(nn.Module):
409
+ def __init__(self, config, model_class=SpiderPortalDenseModel):
410
+ super().__init__()
411
+ self.config = config
412
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
413
+ self.model = model_class(config)
414
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
415
+ if config.tie_word_embeddings:
416
+ self.lm_head.weight = self.embed_tokens.weight
417
+ self.apply(self._init_weights)
418
+ def _init_weights(self, module):
419
+ if isinstance(module, nn.Linear):
420
+ if hasattr(self, 'model') and module is self.model.injection.B:
421
+ return
422
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
423
+ if module.bias is not None:
424
+ module.bias.data.zero_()
425
+ elif isinstance(module, nn.Embedding):
426
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
427
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
428
+ hidden_states = self.embed_tokens(input_ids)
429
+ model_dtype = next(self.model.parameters()).dtype
430
+ hidden_states = hidden_states.to(model_dtype)
431
+ input_embedding = hidden_states.clone()
432
+ if attention_mask is None:
433
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
434
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
435
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
436
+ causal_mask = causal_mask.triu(1)
437
+ hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
438
+ logits = self.lm_head(hidden_states)
439
+ return {"loss": None, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
440
+
441
+
442
+ DEFAULT_PROMPTS = [
443
+ "The cat sat on the",
444
+ "The capital of France is",
445
+ "If I have 3 apples and eat 1, I have",
446
+ "Once upon a time, there was a",
447
+ "Python is a programming language that",
448
+ "Two plus two equals",
449
+ "When it rains, the ground gets",
450
+ "The door opened slowly and",
451
+ "What is the meaning of life? The",
452
+ "def fibonacci(n):\n if n <= 1:\n return",
453
+ ]
454
+
455
+
456
+ def load_model(checkpoint_path, device="cpu", model_class=SpiderPortalDenseModel):
457
+ print(f"Loading checkpoint: {checkpoint_path}")
458
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
459
+
460
+ cfg = ckpt.get("cfg")
461
+ vocab_size = ckpt.get("vocab_size", 50257)
462
+
463
+ if cfg is None:
464
+ cfg = SpiderPortalConfig(
465
+ hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
466
+ num_key_value_heads=2, intermediate_size=1024,
467
+ num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
468
+ router_aux_loss_coef=0.05, max_loop_iters=1,
469
+ prelude_layers=2, coda_layers=2, lora_rank=32,
470
+ rope_theta=10000000.0,
471
+ rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
472
+ max_position_embeddings=131072, sliding_window=4096,
473
+ tie_word_embeddings=True,
474
+ )
475
+ cfg.vocab_size = vocab_size
476
+
477
+ model_state = ckpt.get("model_state_dict", ckpt)
478
+ model = SpiderPortalForConditionalGeneration(cfg, model_class=model_class)
479
+
480
+ missing, unexpected = model.load_state_dict(model_state, strict=False)
481
+ if missing:
482
+ print(f" Missing keys ({len(missing)}): {missing[:3]}...")
483
+ if unexpected:
484
+ print(f" Unexpected keys ({len(unexpected)}): {unexpected[:3]}...")
485
+ if not missing and not unexpected:
486
+ print(" All keys matched perfectly")
487
+
488
+ model = model.to(device)
489
+ model.eval()
490
+
491
+ n_params = sum(p.numel() for p in model.parameters())
492
+ print(f" Parameters: {n_params:,} on {device}")
493
+
494
+ return model, cfg
495
+
496
+
497
+ def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9, device="cpu"):
498
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
499
+
500
+ generated = []
501
+ with torch.no_grad():
502
+ for _ in range(max_new_tokens):
503
+ outputs = model(input_ids, use_cache=False)
504
+ logits = outputs["logits"][0, -1, :]
505
+
506
+ if temperature > 0:
507
+ logits = logits / temperature
508
+ probs = F.softmax(logits, dim=-1)
509
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
510
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
511
+ sorted_indices_to_remove = cumulative_probs > top_p
512
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
513
+ sorted_indices_to_remove[0] = False
514
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
515
+ probs[indices_to_remove] = 0.0
516
+ probs = probs / probs.sum()
517
+ next_token = torch.multinomial(probs, 1)
518
+ else:
519
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
520
+
521
+ generated.append(next_token.item())
522
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
523
+
524
+ if next_token.item() == tokenizer.eos_token_id:
525
+ break
526
+
527
+ return tokenizer.decode(generated, skip_special_tokens=True)
528
+
529
+
530
+ def analyze_output(prompt, generated_text):
531
+ full = prompt + generated_text
532
+ words = full.split()
533
+ unique_words = set(w.lower() for w in words)
534
+ vocab_diversity = len(unique_words) / max(len(words), 1)
535
+
536
+ n = 4
537
+ if len(words) >= n:
538
+ ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
539
+ unique_ngrams = set(ngrams)
540
+ repetition_rate = 1.0 - len(unique_ngrams) / max(len(ngrams), 1)
541
+ else:
542
+ repetition_rate = 0.0
543
+
544
+ has_repetition = False
545
+ for pattern in ["... ", "!!!", " and and ", " the the ", " is is "]:
546
+ if pattern in full.lower():
547
+ has_repetition = True
548
+ break
549
+
550
+ english_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ '.,!?;:-\"()")
551
+ char_ratio = sum(1 for c in generated_text if c in english_chars) / max(len(generated_text), 1)
552
+
553
+ return {
554
+ "total_words": len(words),
555
+ "unique_words": len(unique_words),
556
+ "vocab_diversity": vocab_diversity,
557
+ "repetition_rate": repetition_rate,
558
+ "has_obvious_repetition": has_repetition,
559
+ "english_char_ratio": char_ratio,
560
+ }
561
+
562
+
563
+ def main():
564
+ parser = argparse.ArgumentParser(description="Evaluate SpiderPortal Dense vs MoE")
565
+ parser.add_argument("--dense", required=True, help="Path to dense checkpoint")
566
+ parser.add_argument("--moe", default=None, help="Path to MoE checkpoint for comparison")
567
+ parser.add_argument("--prompts", nargs="*", default=None)
568
+ parser.add_argument("--file", default=None, help="File with prompts")
569
+ parser.add_argument("--all", action="store_true", help="Run default prompt suite")
570
+ parser.add_argument("--max-new-tokens", type=int, default=80)
571
+ parser.add_argument("--temperature", type=float, default=0.8)
572
+ parser.add_argument("--top-p", type=float, default=0.9)
573
+ parser.add_argument("--device", default=None)
574
+ args = parser.parse_args()
575
+
576
+ device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
577
+ print(f"Device: {device}")
578
+
579
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
580
+ tokenizer.pad_token = tokenizer.eos_token
581
+
582
+ prompts = []
583
+ if args.all:
584
+ prompts = DEFAULT_PROMPTS
585
+ elif args.prompts:
586
+ prompts = args.prompts
587
+ elif args.file:
588
+ with open(args.file) as f:
589
+ prompts = [line.strip() for line in f if line.strip()]
590
+ else:
591
+ prompts = DEFAULT_PROMPTS[:3]
592
+
593
+ dense_model, _ = load_model(args.dense, device, model_class=SpiderPortalDenseModel)
594
+
595
+ moe_model = None
596
+ if args.moe:
597
+ print()
598
+ moe_model, _ = load_model(args.moe, device, model_class=SpiderPortalMoEModel)
599
+
600
+ print(f"\nRunning {len(prompts)} prompts (max_new_tokens={args.max_new_tokens}, temp={args.temperature})\n")
601
+ print("=" * 80)
602
+
603
+ dense_results = []
604
+ moe_results = []
605
+
606
+ for i, prompt in enumerate(prompts):
607
+ print(f"\n[Prompt {i+1}/{len(prompts)}]: {prompt}")
608
+
609
+ t0 = time.time()
610
+ dense_gen = generate(dense_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device)
611
+ dense_elapsed = time.time() - t0
612
+ dense_metrics = analyze_output(prompt, dense_gen)
613
+
614
+ print(f" [DENSE] {dense_gen}")
615
+ print(f" vocab_div={dense_metrics['vocab_diversity']:.2f}, "
616
+ f"repetition={dense_metrics['repetition_rate']:.2f}, "
617
+ f"english={dense_metrics['english_char_ratio']:.2f}, "
618
+ f"tok/s={args.max_new_tokens/max(dense_elapsed,0.001):.1f}")
619
+
620
+ if moe_model:
621
+ t0 = time.time()
622
+ moe_gen = generate(moe_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device)
623
+ moe_elapsed = time.time() - t0
624
+ moe_metrics = analyze_output(prompt, moe_gen)
625
+
626
+ print(f" [MoE ] {moe_gen}")
627
+ print(f" vocab_div={moe_metrics['vocab_diversity']:.2f}, "
628
+ f"repetition={moe_metrics['repetition_rate']:.2f}, "
629
+ f"english={moe_metrics['english_char_ratio']:.2f}, "
630
+ f"tok/s={args.max_new_tokens/max(moe_elapsed,0.001):.1f}")
631
+
632
+ moe_results.append({"prompt": prompt, "generated": moe_gen, "metrics": moe_metrics})
633
+
634
+ dense_results.append({"prompt": prompt, "generated": dense_gen, "metrics": dense_metrics})
635
+
636
+ print("\n" + "=" * 80)
637
+ print("SUMMARY")
638
+ print("=" * 80)
639
+
640
+ def print_summary(label, results):
641
+ avg_vocab = sum(r["metrics"]["vocab_diversity"] for r in results) / len(results)
642
+ avg_rep = sum(r["metrics"]["repetition_rate"] for r in results) / len(results)
643
+ avg_eng = sum(r["metrics"]["english_char_ratio"] for r in results) / len(results)
644
+ total_rep = sum(1 for r in results if r["metrics"]["has_obvious_repetition"])
645
+ print(f"\n{label}:")
646
+ print(f" Vocab diversity: {avg_vocab:.2f}")
647
+ print(f" Repetition rate: {avg_rep:.2f}")
648
+ print(f" English chars: {avg_eng:.2f}")
649
+ print(f" Repetition hits: {total_rep}/{len(results)}")
650
+
651
+ print_summary("DENSE", dense_results)
652
+ if moe_results:
653
+ print_summary("MoE ", moe_results)
654
+
655
+ print("\nComparison:")
656
+ d_vocab = sum(r["metrics"]["vocab_diversity"] for r in dense_results) / len(dense_results)
657
+ m_vocab = sum(r["metrics"]["vocab_diversity"] for r in moe_results) / len(moe_results)
658
+ d_eng = sum(r["metrics"]["english_char_ratio"] for r in dense_results) / len(dense_results)
659
+ m_eng = sum(r["metrics"]["english_char_ratio"] for r in moe_results) / len(moe_results)
660
+
661
+ if d_vocab > m_vocab:
662
+ print(f" Dense has better vocabulary diversity (+{d_vocab - m_vocab:.2f})")
663
+ else:
664
+ print(f" MoE has better vocabulary diversity (+{m_vocab - d_vocab:.2f})")
665
+
666
+ if d_eng > m_eng:
667
+ print(f" Dense produces more English-like text (+{d_eng - m_eng:.2f})")
668
+ else:
669
+ print(f" MoE produces more English-like text (+{m_eng - d_eng:.2f})")
670
+
671
+
672
+ if __name__ == "__main__":
673
+ main()
mythos-fineweb-dense.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW.
4
+
5
+ Optimized dense variant — MoE replaced with single FFN per recurrent layer.
6
+ Same RDT architecture (2 prelude + 8 recurrent + 2 coda) but all parameters
7
+ active every forward pass. Designed for fast convergence on English.
8
+
9
+ Performance optimizations:
10
+ - Dense FFN instead of MoE (eliminates Python expert loop)
11
+ - torch.compile with reduce-overhead mode
12
+ - F.scaled_dot_product_attention (flash attention auto-selected)
13
+ - Gradient checkpointing on recurrent layers (saves ~40% VRAM)
14
+ - Larger micro_batch (128) with minimal grad_accum (2)
15
+
16
+ Single GPU:
17
+ python mythos-fineweb-dense.py
18
+
19
+ Multi-GPU:
20
+ torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py
21
+
22
+ Dense-to-MoE conversion (after training):
23
+ Each recurrent layer's ffn weights are split into 64 chunks to initialize
24
+ MoE experts. Attention layers, norms, and loop infrastructure carry over.
25
+ """
26
+ import os
27
+ import math
28
+ import time
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import torch.distributed as dist
33
+ from loguru import logger
34
+ from torch.distributed.fsdp import (
35
+ FullyShardedDataParallel as FSDP,
36
+ ShardingStrategy,
37
+ MixedPrecision,
38
+ FullStateDictConfig,
39
+ StateDictType,
40
+ )
41
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
42
+ from torch.utils.data import IterableDataset, DataLoader, get_worker_info
43
+ from contextlib import nullcontext
44
+ from dataclasses import dataclass
45
+ from typing import Optional, Tuple, Dict, List
46
+ from torch.nn import CrossEntropyLoss
47
+ from datasets import load_dataset
48
+ from transformers import AutoTokenizer
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # SpiderPortal Model Architecture (Dense variant)
53
+ # ---------------------------------------------------------------------------
54
+
55
+ @dataclass
56
+ class SpiderPortalConfig:
57
+ vocab_size: int = 50257
58
+ hidden_size: int = 384
59
+ num_hidden_layers: int = 8
60
+ num_attention_heads: int = 8
61
+ num_key_value_heads: int = 2
62
+ intermediate_size: int = 1024
63
+ hidden_act: str = "silu"
64
+ num_experts: int = 64
65
+ num_experts_per_tok: int = 1
66
+ num_shared_experts: int = 1
67
+ router_aux_loss_coef: float = 0.05
68
+ max_loop_iters: int = 1
69
+ act_threshold: float = 0.5
70
+ max_position_embeddings: int = 131072
71
+ rope_theta: float = 10000000.0
72
+ rope_scaling: dict = None
73
+ sliding_window: int = 4096
74
+ attention_dropout: float = 0.0
75
+ rms_norm_eps: float = 1e-6
76
+ initializer_range: float = 0.02
77
+ use_cache: bool = True
78
+ tie_word_embeddings: bool = True
79
+ prelude_layers: int = 2
80
+ coda_layers: int = 2
81
+ lora_rank: int = 32
82
+ loop_embed_dim: int = 48
83
+ vision_hidden_size: int = 384
84
+ audio_hidden_size: int = 512
85
+ vision_num_frames: int = 60
86
+ vision_tokens_per_frame: int = 256
87
+ vision_temporal_tokens: int = 64
88
+ vision_temporal_layers: int = 2
89
+ model_type: str = "spiderportal"
90
+ torch_dtype: str = "bfloat16"
91
+
92
+
93
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
94
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
95
+ angles = loop_t * freqs
96
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
97
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
98
+ emb_full[:loop_dim] = emb
99
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
100
+
101
+
102
+ class SpiderPortalRMSNorm(nn.Module):
103
+ def __init__(self, hidden_size, eps=1e-6):
104
+ super().__init__()
105
+ self.weight = nn.Parameter(torch.ones(hidden_size))
106
+ self.variance_epsilon = eps
107
+ def forward(self, hidden_states):
108
+ input_dtype = hidden_states.dtype
109
+ hidden_states = hidden_states.to(torch.float32)
110
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
111
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
112
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
113
+
114
+
115
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
116
+ dim = head_dim
117
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
118
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
119
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
120
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
121
+ return orig_inv_freq * scale
122
+
123
+
124
+ class SpiderPortalGQA(nn.Module):
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ self.config = config
128
+ self.hidden_size = config.hidden_size
129
+ self.num_heads = config.num_attention_heads
130
+ self.num_kv_heads = config.num_key_value_heads
131
+ self.head_dim = config.hidden_size // config.num_attention_heads
132
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
133
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
134
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
135
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
136
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
137
+ self.attention_dropout = config.attention_dropout
138
+ rope_scaling = getattr(config, 'rope_scaling', None)
139
+ if rope_scaling and rope_scaling.get("type") == "yarn":
140
+ factor = rope_scaling.get("factor", 1.0)
141
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
142
+ inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
143
+ else:
144
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
145
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
146
+ def _rotate_half(self, x):
147
+ x1 = x[..., :x.shape[-1] // 2]
148
+ x2 = x[..., x.shape[-1] // 2:]
149
+ return torch.cat((-x2, x1), dim=-1)
150
+ def _apply_rotary(self, x, cos, sin):
151
+ return (x * cos) + (self._rotate_half(x) * sin)
152
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
153
+ bsz, q_len, _ = hidden_states.size()
154
+ query_states = self.q_proj(hidden_states)
155
+ key_states = self.k_proj(hidden_states)
156
+ value_states = self.v_proj(hidden_states)
157
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
158
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
159
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
160
+ if position_ids is None:
161
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
162
+ max_pos = position_ids.max().item() + 1
163
+ seq_len = max(max_pos, q_len)
164
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
165
+ freqs = torch.outer(t, self.inv_freq)
166
+ emb = torch.cat((freqs, freqs), dim=-1)
167
+ cos, sin = emb.cos(), emb.sin()
168
+ cos = cos[position_ids].unsqueeze(1)
169
+ sin = sin[position_ids].unsqueeze(1)
170
+ query_states = self._apply_rotary(query_states, cos, sin)
171
+ key_states = self._apply_rotary(key_states, cos, sin)
172
+ if past_key_value is not None:
173
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
174
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
175
+ past_kv = (key_states, value_states) if use_cache else None
176
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
177
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
178
+ attn_output = F.scaled_dot_product_attention(
179
+ query_states, key_states, value_states,
180
+ attn_mask=attention_mask,
181
+ dropout_p=self.attention_dropout if self.training else 0.0,
182
+ is_causal=attention_mask is None
183
+ )
184
+ attn_output = attn_output.transpose(1, 2).contiguous()
185
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
186
+ return self.o_proj(attn_output), past_kv
187
+
188
+
189
+ class SpiderPortalExpert(nn.Module):
190
+ def __init__(self, config, intermediate_size=None):
191
+ super().__init__()
192
+ inter_size = intermediate_size or config.intermediate_size
193
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
194
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
195
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
196
+ self.act_fn = nn.SiLU()
197
+ def forward(self, hidden_states):
198
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
199
+
200
+
201
+ class SpiderPortalDenseLayer(nn.Module):
202
+ """Prelude/coda dense layer. intermediate_size=512 (4/3 * hidden_size)."""
203
+ def __init__(self, config):
204
+ super().__init__()
205
+ self.self_attn = SpiderPortalGQA(config)
206
+ dense_intermediate = config.hidden_size * 4 // 3
207
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
208
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
210
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
211
+ attn_input = self.input_layernorm(hidden_states)
212
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
213
+ hidden_states = hidden_states + attn_output
214
+ ffn_input = self.post_attention_layernorm(hidden_states)
215
+ ffn_output = self.ffn(ffn_input)
216
+ hidden_states = hidden_states + ffn_output
217
+ return hidden_states, past_kv
218
+
219
+
220
+ class SpiderPortalRecurrentDenseLayer(nn.Module):
221
+ """Recurrent layer with DENSE FFN (not MoE). intermediate_size=1024.
222
+
223
+ This replaces SpiderPortalMoELayer. After dense training converges,
224
+ the ffn weights can be split into 64 chunks to initialize MoE experts.
225
+ """
226
+ def __init__(self, config, layer_idx):
227
+ super().__init__()
228
+ self.layer_idx = layer_idx
229
+ self.self_attn = SpiderPortalGQA(config)
230
+ self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
231
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
234
+ attn_input = self.input_layernorm(hidden_states)
235
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
236
+ hidden_states = hidden_states + attn_output
237
+ ffn_input = self.post_attention_layernorm(hidden_states)
238
+ ffn_output = self.ffn(ffn_input)
239
+ hidden_states = hidden_states + ffn_output
240
+ return hidden_states, 0.0, past_kv
241
+
242
+
243
+ class LTIInjection(nn.Module):
244
+ def __init__(self, config):
245
+ super().__init__()
246
+ self.hidden_size = config.hidden_size
247
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
248
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
249
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
250
+ with torch.no_grad():
251
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
252
+ def get_A(self):
253
+ return -torch.exp(self.log_A)
254
+ def forward(self, h_t, e):
255
+ A = self.get_A()
256
+ return A * h_t + self.B(e)
257
+
258
+
259
+ class ACTHalting(nn.Module):
260
+ def __init__(self, config):
261
+ super().__init__()
262
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
263
+ self.threshold = config.act_threshold
264
+ def forward(self, hidden_states):
265
+ return torch.sigmoid(self.halt_predictor(hidden_states))
266
+
267
+
268
+ class LoRAAdapter(nn.Module):
269
+ def __init__(self, config):
270
+ super().__init__()
271
+ rank = config.lora_rank
272
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
273
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
274
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
275
+ with torch.no_grad():
276
+ self.scale.weight.data.zero_()
277
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
278
+ def forward(self, x, loop_t):
279
+ max_t = self.scale.num_embeddings - 1
280
+ t_idx = min(loop_t, max_t)
281
+ s = self.scale(torch.tensor(t_idx, device=x.device))
282
+ down = self.down(x) * s
283
+ return down @ self.B
284
+
285
+
286
+ def checkpoint(func, *args, **kwargs):
287
+ """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost."""
288
+ if torch.is_grad_enabled():
289
+ return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
290
+ return func(*args, **kwargs)
291
+
292
+
293
+ class SpiderPortalDenseModel(nn.Module):
294
+ """Full RDT model with DENSE recurrent layers (no MoE).
295
+
296
+ Architecture:
297
+ 2x Prelude (dense, intermediate=512)
298
+ 8x Recurrent (dense FFN, intermediate=1024) — with gradient checkpointing
299
+ 2x Coda (dense, intermediate=512)
300
+ LTI Injection + ACT Halting + LoRA Adapter
301
+ """
302
+ def __init__(self, config):
303
+ super().__init__()
304
+ self.config = config
305
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
306
+ self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)])
307
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
308
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
309
+ self.injection = LTIInjection(config)
310
+ self.act_halting = ACTHalting(config)
311
+ self.lora_adapter = LoRAAdapter(config)
312
+ self.loop_embed_dim = config.loop_embed_dim
313
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
314
+ n_loops = n_loops or self.config.max_loop_iters
315
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
316
+ for layer in self.prelude_layers:
317
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
318
+ e = hidden_states.clone()
319
+ B, T_seq, D = hidden_states.shape
320
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
321
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
322
+ h_out = torch.zeros_like(hidden_states)
323
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
324
+ for t in range(n_loops):
325
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
326
+ if t > 0:
327
+ injection = self.injection(hidden_states, input_embedding)
328
+ hidden_states = hidden_states + injection
329
+ new_past_key_values = []
330
+ for i, layer in enumerate(self.recurrent_layers):
331
+ hidden_states, aux_loss, past_kv = checkpoint(
332
+ layer, hidden_states,
333
+ attention_mask=attention_mask,
334
+ position_ids=position_ids,
335
+ past_key_value=past_key_values[i] if t == 0 else None,
336
+ use_cache=use_cache
337
+ )
338
+ new_past_key_values.append(past_kv)
339
+ lora_delta = self.lora_adapter(hidden_states, t)
340
+ hidden_states = hidden_states + lora_delta
341
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
342
+ still_running = ~halted
343
+ remainder = (1.0 - cumulative_p).clamp(min=0)
344
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
345
+ weight = weight * still_running.to(hidden_states.dtype)
346
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
347
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
348
+ halted = halted | (cumulative_p >= self.config.act_threshold)
349
+ if halted.all() and not self.training:
350
+ break
351
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
352
+ hidden_states = h_out + never_halted * hidden_states
353
+ for layer in self.coda_layers:
354
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
355
+ hidden_states = self.norm(hidden_states)
356
+ return hidden_states, 0.0, new_past_key_values
357
+
358
+
359
+ class SpiderPortalForConditionalGeneration(nn.Module):
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.config = config
363
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
364
+ self.model = SpiderPortalDenseModel(config)
365
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
366
+ if config.tie_word_embeddings:
367
+ self.lm_head.weight = self.embed_tokens.weight
368
+ self.apply(self._init_weights)
369
+ def _init_weights(self, module):
370
+ if isinstance(module, nn.Linear):
371
+ if hasattr(self, 'model') and module is self.model.injection.B:
372
+ return
373
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
374
+ if module.bias is not None:
375
+ module.bias.data.zero_()
376
+ elif isinstance(module, nn.Embedding):
377
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
378
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
379
+ hidden_states = self.embed_tokens(input_ids)
380
+ model_dtype = next(self.model.parameters()).dtype
381
+ hidden_states = hidden_states.to(model_dtype)
382
+ input_embedding = hidden_states.clone()
383
+ if attention_mask is None:
384
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
385
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
386
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
387
+ causal_mask = causal_mask.triu(1)
388
+ hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
389
+ logits = self.lm_head(hidden_states)
390
+ loss = None
391
+ if labels is not None:
392
+ shift_logits = logits[..., :-1, :].contiguous()
393
+ shift_labels = labels[..., 1:].contiguous()
394
+ loss_fct = CrossEntropyLoss()
395
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
396
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
397
+ def get_num_params(self):
398
+ total = sum(p.numel() for p in self.parameters())
399
+ return {"total": total, "trainable": total}
400
+
401
+
402
+ # ---------------------------------------------------------------------------
403
+ # Dataset
404
+ # ---------------------------------------------------------------------------
405
+
406
+ class FineWebEduDataset(IterableDataset):
407
+ def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int):
408
+ self.tokenizer = tokenizer
409
+ self.seq_len = seq_len
410
+ self.subset = subset
411
+ self.rank = rank
412
+ self.world_size = world_size
413
+ def __iter__(self):
414
+ worker = get_worker_info()
415
+ num_workers = worker.num_workers if worker else 1
416
+ worker_id = worker.id if worker else 0
417
+ total_shards = self.world_size * num_workers
418
+ shard_index = self.rank * num_workers + worker_id
419
+ ds = load_dataset(
420
+ "HuggingFaceFW/fineweb-edu",
421
+ name=self.subset,
422
+ split="train",
423
+ streaming=True,
424
+ ).shard(num_shards=total_shards, index=shard_index)
425
+ buf = []
426
+ for sample in ds:
427
+ buf.extend(self.tokenizer.encode(sample["text"]))
428
+ while len(buf) >= self.seq_len + 1:
429
+ chunk = buf[: self.seq_len + 1]
430
+ buf = buf[self.seq_len + 1 :]
431
+ yield (
432
+ torch.tensor(chunk[:-1], dtype=torch.long),
433
+ torch.tensor(chunk[1:], dtype=torch.long),
434
+ )
435
+
436
+
437
+ # ---------------------------------------------------------------------------
438
+ # LR schedule: linear warmup → cosine decay
439
+ # ---------------------------------------------------------------------------
440
+
441
+ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
442
+ if step < warmup:
443
+ return max_lr * step / warmup
444
+ if step >= total:
445
+ return min_lr
446
+ decay = (step - warmup) / (total - warmup)
447
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
448
+
449
+
450
+ # ---------------------------------------------------------------------------
451
+ # Checkpointing
452
+ # ---------------------------------------------------------------------------
453
+
454
+ def save_weights_only(model, step, epoch, ckpt_dir, ddp):
455
+ if ddp:
456
+ with FSDP.state_dict_type(
457
+ model,
458
+ StateDictType.FULL_STATE_DICT,
459
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
460
+ ):
461
+ model_state = model.state_dict()
462
+ else:
463
+ model_state = model.state_dict()
464
+ ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt")
465
+ tmp_path = ckpt_path + ".tmp"
466
+ torch.save(model_state, tmp_path)
467
+ os.replace(tmp_path, ckpt_path)
468
+ size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
469
+ return ckpt_path, size_mb
470
+
471
+
472
+ def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"):
473
+ if ddp:
474
+ with FSDP.state_dict_type(
475
+ model,
476
+ StateDictType.FULL_STATE_DICT,
477
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
478
+ ):
479
+ model_state = model.state_dict()
480
+ optim_state = FSDP.optim_state_dict(model, optimizer)
481
+ else:
482
+ model_state = model.state_dict()
483
+ optim_state = optimizer.state_dict()
484
+ if not master:
485
+ return None, 0
486
+ os.makedirs(ckpt_dir, exist_ok=True)
487
+ final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt")
488
+ tmp_path = final_path + ".tmp"
489
+ torch.save(
490
+ {
491
+ "step": step,
492
+ "epoch": epoch,
493
+ "model_state_dict": model_state,
494
+ "optimizer_state_dict": optim_state,
495
+ "cfg": cfg,
496
+ "vocab_size": vocab_size,
497
+ },
498
+ tmp_path,
499
+ )
500
+ os.replace(tmp_path, final_path)
501
+ size_mb = os.path.getsize(final_path) / (1024 * 1024)
502
+ return final_path, size_mb
503
+
504
+
505
+ def load_checkpoint(model, optimizer, path, ddp):
506
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
507
+ if ddp:
508
+ with FSDP.state_dict_type(
509
+ model,
510
+ StateDictType.FULL_STATE_DICT,
511
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
512
+ ):
513
+ model.load_state_dict(ckpt["model_state_dict"])
514
+ optim_state = FSDP.optim_state_dict_to_load(
515
+ model=model,
516
+ optim=optimizer,
517
+ optim_state_dict=ckpt["optimizer_state_dict"],
518
+ )
519
+ optimizer.load_state_dict(optim_state)
520
+ else:
521
+ model.load_state_dict(ckpt["model_state_dict"])
522
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
523
+ return int(ckpt["step"]), int(ckpt.get("epoch", 0))
524
+
525
+
526
+ # ---------------------------------------------------------------------------
527
+ # Main
528
+ # ---------------------------------------------------------------------------
529
+
530
+ def main():
531
+ # ------------------------------------------------------------------
532
+ # Distributed init
533
+ # ------------------------------------------------------------------
534
+ ddp = int(os.environ.get("RANK", -1)) != -1
535
+ if ddp:
536
+ dist.init_process_group("nccl")
537
+ rank = int(os.environ["RANK"])
538
+ local_rank = int(os.environ["LOCAL_RANK"])
539
+ world_size = int(os.environ["WORLD_SIZE"])
540
+ device = f"cuda:{local_rank}"
541
+ torch.cuda.set_device(device)
542
+ else:
543
+ rank = local_rank = 0
544
+ world_size = 1
545
+ device = "cuda" if torch.cuda.is_available() else "cpu"
546
+ master = rank == 0
547
+ if master:
548
+ logger.info(
549
+ f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
550
+ )
551
+
552
+ # ------------------------------------------------------------------
553
+ # Tokenizer
554
+ # ------------------------------------------------------------------
555
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
556
+ tokenizer.pad_token = tokenizer.eos_token
557
+ vocab_size = tokenizer.vocab_size
558
+ if master:
559
+ logger.info(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}")
560
+
561
+ # ------------------------------------------------------------------
562
+ # Hyperparameters — OPTIMIZED for speed
563
+ # ------------------------------------------------------------------
564
+ seq_len = 2048
565
+ micro_batch = 128 # Increased from 32 — RTX 6000 has 96GB VRAM
566
+ target_tokens = 20_000_000_000 # 20B tokens (2 epochs on 10BT)
567
+ grad_accum = 2 # Reduced from 4 — fewer backward passes
568
+ global_batch_tok = world_size * micro_batch * grad_accum * seq_len
569
+ total_steps = target_tokens // global_batch_tok
570
+ warmup_steps = 200
571
+ lr = 3e-4
572
+ wd = 0.1
573
+ log_every = 10
574
+ ckpt_every = 500
575
+ ckpt_dir = "checkpoints-dense"
576
+ dataset_subset = "sample-10BT"
577
+
578
+ if master:
579
+ logger.info(
580
+ f"[DENSE OPTIMIZED] seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
581
+ f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
582
+ )
583
+ logger.info(
584
+ f"Gradient checkpointing: enabled | torch.compile: enabled | "
585
+ f"SDPA attention: enabled"
586
+ )
587
+
588
+ # ------------------------------------------------------------------
589
+ # Model — Dense variant
590
+ # ------------------------------------------------------------------
591
+ cfg = SpiderPortalConfig(
592
+ hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
593
+ num_key_value_heads=2, intermediate_size=1024,
594
+ num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
595
+ router_aux_loss_coef=0.05, max_loop_iters=1,
596
+ prelude_layers=2, coda_layers=2, lora_rank=32,
597
+ rope_theta=10000000.0,
598
+ rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
599
+ max_position_embeddings=131072, sliding_window=4096,
600
+ tie_word_embeddings=True,
601
+ )
602
+ cfg.vocab_size = vocab_size
603
+
604
+ bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
605
+ amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
606
+
607
+ model = SpiderPortalForConditionalGeneration(cfg)
608
+
609
+ if ddp:
610
+ mp_policy = MixedPrecision(
611
+ param_dtype=amp_dtype,
612
+ reduce_dtype=amp_dtype,
613
+ buffer_dtype=amp_dtype,
614
+ )
615
+ wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer})
616
+ model = FSDP(
617
+ model,
618
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
619
+ mixed_precision=mp_policy,
620
+ auto_wrap_policy=wrap_policy,
621
+ device_id=local_rank,
622
+ )
623
+ else:
624
+ model = model.to(device)
625
+ amp_ctx = (
626
+ torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
627
+ if "cuda" in device
628
+ else nullcontext()
629
+ )
630
+
631
+ amp_ctx = nullcontext() if ddp else amp_ctx
632
+
633
+ if master:
634
+ n_params = sum(p.numel() for p in model.parameters())
635
+ n_active = n_params # Dense = all params active
636
+ logger.info(f"Parameters: {n_params:,} (all active) | AMP dtype: {amp_dtype}")
637
+
638
+ # Compile — ACTUAL torch.compile this time
639
+ try:
640
+ model = torch.compile(model, mode="reduce-overhead")
641
+ if master:
642
+ logger.info("torch.compile: enabled (reduce-overhead)")
643
+ except Exception as e:
644
+ if master:
645
+ logger.warning(f"torch.compile failed ({e}), using eager mode")
646
+
647
+ # ------------------------------------------------------------------
648
+ # Optimizer
649
+ # ------------------------------------------------------------------
650
+ optimizer = torch.optim.AdamW(
651
+ model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
652
+ )
653
+
654
+ # ------------------------------------------------------------------
655
+ # Resume from latest checkpoint (if any)
656
+ # ------------------------------------------------------------------
657
+ start_step = 0
658
+ start_epoch = 1
659
+ best_loss = float("inf")
660
+ existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else []
661
+ if existing_ckpts:
662
+ latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1])
663
+ if master:
664
+ logger.info(f"Resuming from checkpoint: {latest}")
665
+ start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp)
666
+ if master:
667
+ logger.success(f"Resumed at step {start_step}, epoch {start_epoch}")
668
+
669
+ # ------------------------------------------------------------------
670
+ # Dataset + DataLoader
671
+ # ------------------------------------------------------------------
672
+ dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size)
673
+ loader = DataLoader(dataset, batch_size=micro_batch, num_workers=8, pin_memory=True, prefetch_factor=2)
674
+
675
+ # ------------------------------------------------------------------
676
+ # Training loop
677
+ # ------------------------------------------------------------------
678
+ if master:
679
+ os.makedirs(ckpt_dir, exist_ok=True)
680
+
681
+ model.train()
682
+ data_iter = iter(loader)
683
+ t0 = time.perf_counter()
684
+ step = start_step
685
+ epoch = start_epoch
686
+ step_ckpt_files = []
687
+ tokens_in_epoch = 0
688
+ tokens_per_epoch = target_tokens
689
+
690
+ while step < total_steps:
691
+ cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
692
+ for g in optimizer.param_groups:
693
+ g["lr"] = cur_lr
694
+
695
+ optimizer.zero_grad()
696
+ loss_accum = 0.0
697
+
698
+ for micro_step in range(grad_accum):
699
+ try:
700
+ x, y = next(data_iter)
701
+ except StopIteration:
702
+ data_iter = iter(loader)
703
+ x, y = next(data_iter)
704
+
705
+ x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
706
+ y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
707
+
708
+ sync = (
709
+ nullcontext()
710
+ if (not ddp or micro_step == grad_accum - 1)
711
+ else model.no_sync()
712
+ )
713
+ with sync, amp_ctx:
714
+ output = model(x)
715
+ if isinstance(output, dict):
716
+ logits = output["logits"]
717
+ else:
718
+ logits = output
719
+ loss = nn.functional.cross_entropy(
720
+ logits.view(-1, vocab_size), y.view(-1)
721
+ )
722
+ loss = loss / grad_accum
723
+
724
+ loss.backward()
725
+ loss_accum += loss.item()
726
+
727
+ if ddp:
728
+ grad_norm = model.clip_grad_norm_(1.0)
729
+ else:
730
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
731
+ optimizer.step()
732
+ step += 1
733
+ tokens_in_epoch += global_batch_tok
734
+
735
+ if master and step % log_every == 0:
736
+ dt = time.perf_counter() - t0
737
+ tok_per_sec = global_batch_tok * log_every / dt
738
+ tokens_seen = step * global_batch_tok
739
+ logger.info(
740
+ f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
741
+ f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
742
+ f"| {tok_per_sec / 1e6:.2f}M tok/s "
743
+ f"| {tokens_seen / 1e9:.2f}B tokens seen"
744
+ )
745
+ t0 = time.perf_counter()
746
+
747
+ if step % ckpt_every == 0 and master:
748
+ ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp)
749
+ step_ckpt_files.append(ckpt_path)
750
+ logger.info(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
751
+
752
+ if tokens_in_epoch >= tokens_per_epoch:
753
+ epoch_loss = loss_accum
754
+ if master:
755
+ epoch_time = (time.perf_counter() - t0) / 60
756
+ logger.info(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min")
757
+
758
+ for f in step_ckpt_files:
759
+ if os.path.exists(f):
760
+ os.remove(f)
761
+ logger.info(f" Deleted step checkpoint: {os.path.basename(f)}")
762
+ step_ckpt_files.clear()
763
+
764
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}")
765
+ if ckpt_path:
766
+ logger.info(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
767
+
768
+ if epoch_loss < best_loss:
769
+ best_loss = epoch_loss
770
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best")
771
+ if ckpt_path:
772
+ logger.info(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
773
+
774
+ epoch += 1
775
+ tokens_in_epoch = 0
776
+
777
+ if step > start_step and master:
778
+ ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}")
779
+ if ckpt_path:
780
+ logger.info(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)")
781
+
782
+ if ddp:
783
+ dist.barrier()
784
+ dist.destroy_process_group()
785
+
786
+ if master:
787
+ logger.success("Training complete.")
788
+
789
+
790
+ if __name__ == "__main__":
791
+ main()