CLIWorks commited on
Commit
c0eb78b
·
verified ·
1 Parent(s): e9af599

Delete train_single_gpu.py

Browse files
Files changed (1) hide show
  1. train_single_gpu.py +0 -581
train_single_gpu.py DELETED
@@ -1,581 +0,0 @@
1
- #!/usr/bin/env python3
2
- """SpiderPortal v5 — Single-GPU Optimized Training.
3
-
4
- For RTX PRO 6000 (96GB) or similar large-VRAM GPU.
5
- No DDP, maximal batch size, torch.compile, pre-tokenized data.
6
-
7
- Usage:
8
- python train_single_gpu.py
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- import math
15
- import os
16
- import json
17
- import gc
18
- import random
19
- import time
20
- from pathlib import Path
21
- from dataclasses import dataclass
22
- from typing import Optional, Tuple, Dict, List
23
- from torch.nn import CrossEntropyLoss
24
-
25
- @dataclass
26
- class SpiderPortalConfig:
27
- vocab_size: int = 50278
28
- hidden_size: int = 384
29
- num_hidden_layers: int = 8
30
- num_attention_heads: int = 8
31
- num_key_value_heads: int = 2
32
- intermediate_size: int = 1024
33
- hidden_act: str = "silu"
34
- num_experts: int = 64
35
- num_experts_per_tok: int = 1
36
- num_shared_experts: int = 1
37
- router_aux_loss_coef: float = 0.05
38
- max_loop_iters: int = 4
39
- act_threshold: float = 0.5
40
- max_position_embeddings: int = 131072
41
- rope_theta: float = 10000000.0
42
- rope_scaling: dict = None
43
- sliding_window: int = 4096
44
- attention_dropout: float = 0.0
45
- rms_norm_eps: float = 1e-6
46
- initializer_range: float = 0.02
47
- use_cache: bool = True
48
- tie_word_embeddings: bool = True
49
- prelude_layers: int = 2
50
- coda_layers: int = 2
51
- lora_rank: int = 32
52
- loop_embed_dim: int = 48
53
- vision_hidden_size: int = 384
54
- audio_hidden_size: int = 512
55
- vision_num_frames: int = 60
56
- vision_tokens_per_frame: int = 256
57
- vision_temporal_tokens: int = 64
58
- vision_temporal_layers: int = 2
59
- model_type: str = "spiderportal"
60
- torch_dtype: str = "bfloat16"
61
-
62
- def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
63
- freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
64
- angles = loop_t * freqs
65
- emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
66
- emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
67
- emb_full[:loop_dim] = emb
68
- return h + emb_full.unsqueeze(0).unsqueeze(0)
69
-
70
- class SpiderPortalRMSNorm(nn.Module):
71
- def __init__(self, hidden_size, eps=1e-6):
72
- super().__init__()
73
- self.weight = nn.Parameter(torch.ones(hidden_size))
74
- self.variance_epsilon = eps
75
- def forward(self, hidden_states):
76
- input_dtype = hidden_states.dtype
77
- hidden_states = hidden_states.to(torch.float32)
78
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
- return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
81
-
82
- def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
83
- dim = head_dim
84
- orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
85
- pos_freqs = torch.arange(0, dim, 2).float() / dim
86
- beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
87
- scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
88
- return orig_inv_freq * scale
89
-
90
- class SpiderPortalGQA(nn.Module):
91
- def __init__(self, config):
92
- super().__init__()
93
- self.config = config
94
- self.hidden_size = config.hidden_size
95
- self.num_heads = config.num_attention_heads
96
- self.num_kv_heads = config.num_key_value_heads
97
- self.head_dim = config.hidden_size // config.num_attention_heads
98
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
99
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
100
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
101
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
102
- self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
103
- self.attention_dropout = config.attention_dropout
104
- rope_scaling = getattr(config, 'rope_scaling', None)
105
- if rope_scaling and rope_scaling.get("type") == "yarn":
106
- factor = rope_scaling.get("factor", 1.0)
107
- orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
108
- inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
109
- else:
110
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
111
- self.register_buffer("inv_freq", inv_freq, persistent=False)
112
- def _rotate_half(self, x):
113
- x1 = x[..., :x.shape[-1] // 2]
114
- x2 = x[..., x.shape[-1] // 2:]
115
- return torch.cat((-x2, x1), dim=-1)
116
- def _apply_rotary(self, x, cos, sin):
117
- return (x * cos) + (self._rotate_half(x) * sin)
118
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
119
- bsz, q_len, _ = hidden_states.size()
120
- query_states = self.q_proj(hidden_states)
121
- key_states = self.k_proj(hidden_states)
122
- value_states = self.v_proj(hidden_states)
123
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
124
- key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
125
- value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
126
- if position_ids is None:
127
- position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
128
- max_pos = position_ids.max().item() + 1
129
- seq_len = max(max_pos, q_len)
130
- t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
131
- freqs = torch.outer(t, self.inv_freq)
132
- emb = torch.cat((freqs, freqs), dim=-1)
133
- cos, sin = emb.cos(), emb.sin()
134
- cos = cos[position_ids].unsqueeze(1)
135
- sin = sin[position_ids].unsqueeze(1)
136
- query_states = self._apply_rotary(query_states, cos, sin)
137
- key_states = self._apply_rotary(key_states, cos, sin)
138
- if past_key_value is not None:
139
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
140
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
141
- past_kv = (key_states, value_states) if use_cache else None
142
- key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
143
- value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
144
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
145
- if attention_mask is not None:
146
- attn_weights = attn_weights + attention_mask
147
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
148
- attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
149
- attn_output = torch.matmul(attn_weights, value_states)
150
- attn_output = attn_output.transpose(1, 2).contiguous()
151
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
152
- return self.o_proj(attn_output), past_kv
153
-
154
- class SpiderPortalExpert(nn.Module):
155
- def __init__(self, config, intermediate_size=None):
156
- super().__init__()
157
- inter_size = intermediate_size or config.intermediate_size
158
- self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
159
- self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
160
- self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
161
- self.act_fn = nn.SiLU()
162
- def forward(self, hidden_states):
163
- return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
164
-
165
- class SpiderPortalRouter(nn.Module):
166
- def __init__(self, config):
167
- super().__init__()
168
- self.num_experts = config.num_experts
169
- self.num_experts_per_tok = config.num_experts_per_tok
170
- self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
171
- self.register_buffer("router_bias", torch.zeros(config.num_experts))
172
- def forward(self, hidden_states):
173
- router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
174
- routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
175
- biased_logits = router_logits + self.router_bias
176
- biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
177
- top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
178
- top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
179
- top_weights = top_weights.to(hidden_states.dtype)
180
- mean_probs = routing_weights.mean(dim=0)
181
- aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
182
- return top_weights, top_indices, aux_loss
183
-
184
- class SpiderPortalMoE(nn.Module):
185
- def __init__(self, config):
186
- super().__init__()
187
- self.config = config
188
- self.num_experts = config.num_experts
189
- self.num_experts_per_tok = config.num_experts_per_tok
190
- self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
191
- self.shared_expert = SpiderPortalExpert(config)
192
- self.router = SpiderPortalRouter(config)
193
- def forward(self, hidden_states):
194
- batch_size, seq_len, hidden_dim = hidden_states.shape
195
- top_weights, top_indices, aux_loss = self.router(hidden_states)
196
- flat_hidden = hidden_states.view(-1, hidden_dim)
197
- final_output = torch.zeros_like(flat_hidden)
198
- for expert_idx in range(self.num_experts_per_tok):
199
- expert_ids = top_indices[:, expert_idx]
200
- expert_weights = top_weights[:, expert_idx:expert_idx+1]
201
- for e in range(self.num_experts):
202
- mask = expert_ids == e
203
- if mask.any():
204
- expert_output = self.experts[e](flat_hidden[mask])
205
- final_output[mask] += expert_output * expert_weights[mask]
206
- shared_output = self.shared_expert(flat_hidden)
207
- final_output = final_output + shared_output
208
- return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
209
-
210
- class SpiderPortalDenseLayer(nn.Module):
211
- def __init__(self, config):
212
- super().__init__()
213
- self.self_attn = SpiderPortalGQA(config)
214
- dense_intermediate = config.hidden_size * 4 // 3
215
- self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
216
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
219
- attn_input = self.input_layernorm(hidden_states)
220
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
221
- hidden_states = hidden_states + attn_output
222
- ffn_input = self.post_attention_layernorm(hidden_states)
223
- ffn_output = self.ffn(ffn_input)
224
- hidden_states = hidden_states + ffn_output
225
- return hidden_states, past_kv
226
-
227
- class SpiderPortalMoELayer(nn.Module):
228
- def __init__(self, config, layer_idx):
229
- super().__init__()
230
- self.layer_idx = layer_idx
231
- self.self_attn = SpiderPortalGQA(config)
232
- self.moe = SpiderPortalMoE(config)
233
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
236
- attn_input = self.input_layernorm(hidden_states)
237
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
238
- hidden_states = hidden_states + attn_output
239
- moe_input = self.post_attention_layernorm(hidden_states)
240
- moe_output, aux_loss = self.moe(moe_input)
241
- hidden_states = hidden_states + moe_output
242
- return hidden_states, aux_loss, past_kv
243
-
244
- class LTIInjection(nn.Module):
245
- def __init__(self, config):
246
- super().__init__()
247
- self.hidden_size = config.hidden_size
248
- self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
249
- self.delta_t = nn.Parameter(torch.tensor(1.0))
250
- self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
251
- with torch.no_grad():
252
- self.B.weight.data.normal_(mean=0.0, std=0.01)
253
- def get_A(self):
254
- return -torch.exp(self.log_A)
255
- def forward(self, h_t, e):
256
- A = self.get_A()
257
- return A * h_t + self.B(e)
258
-
259
- class ACTHalting(nn.Module):
260
- def __init__(self, config):
261
- super().__init__()
262
- self.halt_predictor = nn.Linear(config.hidden_size, 1)
263
- self.threshold = config.act_threshold
264
- def forward(self, hidden_states):
265
- return torch.sigmoid(self.halt_predictor(hidden_states))
266
-
267
- class LoRAAdapter(nn.Module):
268
- def __init__(self, config):
269
- super().__init__()
270
- rank = config.lora_rank
271
- self.down = nn.Linear(config.hidden_size, rank, bias=False)
272
- self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
273
- self.scale = nn.Embedding(config.max_loop_iters, rank)
274
- with torch.no_grad():
275
- self.scale.weight.data.zero_()
276
- self.down.weight.data.normal_(mean=0.0, std=0.001)
277
- def forward(self, x, loop_t):
278
- max_t = self.scale.num_embeddings - 1
279
- t_idx = min(loop_t, max_t)
280
- s = self.scale(torch.tensor(t_idx, device=x.device))
281
- down = self.down(x) * s
282
- return down @ self.B
283
-
284
- class SpiderPortalMoEModel(nn.Module):
285
- def __init__(self, config):
286
- super().__init__()
287
- self.config = config
288
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
289
- self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
290
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
291
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
292
- self.injection = LTIInjection(config)
293
- self.act_halting = ACTHalting(config)
294
- self.lora_adapter = LoRAAdapter(config)
295
- self.loop_embed_dim = config.loop_embed_dim
296
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
297
- n_loops = n_loops or self.config.max_loop_iters
298
- input_embedding = input_embedding if input_embedding is not None else hidden_states
299
- total_aux_loss = 0.0
300
- for layer in self.prelude_layers:
301
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
302
- e = hidden_states.clone()
303
- B, T_seq, D = hidden_states.shape
304
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
305
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
306
- h_out = torch.zeros_like(hidden_states)
307
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
308
- for t in range(n_loops):
309
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
310
- if t > 0:
311
- injection = self.injection(hidden_states, input_embedding)
312
- hidden_states = hidden_states + injection
313
- new_past_key_values = []
314
- for i, layer in enumerate(self.recurrent_layers):
315
- hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
316
- total_aux_loss = total_aux_loss + aux_loss
317
- new_past_key_values.append(past_kv)
318
- lora_delta = self.lora_adapter(hidden_states, t)
319
- hidden_states = hidden_states + lora_delta
320
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
321
- still_running = ~halted
322
- remainder = (1.0 - cumulative_p).clamp(min=0)
323
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
324
- weight = weight * still_running.to(hidden_states.dtype)
325
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
326
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
327
- halted = halted | (cumulative_p >= self.config.act_threshold)
328
- if halted.all() and not self.training:
329
- break
330
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
331
- hidden_states = h_out + never_halted * hidden_states
332
- for layer in self.coda_layers:
333
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
334
- hidden_states = self.norm(hidden_states)
335
- return hidden_states, total_aux_loss, new_past_key_values
336
-
337
- class SpiderPortalForConditionalGeneration(nn.Module):
338
- def __init__(self, config):
339
- super().__init__()
340
- self.config = config
341
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
342
- self.model = SpiderPortalMoEModel(config)
343
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
344
- if config.tie_word_embeddings:
345
- self.lm_head.weight = self.embed_tokens.weight
346
- self.apply(self._init_weights)
347
- def _init_weights(self, module):
348
- if isinstance(module, nn.Linear):
349
- if hasattr(self, 'model') and module is self.model.injection.B:
350
- return
351
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
352
- if module.bias is not None:
353
- module.bias.data.zero_()
354
- elif isinstance(module, nn.Embedding):
355
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
356
- def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
357
- hidden_states = self.embed_tokens(input_ids)
358
- model_dtype = next(self.model.parameters()).dtype
359
- hidden_states = hidden_states.to(model_dtype)
360
- input_embedding = hidden_states.clone()
361
- if attention_mask is None:
362
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
363
- causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
364
- causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
365
- causal_mask = causal_mask.triu(1)
366
- hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
367
- logits = self.lm_head(hidden_states)
368
- loss = None
369
- if labels is not None:
370
- shift_logits = logits[..., :-1, :].contiguous()
371
- shift_labels = labels[..., 1:].contiguous()
372
- loss_fct = CrossEntropyLoss()
373
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
374
- loss = loss + self.config.router_aux_loss_coef * aux_loss
375
- return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
376
- def get_num_params(self):
377
- total = sum(p.numel() for p in self.parameters())
378
- return {"total": total, "trainable": total}
379
-
380
- def train_single_gpu():
381
- device = torch.device("cuda")
382
- gpu_name = torch.cuda.get_device_name(0)
383
- gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
384
- print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
385
-
386
- config = SpiderPortalConfig(
387
- hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
388
- num_key_value_heads=2, intermediate_size=1024,
389
- num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
390
- router_aux_loss_coef=0.05, max_loop_iters=2,
391
- prelude_layers=2, coda_layers=2, lora_rank=32,
392
- rope_theta=10000000.0,
393
- rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
394
- max_position_embeddings=131072, sliding_window=4096,
395
- tie_word_embeddings=True,
396
- )
397
-
398
- print("Building model...")
399
- model = SpiderPortalForConditionalGeneration(config)
400
- model = model.to(torch.bfloat16).to(device)
401
-
402
- params = model.get_num_params()
403
- print(f"Model: {params['total']/1e6:.1f}M params")
404
- print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
405
-
406
- BASE_LR = 1e-3
407
- WARMUP_STEPS = 500
408
- optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01, betas=(0.9, 0.95))
409
-
410
- import pandas as pd
411
- data_dir = Path(__file__).parent / "data"
412
- all_records = []
413
- pkl_file = data_dir / "spiderportal_combined.pkl"
414
- if pkl_file.exists():
415
- print(f"Loading dataset from {pkl_file}...")
416
- df = pd.read_pickle(pkl_file)
417
- all_records = df.to_dict("records")
418
- else:
419
- print(f"No dataset found at {pkl_file}, creating synthetic data...")
420
- all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
421
-
422
- print(f"Loaded {len(all_records):,} samples")
423
-
424
- from transformers import AutoTokenizer
425
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
426
- tokenizer.pad_token = tokenizer.eos_token
427
-
428
- BATCH_SIZE = 128
429
- MAX_LEN = 256
430
- EPOCHS = 3
431
- N_LOOPS = 2
432
-
433
- print(f"Batch size: {BATCH_SIZE} (no grad accum)")
434
- print(f"Effective batch: {BATCH_SIZE}")
435
- print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup (high LR for recurrent MoE)")
436
- print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}")
437
-
438
- def build_prompt(sample):
439
- instruction = str(sample.get("instruction", "")).strip()
440
- inp = str(sample.get("input", "")).strip()
441
- output = str(sample.get("output", "")).strip()
442
- if inp:
443
- return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
444
- return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
445
-
446
- print("Pre-tokenizing dataset...")
447
- prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
448
- mask_len = len(prefix_ids)
449
-
450
- pre_tokenized = []
451
- for i, sample in enumerate(all_records):
452
- instruction = str(sample.get("instruction", "")).strip()
453
- inp = str(sample.get("input", "")).strip()
454
- output = str(sample.get("output", "")).strip()
455
- if inp:
456
- text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token
457
- else:
458
- text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token
459
- enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length")
460
- input_ids = enc["input_ids"]
461
- labels = input_ids[:]
462
- for j in range(min(mask_len, len(labels))):
463
- labels[j] = -100
464
- pre_tokenized.append((input_ids, labels))
465
- if (i + 1) % 50000 == 0:
466
- print(f" Tokenized {i+1:,}/{len(all_records):,}")
467
-
468
- print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples")
469
- del all_records
470
- gc.collect()
471
-
472
- global_step = 0
473
- best_loss = float('inf')
474
- start_time = time.time()
475
- checkpoint_dir = Path("checkpoints")
476
- checkpoint_dir.mkdir(exist_ok=True)
477
- step_ckpt_files = []
478
-
479
- for epoch in range(1, EPOCHS + 1):
480
- if epoch > 1:
481
- for f in step_ckpt_files:
482
- if f.exists():
483
- f.unlink()
484
- print(f" Deleted old step checkpoint: {f.name}")
485
- step_ckpt_files.clear()
486
- gc.collect()
487
-
488
- indices = list(range(len(pre_tokenized)))
489
- random.shuffle(indices)
490
- total_loss = 0
491
- num_batches = 0
492
- optimizer.zero_grad()
493
-
494
- for batch_start in range(0, len(indices), BATCH_SIZE):
495
- batch_indices = indices[batch_start:batch_start + BATCH_SIZE]
496
- if len(batch_indices) < BATCH_SIZE:
497
- continue
498
-
499
- if global_step < WARMUP_STEPS:
500
- lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
501
- for param_group in optimizer.param_groups:
502
- param_group['lr'] = lr
503
-
504
- batch_input_ids = []
505
- batch_labels = []
506
- for idx in batch_indices:
507
- input_ids, labels = pre_tokenized[idx]
508
- batch_input_ids.append(input_ids)
509
- batch_labels.append(labels)
510
-
511
- input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device)
512
- labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
513
-
514
- if global_step == 0:
515
- print(" [First forward pass - compiling...]")
516
-
517
- outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
518
- loss = outputs["loss"]
519
- loss.backward()
520
-
521
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
522
- optimizer.step()
523
- optimizer.zero_grad()
524
- global_step += 1
525
-
526
- total_loss += loss.item()
527
- num_batches += 1
528
-
529
- if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0:
530
- avg_loss = total_loss / max(num_batches, 1)
531
- elapsed = time.time() - start_time
532
- steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
533
- current_lr = optimizer.param_groups[0]['lr']
534
- samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0
535
- print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec")
536
-
537
- if global_step > 0 and global_step % 500 == 0:
538
- ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt"
539
- state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
540
- torch.save(state_dict, ckpt_path)
541
- step_ckpt_files.append(ckpt_path)
542
- size_mb = ckpt_path.stat().st_size / (1024 * 1024)
543
- print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
544
-
545
- avg_loss = total_loss / max(num_batches, 1)
546
- epoch_time = (time.time() - start_time) / 60
547
- print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min")
548
-
549
- ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt"
550
- torch.save({
551
- "step": global_step,
552
- "epoch": epoch,
553
- "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
554
- "optimizer_state_dict": optimizer.state_dict(),
555
- "config": config.__dict__,
556
- }, ckpt_path)
557
- size_mb = ckpt_path.stat().st_size / (1024 * 1024)
558
- print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
559
-
560
- if avg_loss < best_loss:
561
- best_loss = avg_loss
562
- best_path = checkpoint_dir / "spiderportal-v5-best.pt"
563
- torch.save({
564
- "step": global_step,
565
- "epoch": epoch,
566
- "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
567
- "optimizer_state_dict": optimizer.state_dict(),
568
- "config": config.__dict__,
569
- }, best_path)
570
- size_mb = best_path.stat().st_size / (1024 * 1024)
571
- print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)")
572
-
573
- total_time = (time.time() - start_time) / 3600
574
- print(f"\nTraining complete!")
575
- print(f"Best loss: {best_loss:.4f}")
576
- print(f"Total time: {total_time:.2f} hours")
577
- print(f"Total steps: {global_step}")
578
- print(f"Checkpoints saved to: {checkpoint_dir}")
579
-
580
- if __name__ == "__main__":
581
- train_single_gpu()