CLIWorks commited on
Commit
e9af599
·
verified ·
1 Parent(s): d1baae9

Upload train_single_gpu.py

Browse files
Files changed (1) hide show
  1. train_single_gpu.py +581 -0
train_single_gpu.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SpiderPortal v5 — Single-GPU Optimized Training.
3
+
4
+ For RTX PRO 6000 (96GB) or similar large-VRAM GPU.
5
+ No DDP, maximal batch size, torch.compile, pre-tokenized data.
6
+
7
+ Usage:
8
+ python train_single_gpu.py
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+ import os
16
+ import json
17
+ import gc
18
+ import random
19
+ import time
20
+ from pathlib import Path
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple, Dict, List
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+ @dataclass
26
+ class SpiderPortalConfig:
27
+ vocab_size: int = 50278
28
+ hidden_size: int = 384
29
+ num_hidden_layers: int = 8
30
+ num_attention_heads: int = 8
31
+ num_key_value_heads: int = 2
32
+ intermediate_size: int = 1024
33
+ hidden_act: str = "silu"
34
+ num_experts: int = 64
35
+ num_experts_per_tok: int = 1
36
+ num_shared_experts: int = 1
37
+ router_aux_loss_coef: float = 0.05
38
+ max_loop_iters: int = 4
39
+ act_threshold: float = 0.5
40
+ max_position_embeddings: int = 131072
41
+ rope_theta: float = 10000000.0
42
+ rope_scaling: dict = None
43
+ sliding_window: int = 4096
44
+ attention_dropout: float = 0.0
45
+ rms_norm_eps: float = 1e-6
46
+ initializer_range: float = 0.02
47
+ use_cache: bool = True
48
+ tie_word_embeddings: bool = True
49
+ prelude_layers: int = 2
50
+ coda_layers: int = 2
51
+ lora_rank: int = 32
52
+ loop_embed_dim: int = 48
53
+ vision_hidden_size: int = 384
54
+ audio_hidden_size: int = 512
55
+ vision_num_frames: int = 60
56
+ vision_tokens_per_frame: int = 256
57
+ vision_temporal_tokens: int = 64
58
+ vision_temporal_layers: int = 2
59
+ model_type: str = "spiderportal"
60
+ torch_dtype: str = "bfloat16"
61
+
62
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
63
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
64
+ angles = loop_t * freqs
65
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
66
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
67
+ emb_full[:loop_dim] = emb
68
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
69
+
70
+ class SpiderPortalRMSNorm(nn.Module):
71
+ def __init__(self, hidden_size, eps=1e-6):
72
+ super().__init__()
73
+ self.weight = nn.Parameter(torch.ones(hidden_size))
74
+ self.variance_epsilon = eps
75
+ def forward(self, hidden_states):
76
+ input_dtype = hidden_states.dtype
77
+ hidden_states = hidden_states.to(torch.float32)
78
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
81
+
82
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
83
+ dim = head_dim
84
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
85
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
86
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
87
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
88
+ return orig_inv_freq * scale
89
+
90
+ class SpiderPortalGQA(nn.Module):
91
+ def __init__(self, config):
92
+ super().__init__()
93
+ self.config = config
94
+ self.hidden_size = config.hidden_size
95
+ self.num_heads = config.num_attention_heads
96
+ self.num_kv_heads = config.num_key_value_heads
97
+ self.head_dim = config.hidden_size // config.num_attention_heads
98
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
99
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
100
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
101
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
102
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
103
+ self.attention_dropout = config.attention_dropout
104
+ rope_scaling = getattr(config, 'rope_scaling', None)
105
+ if rope_scaling and rope_scaling.get("type") == "yarn":
106
+ factor = rope_scaling.get("factor", 1.0)
107
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
108
+ inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
109
+ else:
110
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
111
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
112
+ def _rotate_half(self, x):
113
+ x1 = x[..., :x.shape[-1] // 2]
114
+ x2 = x[..., x.shape[-1] // 2:]
115
+ return torch.cat((-x2, x1), dim=-1)
116
+ def _apply_rotary(self, x, cos, sin):
117
+ return (x * cos) + (self._rotate_half(x) * sin)
118
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
119
+ bsz, q_len, _ = hidden_states.size()
120
+ query_states = self.q_proj(hidden_states)
121
+ key_states = self.k_proj(hidden_states)
122
+ value_states = self.v_proj(hidden_states)
123
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
124
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
125
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
126
+ if position_ids is None:
127
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
128
+ max_pos = position_ids.max().item() + 1
129
+ seq_len = max(max_pos, q_len)
130
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ cos, sin = emb.cos(), emb.sin()
134
+ cos = cos[position_ids].unsqueeze(1)
135
+ sin = sin[position_ids].unsqueeze(1)
136
+ query_states = self._apply_rotary(query_states, cos, sin)
137
+ key_states = self._apply_rotary(key_states, cos, sin)
138
+ if past_key_value is not None:
139
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
140
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
141
+ past_kv = (key_states, value_states) if use_cache else None
142
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
143
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
144
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
145
+ if attention_mask is not None:
146
+ attn_weights = attn_weights + attention_mask
147
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
148
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
149
+ attn_output = torch.matmul(attn_weights, value_states)
150
+ attn_output = attn_output.transpose(1, 2).contiguous()
151
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
152
+ return self.o_proj(attn_output), past_kv
153
+
154
+ class SpiderPortalExpert(nn.Module):
155
+ def __init__(self, config, intermediate_size=None):
156
+ super().__init__()
157
+ inter_size = intermediate_size or config.intermediate_size
158
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
159
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
160
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
161
+ self.act_fn = nn.SiLU()
162
+ def forward(self, hidden_states):
163
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
164
+
165
+ class SpiderPortalRouter(nn.Module):
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.num_experts = config.num_experts
169
+ self.num_experts_per_tok = config.num_experts_per_tok
170
+ self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
171
+ self.register_buffer("router_bias", torch.zeros(config.num_experts))
172
+ def forward(self, hidden_states):
173
+ router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
174
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
175
+ biased_logits = router_logits + self.router_bias
176
+ biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
177
+ top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
178
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
179
+ top_weights = top_weights.to(hidden_states.dtype)
180
+ mean_probs = routing_weights.mean(dim=0)
181
+ aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
182
+ return top_weights, top_indices, aux_loss
183
+
184
+ class SpiderPortalMoE(nn.Module):
185
+ def __init__(self, config):
186
+ super().__init__()
187
+ self.config = config
188
+ self.num_experts = config.num_experts
189
+ self.num_experts_per_tok = config.num_experts_per_tok
190
+ self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
191
+ self.shared_expert = SpiderPortalExpert(config)
192
+ self.router = SpiderPortalRouter(config)
193
+ def forward(self, hidden_states):
194
+ batch_size, seq_len, hidden_dim = hidden_states.shape
195
+ top_weights, top_indices, aux_loss = self.router(hidden_states)
196
+ flat_hidden = hidden_states.view(-1, hidden_dim)
197
+ final_output = torch.zeros_like(flat_hidden)
198
+ for expert_idx in range(self.num_experts_per_tok):
199
+ expert_ids = top_indices[:, expert_idx]
200
+ expert_weights = top_weights[:, expert_idx:expert_idx+1]
201
+ for e in range(self.num_experts):
202
+ mask = expert_ids == e
203
+ if mask.any():
204
+ expert_output = self.experts[e](flat_hidden[mask])
205
+ final_output[mask] += expert_output * expert_weights[mask]
206
+ shared_output = self.shared_expert(flat_hidden)
207
+ final_output = final_output + shared_output
208
+ return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
209
+
210
+ class SpiderPortalDenseLayer(nn.Module):
211
+ def __init__(self, config):
212
+ super().__init__()
213
+ self.self_attn = SpiderPortalGQA(config)
214
+ dense_intermediate = config.hidden_size * 4 // 3
215
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
216
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
219
+ attn_input = self.input_layernorm(hidden_states)
220
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
221
+ hidden_states = hidden_states + attn_output
222
+ ffn_input = self.post_attention_layernorm(hidden_states)
223
+ ffn_output = self.ffn(ffn_input)
224
+ hidden_states = hidden_states + ffn_output
225
+ return hidden_states, past_kv
226
+
227
+ class SpiderPortalMoELayer(nn.Module):
228
+ def __init__(self, config, layer_idx):
229
+ super().__init__()
230
+ self.layer_idx = layer_idx
231
+ self.self_attn = SpiderPortalGQA(config)
232
+ self.moe = SpiderPortalMoE(config)
233
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
236
+ attn_input = self.input_layernorm(hidden_states)
237
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
238
+ hidden_states = hidden_states + attn_output
239
+ moe_input = self.post_attention_layernorm(hidden_states)
240
+ moe_output, aux_loss = self.moe(moe_input)
241
+ hidden_states = hidden_states + moe_output
242
+ return hidden_states, aux_loss, past_kv
243
+
244
+ class LTIInjection(nn.Module):
245
+ def __init__(self, config):
246
+ super().__init__()
247
+ self.hidden_size = config.hidden_size
248
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
249
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
250
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
251
+ with torch.no_grad():
252
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
253
+ def get_A(self):
254
+ return -torch.exp(self.log_A)
255
+ def forward(self, h_t, e):
256
+ A = self.get_A()
257
+ return A * h_t + self.B(e)
258
+
259
+ class ACTHalting(nn.Module):
260
+ def __init__(self, config):
261
+ super().__init__()
262
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
263
+ self.threshold = config.act_threshold
264
+ def forward(self, hidden_states):
265
+ return torch.sigmoid(self.halt_predictor(hidden_states))
266
+
267
+ class LoRAAdapter(nn.Module):
268
+ def __init__(self, config):
269
+ super().__init__()
270
+ rank = config.lora_rank
271
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
272
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
273
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
274
+ with torch.no_grad():
275
+ self.scale.weight.data.zero_()
276
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
277
+ def forward(self, x, loop_t):
278
+ max_t = self.scale.num_embeddings - 1
279
+ t_idx = min(loop_t, max_t)
280
+ s = self.scale(torch.tensor(t_idx, device=x.device))
281
+ down = self.down(x) * s
282
+ return down @ self.B
283
+
284
+ class SpiderPortalMoEModel(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.config = config
288
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
289
+ self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
290
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
291
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
292
+ self.injection = LTIInjection(config)
293
+ self.act_halting = ACTHalting(config)
294
+ self.lora_adapter = LoRAAdapter(config)
295
+ self.loop_embed_dim = config.loop_embed_dim
296
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
297
+ n_loops = n_loops or self.config.max_loop_iters
298
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
299
+ total_aux_loss = 0.0
300
+ for layer in self.prelude_layers:
301
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
302
+ e = hidden_states.clone()
303
+ B, T_seq, D = hidden_states.shape
304
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
305
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
306
+ h_out = torch.zeros_like(hidden_states)
307
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
308
+ for t in range(n_loops):
309
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
310
+ if t > 0:
311
+ injection = self.injection(hidden_states, input_embedding)
312
+ hidden_states = hidden_states + injection
313
+ new_past_key_values = []
314
+ for i, layer in enumerate(self.recurrent_layers):
315
+ hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
316
+ total_aux_loss = total_aux_loss + aux_loss
317
+ new_past_key_values.append(past_kv)
318
+ lora_delta = self.lora_adapter(hidden_states, t)
319
+ hidden_states = hidden_states + lora_delta
320
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
321
+ still_running = ~halted
322
+ remainder = (1.0 - cumulative_p).clamp(min=0)
323
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
324
+ weight = weight * still_running.to(hidden_states.dtype)
325
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
326
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
327
+ halted = halted | (cumulative_p >= self.config.act_threshold)
328
+ if halted.all() and not self.training:
329
+ break
330
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
331
+ hidden_states = h_out + never_halted * hidden_states
332
+ for layer in self.coda_layers:
333
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
334
+ hidden_states = self.norm(hidden_states)
335
+ return hidden_states, total_aux_loss, new_past_key_values
336
+
337
+ class SpiderPortalForConditionalGeneration(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.config = config
341
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
342
+ self.model = SpiderPortalMoEModel(config)
343
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
344
+ if config.tie_word_embeddings:
345
+ self.lm_head.weight = self.embed_tokens.weight
346
+ self.apply(self._init_weights)
347
+ def _init_weights(self, module):
348
+ if isinstance(module, nn.Linear):
349
+ if hasattr(self, 'model') and module is self.model.injection.B:
350
+ return
351
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
352
+ if module.bias is not None:
353
+ module.bias.data.zero_()
354
+ elif isinstance(module, nn.Embedding):
355
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
356
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
357
+ hidden_states = self.embed_tokens(input_ids)
358
+ model_dtype = next(self.model.parameters()).dtype
359
+ hidden_states = hidden_states.to(model_dtype)
360
+ input_embedding = hidden_states.clone()
361
+ if attention_mask is None:
362
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
363
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
364
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
365
+ causal_mask = causal_mask.triu(1)
366
+ hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
367
+ logits = self.lm_head(hidden_states)
368
+ loss = None
369
+ if labels is not None:
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = labels[..., 1:].contiguous()
372
+ loss_fct = CrossEntropyLoss()
373
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
374
+ loss = loss + self.config.router_aux_loss_coef * aux_loss
375
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
376
+ def get_num_params(self):
377
+ total = sum(p.numel() for p in self.parameters())
378
+ return {"total": total, "trainable": total}
379
+
380
+ def train_single_gpu():
381
+ device = torch.device("cuda")
382
+ gpu_name = torch.cuda.get_device_name(0)
383
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
384
+ print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
385
+
386
+ config = SpiderPortalConfig(
387
+ hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
388
+ num_key_value_heads=2, intermediate_size=1024,
389
+ num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
390
+ router_aux_loss_coef=0.05, max_loop_iters=2,
391
+ prelude_layers=2, coda_layers=2, lora_rank=32,
392
+ rope_theta=10000000.0,
393
+ rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
394
+ max_position_embeddings=131072, sliding_window=4096,
395
+ tie_word_embeddings=True,
396
+ )
397
+
398
+ print("Building model...")
399
+ model = SpiderPortalForConditionalGeneration(config)
400
+ model = model.to(torch.bfloat16).to(device)
401
+
402
+ params = model.get_num_params()
403
+ print(f"Model: {params['total']/1e6:.1f}M params")
404
+ print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
405
+
406
+ BASE_LR = 1e-3
407
+ WARMUP_STEPS = 500
408
+ optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01, betas=(0.9, 0.95))
409
+
410
+ import pandas as pd
411
+ data_dir = Path(__file__).parent / "data"
412
+ all_records = []
413
+ pkl_file = data_dir / "spiderportal_combined.pkl"
414
+ if pkl_file.exists():
415
+ print(f"Loading dataset from {pkl_file}...")
416
+ df = pd.read_pickle(pkl_file)
417
+ all_records = df.to_dict("records")
418
+ else:
419
+ print(f"No dataset found at {pkl_file}, creating synthetic data...")
420
+ all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
421
+
422
+ print(f"Loaded {len(all_records):,} samples")
423
+
424
+ from transformers import AutoTokenizer
425
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
426
+ tokenizer.pad_token = tokenizer.eos_token
427
+
428
+ BATCH_SIZE = 128
429
+ MAX_LEN = 256
430
+ EPOCHS = 3
431
+ N_LOOPS = 2
432
+
433
+ print(f"Batch size: {BATCH_SIZE} (no grad accum)")
434
+ print(f"Effective batch: {BATCH_SIZE}")
435
+ print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup (high LR for recurrent MoE)")
436
+ print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}")
437
+
438
+ def build_prompt(sample):
439
+ instruction = str(sample.get("instruction", "")).strip()
440
+ inp = str(sample.get("input", "")).strip()
441
+ output = str(sample.get("output", "")).strip()
442
+ if inp:
443
+ return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
444
+ return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
445
+
446
+ print("Pre-tokenizing dataset...")
447
+ prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
448
+ mask_len = len(prefix_ids)
449
+
450
+ pre_tokenized = []
451
+ for i, sample in enumerate(all_records):
452
+ instruction = str(sample.get("instruction", "")).strip()
453
+ inp = str(sample.get("input", "")).strip()
454
+ output = str(sample.get("output", "")).strip()
455
+ if inp:
456
+ text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token
457
+ else:
458
+ text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token
459
+ enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length")
460
+ input_ids = enc["input_ids"]
461
+ labels = input_ids[:]
462
+ for j in range(min(mask_len, len(labels))):
463
+ labels[j] = -100
464
+ pre_tokenized.append((input_ids, labels))
465
+ if (i + 1) % 50000 == 0:
466
+ print(f" Tokenized {i+1:,}/{len(all_records):,}")
467
+
468
+ print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples")
469
+ del all_records
470
+ gc.collect()
471
+
472
+ global_step = 0
473
+ best_loss = float('inf')
474
+ start_time = time.time()
475
+ checkpoint_dir = Path("checkpoints")
476
+ checkpoint_dir.mkdir(exist_ok=True)
477
+ step_ckpt_files = []
478
+
479
+ for epoch in range(1, EPOCHS + 1):
480
+ if epoch > 1:
481
+ for f in step_ckpt_files:
482
+ if f.exists():
483
+ f.unlink()
484
+ print(f" Deleted old step checkpoint: {f.name}")
485
+ step_ckpt_files.clear()
486
+ gc.collect()
487
+
488
+ indices = list(range(len(pre_tokenized)))
489
+ random.shuffle(indices)
490
+ total_loss = 0
491
+ num_batches = 0
492
+ optimizer.zero_grad()
493
+
494
+ for batch_start in range(0, len(indices), BATCH_SIZE):
495
+ batch_indices = indices[batch_start:batch_start + BATCH_SIZE]
496
+ if len(batch_indices) < BATCH_SIZE:
497
+ continue
498
+
499
+ if global_step < WARMUP_STEPS:
500
+ lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
501
+ for param_group in optimizer.param_groups:
502
+ param_group['lr'] = lr
503
+
504
+ batch_input_ids = []
505
+ batch_labels = []
506
+ for idx in batch_indices:
507
+ input_ids, labels = pre_tokenized[idx]
508
+ batch_input_ids.append(input_ids)
509
+ batch_labels.append(labels)
510
+
511
+ input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device)
512
+ labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
513
+
514
+ if global_step == 0:
515
+ print(" [First forward pass - compiling...]")
516
+
517
+ outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
518
+ loss = outputs["loss"]
519
+ loss.backward()
520
+
521
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
522
+ optimizer.step()
523
+ optimizer.zero_grad()
524
+ global_step += 1
525
+
526
+ total_loss += loss.item()
527
+ num_batches += 1
528
+
529
+ if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0:
530
+ avg_loss = total_loss / max(num_batches, 1)
531
+ elapsed = time.time() - start_time
532
+ steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
533
+ current_lr = optimizer.param_groups[0]['lr']
534
+ samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0
535
+ print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec")
536
+
537
+ if global_step > 0 and global_step % 500 == 0:
538
+ ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt"
539
+ state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
540
+ torch.save(state_dict, ckpt_path)
541
+ step_ckpt_files.append(ckpt_path)
542
+ size_mb = ckpt_path.stat().st_size / (1024 * 1024)
543
+ print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
544
+
545
+ avg_loss = total_loss / max(num_batches, 1)
546
+ epoch_time = (time.time() - start_time) / 60
547
+ print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min")
548
+
549
+ ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt"
550
+ torch.save({
551
+ "step": global_step,
552
+ "epoch": epoch,
553
+ "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
554
+ "optimizer_state_dict": optimizer.state_dict(),
555
+ "config": config.__dict__,
556
+ }, ckpt_path)
557
+ size_mb = ckpt_path.stat().st_size / (1024 * 1024)
558
+ print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
559
+
560
+ if avg_loss < best_loss:
561
+ best_loss = avg_loss
562
+ best_path = checkpoint_dir / "spiderportal-v5-best.pt"
563
+ torch.save({
564
+ "step": global_step,
565
+ "epoch": epoch,
566
+ "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
567
+ "optimizer_state_dict": optimizer.state_dict(),
568
+ "config": config.__dict__,
569
+ }, best_path)
570
+ size_mb = best_path.stat().st_size / (1024 * 1024)
571
+ print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)")
572
+
573
+ total_time = (time.time() - start_time) / 3600
574
+ print(f"\nTraining complete!")
575
+ print(f"Best loss: {best_loss:.4f}")
576
+ print(f"Total time: {total_time:.2f} hours")
577
+ print(f"Total steps: {global_step}")
578
+ print(f"Checkpoints saved to: {checkpoint_dir}")
579
+
580
+ if __name__ == "__main__":
581
+ train_single_gpu()