CLIWorks commited on
Commit
91a6fe7
·
verified ·
1 Parent(s): 9a53bd6

Delete AI-Transfer

Browse files
AI-Transfer/README.md DELETED
@@ -1,81 +0,0 @@
1
- # SpiderPortal v5 — GPU Training Transfer Package
2
-
3
- ## Contents
4
- - `scripts/train_single_gpu.py` — Optimized single-GPU training script
5
- - `scripts/train_ddp.py` — Original DDP script (for reference)
6
- - `data/spiderportal_combined.pkl` — Training dataset (491K samples, 537MB)
7
- - `notebooks/spiderportal_gpu.ipynb` — Original Kaggle notebook (for reference)
8
-
9
- ## Quick Start
10
-
11
- ### 1. Setup
12
- ```bash
13
- pip install torch transformers pandas
14
- ```
15
-
16
- ### 2. Run Training
17
- ```bash
18
- python scripts/train_single_gpu.py
19
- ```
20
-
21
- ## Configuration (edit in `train_single_gpu.py`)
22
-
23
- | Parameter | Default | RTX 6000 (96GB) |
24
- |-----------|---------|-----------------|
25
- | `BATCH_SIZE` | 64 | 64-128 |
26
- | `MAX_LEN` | 256 | 256-512 |
27
- | `EPOCHS` | 3 | 3 |
28
- | `N_LOOPS` | 2 | 2 |
29
- | `BASE_LR` | 2e-5 | 2e-5 |
30
- | `WARMUP_STEPS` | 1000 | 1000 |
31
-
32
- ## Expected Performance (RTX PRO 6000, 96GB)
33
-
34
- | Metric | Value |
35
- |--------|-------|
36
- | Model size | 659M params |
37
- | Active params (MoE) | ~59M |
38
- | VRAM usage | ~15-25GB |
39
- | Batch size | 64 |
40
- | Steps per epoch | ~7,672 |
41
- | Time per epoch | 40-100 min |
42
- | **Total (3 epochs)** | **2-5 hours** |
43
-
44
- ## Checkpoints
45
-
46
- Saved to `checkpoints/` directory:
47
-
48
- | Checkpoint type | What's saved | Size | Purpose |
49
- |----------------|--------------|------|---------|
50
- | Every 500 steps | Model weights only | ~1.3GB | Testing, transfer, inference |
51
- | End of epoch | Model + optimizer | ~6.6GB | Resume training |
52
- | Best loss | Model + optimizer | ~6.6GB | Resume training |
53
-
54
- Step checkpoints are auto-deleted at the start of each new epoch to free disk space.
55
-
56
- ### Loading a weights-only checkpoint (inference)
57
- ```python
58
- state_dict = torch.load("checkpoints/spiderportal-v5-ep1-step500.pt", map_location="cpu")
59
- model.load_state_dict(state_dict)
60
- ```
61
-
62
- ### Resuming training from an epoch checkpoint
63
- ```python
64
- ckpt = torch.load("checkpoints/spiderportal-v5-ep1.pt", map_location="cpu")
65
- model.load_state_dict(ckpt["model_state_dict"])
66
- optimizer.load_state_dict(ckpt["optimizer_state_dict"])
67
- start_epoch = ckpt["epoch"]
68
- ```
69
-
70
- ### Peak disk usage during training: ~20GB
71
-
72
- ## Model Architecture
73
-
74
- - **Type**: MoE-RDT (Mixture of Experts + Recurrent Depth Transformer)
75
- - **Total params**: 659M
76
- - **Active params**: ~59M (1 routed expert + 1 shared expert per token)
77
- - **Experts**: 64 routed + 1 shared
78
- - **Layers**: 2 prelude → 8 recurrent (MoE) → 2 coda
79
- - **Attention**: GQA (8 heads, 2 KV heads)
80
- - **Context**: 131K tokens (YaRN RoPE scaling)
81
- - **Loop**: ACT halting + LTI injection + LoRA adapters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
AI-Transfer/data/spiderportal_combined.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:706bb0870f785241df36c3dcc1925deb8a47ce6003dfe506b6c0751624c34e63
3
- size 563067610
 
 
 
 
AI-Transfer/notebooks/spiderportal_gpu.ipynb DELETED
@@ -1,466 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# SpiderPortal MoE-RDT v5 \u2014 Multi-GPU Training (DDP)\n",
8
- "\n",
9
- "Optimized for 2\u00d7 T4 GPUs (32GB total VRAM).\n",
10
- "- **Total params**: ~659M (64 experts)\n",
11
- "- **Active params**: ~59M per token\n",
12
- "- **Training**: bf16, DDP, gradient accumulation\n",
13
- "- **Expected time**: ~1-1.5 hours for 3 epochs"
14
- ]
15
- },
16
- {
17
- "cell_type": "code",
18
- "execution_count": null,
19
- "metadata": {},
20
- "outputs": [],
21
- "source": [
22
- "!pip install -q transformers pandas safetensors\n",
23
- "\n",
24
- "import torch\n",
25
- "import torch.nn as nn\n",
26
- "import torch.nn.functional as F\n",
27
- "import math, os, json, gc, random, time\n",
28
- "from pathlib import Path\n",
29
- "from dataclasses import dataclass\n",
30
- "from typing import Optional, Tuple, Dict, List\n",
31
- "from torch.nn import CrossEntropyLoss\n",
32
- "\n",
33
- "print(f\"PyTorch: {torch.__version__}\")\n",
34
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
35
- "print(f\"GPU count: {torch.cuda.device_count()}\")\n",
36
- "for i in range(torch.cuda.device_count()):\n",
37
- " p = torch.cuda.get_device_properties(i)\n",
38
- " print(f\" GPU {i}: {torch.cuda.get_device_name(i)} ({p.total_memory / 1e9:.1f}GB)\")"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": null,
44
- "metadata": {},
45
- "outputs": [],
46
- "source": [
47
- "@dataclass\n",
48
- "class SpiderPortalConfig:\n",
49
- " vocab_size: int = 50278\n",
50
- " hidden_size: int = 384\n",
51
- " num_hidden_layers: int = 8\n",
52
- " num_attention_heads: int = 8\n",
53
- " num_key_value_heads: int = 2\n",
54
- " intermediate_size: int = 1024\n",
55
- " hidden_act: str = \"silu\"\n",
56
- " num_experts: int = 64\n",
57
- " num_experts_per_tok: int = 1\n",
58
- " num_shared_experts: int = 1\n",
59
- " router_aux_loss_coef: float = 0.05\n",
60
- " max_loop_iters: int = 4\n",
61
- " act_threshold: float = 0.5\n",
62
- " max_position_embeddings: int = 131072\n",
63
- " rope_theta: float = 10000000.0\n",
64
- " rope_scaling: dict = None\n",
65
- " sliding_window: int = 4096\n",
66
- " attention_dropout: float = 0.0\n",
67
- " rms_norm_eps: float = 1e-6\n",
68
- " initializer_range: float = 0.02\n",
69
- " use_cache: bool = True\n",
70
- " tie_word_embeddings: bool = True\n",
71
- " prelude_layers: int = 2\n",
72
- " coda_layers: int = 2\n",
73
- " lora_rank: int = 32\n",
74
- " loop_embed_dim: int = 48\n",
75
- " vision_hidden_size: int = 384\n",
76
- " audio_hidden_size: int = 512\n",
77
- " vision_num_frames: int = 60\n",
78
- " vision_tokens_per_frame: int = 256\n",
79
- " vision_temporal_tokens: int = 64\n",
80
- " vision_temporal_layers: int = 2\n",
81
- " model_type: str = \"spiderportal\"\n",
82
- " torch_dtype: str = \"bfloat16\"\n",
83
- "\n",
84
- "def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):\n",
85
- " freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))\n",
86
- " angles = loop_t * freqs\n",
87
- " emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]\n",
88
- " emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)\n",
89
- " emb_full[:loop_dim] = emb\n",
90
- " return h + emb_full.unsqueeze(0).unsqueeze(0)\n",
91
- "\n",
92
- "class SpiderPortalRMSNorm(nn.Module):\n",
93
- " def __init__(self, hidden_size, eps=1e-6):\n",
94
- " super().__init__()\n",
95
- " self.weight = nn.Parameter(torch.ones(hidden_size))\n",
96
- " self.variance_epsilon = eps\n",
97
- " def forward(self, hidden_states):\n",
98
- " input_dtype = hidden_states.dtype\n",
99
- " hidden_states = hidden_states.to(torch.float32)\n",
100
- " variance = hidden_states.pow(2).mean(-1, keepdim=True)\n",
101
- " hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n",
102
- " return self.weight.to(input_dtype) * hidden_states.to(input_dtype)\n",
103
- "\n",
104
- "def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):\n",
105
- " dim = head_dim\n",
106
- " orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))\n",
107
- " pos_freqs = torch.arange(0, dim, 2).float() / dim\n",
108
- " beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))\n",
109
- " scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))\n",
110
- " return orig_inv_freq * scale\n",
111
- "\n",
112
- "class SpiderPortalGQA(nn.Module):\n",
113
- " def __init__(self, config):\n",
114
- " super().__init__()\n",
115
- " self.config = config\n",
116
- " self.hidden_size = config.hidden_size\n",
117
- " self.num_heads = config.num_attention_heads\n",
118
- " self.num_kv_heads = config.num_key_value_heads\n",
119
- " self.head_dim = config.hidden_size // config.num_attention_heads\n",
120
- " self.num_key_value_groups = self.num_heads // self.num_kv_heads\n",
121
- " self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
122
- " self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)\n",
123
- " self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)\n",
124
- " self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
125
- " self.attention_dropout = config.attention_dropout\n",
126
- " rope_scaling = getattr(config, 'rope_scaling', None)\n",
127
- " if rope_scaling and rope_scaling.get(\"type\") == \"yarn\":\n",
128
- " factor = rope_scaling.get(\"factor\", 1.0)\n",
129
- " orig_max_pos = rope_scaling.get(\"original_max_position_embeddings\", config.max_position_embeddings)\n",
130
- " inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)\n",
131
- " else:\n",
132
- " inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))\n",
133
- " self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n",
134
- " def _rotate_half(self, x):\n",
135
- " x1 = x[..., :x.shape[-1] // 2]\n",
136
- " x2 = x[..., x.shape[-1] // 2:]\n",
137
- " return torch.cat((-x2, x1), dim=-1)\n",
138
- " def _apply_rotary(self, x, cos, sin):\n",
139
- " return (x * cos) + (self._rotate_half(x) * sin)\n",
140
- " def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):\n",
141
- " bsz, q_len, _ = hidden_states.size()\n",
142
- " query_states = self.q_proj(hidden_states)\n",
143
- " key_states = self.k_proj(hidden_states)\n",
144
- " value_states = self.v_proj(hidden_states)\n",
145
- " query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
146
- " key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n",
147
- " value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n",
148
- " if position_ids is None:\n",
149
- " position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)\n",
150
- " max_pos = position_ids.max().item() + 1\n",
151
- " seq_len = max(max_pos, q_len)\n",
152
- " t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)\n",
153
- " freqs = torch.outer(t, self.inv_freq)\n",
154
- " emb = torch.cat((freqs, freqs), dim=-1)\n",
155
- " cos, sin = emb.cos(), emb.sin()\n",
156
- " cos = cos[position_ids].unsqueeze(1)\n",
157
- " sin = sin[position_ids].unsqueeze(1)\n",
158
- " query_states = self._apply_rotary(query_states, cos, sin)\n",
159
- " key_states = self._apply_rotary(key_states, cos, sin)\n",
160
- " if past_key_value is not None:\n",
161
- " key_states = torch.cat([past_key_value[0], key_states], dim=2)\n",
162
- " value_states = torch.cat([past_key_value[1], value_states], dim=2)\n",
163
- " past_kv = (key_states, value_states) if use_cache else None\n",
164
- " key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)\n",
165
- " value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)\n",
166
- " attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
167
- " if attention_mask is not None:\n",
168
- " attn_weights = attn_weights + attention_mask\n",
169
- " attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)\n",
170
- " attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n",
171
- " attn_output = torch.matmul(attn_weights, value_states)\n",
172
- " attn_output = attn_output.transpose(1, 2).contiguous()\n",
173
- " attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n",
174
- " return self.o_proj(attn_output), past_kv\n",
175
- "\n",
176
- "class SpiderPortalExpert(nn.Module):\n",
177
- " def __init__(self, config, intermediate_size=None):\n",
178
- " super().__init__()\n",
179
- " inter_size = intermediate_size or config.intermediate_size\n",
180
- " self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)\n",
181
- " self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)\n",
182
- " self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)\n",
183
- " self.act_fn = nn.SiLU()\n",
184
- " def forward(self, hidden_states):\n",
185
- " return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))\n",
186
- "\n",
187
- "class SpiderPortalRouter(nn.Module):\n",
188
- " def __init__(self, config):\n",
189
- " super().__init__()\n",
190
- " self.num_experts = config.num_experts\n",
191
- " self.num_experts_per_tok = config.num_experts_per_tok\n",
192
- " self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)\n",
193
- " self.register_buffer(\"router_bias\", torch.zeros(config.num_experts))\n",
194
- " def forward(self, hidden_states):\n",
195
- " router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight\n",
196
- " routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)\n",
197
- " biased_logits = router_logits + self.router_bias\n",
198
- " biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)\n",
199
- " top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)\n",
200
- " top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)\n",
201
- " top_weights = top_weights.to(hidden_states.dtype)\n",
202
- " mean_probs = routing_weights.mean(dim=0)\n",
203
- " aux_loss = self.num_experts * (mean_probs * mean_probs).sum()\n",
204
- " return top_weights, top_indices, aux_loss\n",
205
- "\n",
206
- "class SpiderPortalMoE(nn.Module):\n",
207
- " def __init__(self, config):\n",
208
- " super().__init__()\n",
209
- " self.config = config\n",
210
- " self.num_experts = config.num_experts\n",
211
- " self.num_experts_per_tok = config.num_experts_per_tok\n",
212
- " self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])\n",
213
- " self.shared_expert = SpiderPortalExpert(config)\n",
214
- " self.router = SpiderPortalRouter(config)\n",
215
- " def forward(self, hidden_states):\n",
216
- " batch_size, seq_len, hidden_dim = hidden_states.shape\n",
217
- " top_weights, top_indices, aux_loss = self.router(hidden_states)\n",
218
- " flat_hidden = hidden_states.view(-1, hidden_dim)\n",
219
- " final_output = torch.zeros_like(flat_hidden)\n",
220
- " for expert_idx in range(self.num_experts_per_tok):\n",
221
- " expert_ids = top_indices[:, expert_idx]\n",
222
- " expert_weights = top_weights[:, expert_idx:expert_idx+1]\n",
223
- " for e in range(self.num_experts):\n",
224
- " mask = expert_ids == e\n",
225
- " if mask.any():\n",
226
- " expert_output = self.experts[e](flat_hidden[mask])\n",
227
- " final_output[mask] += expert_output * expert_weights[mask]\n",
228
- " shared_output = self.shared_expert(flat_hidden)\n",
229
- " final_output = final_output + shared_output\n",
230
- " return final_output.view(batch_size, seq_len, hidden_dim), aux_loss\n",
231
- "\n",
232
- "class SpiderPortalDenseLayer(nn.Module):\n",
233
- " def __init__(self, config):\n",
234
- " super().__init__()\n",
235
- " self.self_attn = SpiderPortalGQA(config)\n",
236
- " dense_intermediate = config.hidden_size * 4 // 3\n",
237
- " self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)\n",
238
- " self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
239
- " self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
240
- " def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):\n",
241
- " attn_input = self.input_layernorm(hidden_states)\n",
242
- " attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)\n",
243
- " hidden_states = hidden_states + attn_output\n",
244
- " ffn_input = self.post_attention_layernorm(hidden_states)\n",
245
- " ffn_output = self.ffn(ffn_input)\n",
246
- " hidden_states = hidden_states + ffn_output\n",
247
- " return hidden_states, past_kv\n",
248
- "\n",
249
- "class SpiderPortalMoELayer(nn.Module):\n",
250
- " def __init__(self, config, layer_idx):\n",
251
- " super().__init__()\n",
252
- " self.layer_idx = layer_idx\n",
253
- " self.self_attn = SpiderPortalGQA(config)\n",
254
- " self.moe = SpiderPortalMoE(config)\n",
255
- " self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
256
- " self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
257
- " def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):\n",
258
- " attn_input = self.input_layernorm(hidden_states)\n",
259
- " attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)\n",
260
- " hidden_states = hidden_states + attn_output\n",
261
- " moe_input = self.post_attention_layernorm(hidden_states)\n",
262
- " moe_output, aux_loss = self.moe(moe_input)\n",
263
- " hidden_states = hidden_states + moe_output\n",
264
- " return hidden_states, aux_loss, past_kv\n",
265
- "\n",
266
- "class LTIInjection(nn.Module):\n",
267
- " def __init__(self, config):\n",
268
- " super().__init__()\n",
269
- " self.hidden_size = config.hidden_size\n",
270
- " self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))\n",
271
- " self.delta_t = nn.Parameter(torch.tensor(1.0))\n",
272
- " self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
273
- " with torch.no_grad():\n",
274
- " self.B.weight.data.normal_(mean=0.0, std=0.01)\n",
275
- " def get_A(self):\n",
276
- " return -torch.exp(self.log_A)\n",
277
- " def forward(self, h_t, e):\n",
278
- " A = self.get_A()\n",
279
- " return A * h_t + self.B(e)\n",
280
- "\n",
281
- "class ACTHalting(nn.Module):\n",
282
- " def __init__(self, config):\n",
283
- " super().__init__()\n",
284
- " self.halt_predictor = nn.Linear(config.hidden_size, 1)\n",
285
- " self.threshold = config.act_threshold\n",
286
- " def forward(self, hidden_states):\n",
287
- " return torch.sigmoid(self.halt_predictor(hidden_states))\n",
288
- "\n",
289
- "class LoRAAdapter(nn.Module):\n",
290
- " def __init__(self, config):\n",
291
- " super().__init__()\n",
292
- " rank = config.lora_rank\n",
293
- " self.down = nn.Linear(config.hidden_size, rank, bias=False)\n",
294
- " self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)\n",
295
- " self.scale = nn.Embedding(config.max_loop_iters, rank)\n",
296
- " with torch.no_grad():\n",
297
- " self.scale.weight.data.zero_()\n",
298
- " self.down.weight.data.normal_(mean=0.0, std=0.001)\n",
299
- " def forward(self, x, loop_t):\n",
300
- " max_t = self.scale.num_embeddings - 1\n",
301
- " t_idx = min(loop_t, max_t)\n",
302
- " s = self.scale(torch.tensor(t_idx, device=x.device))\n",
303
- " down = self.down(x) * s\n",
304
- " return down @ self.B\n",
305
- "\n",
306
- "class SpiderPortalMoEModel(nn.Module):\n",
307
- " def __init__(self, config):\n",
308
- " super().__init__()\n",
309
- " self.config = config\n",
310
- " self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])\n",
311
- " self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])\n",
312
- " self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])\n",
313
- " self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
314
- " self.injection = LTIInjection(config)\n",
315
- " self.act_halting = ACTHalting(config)\n",
316
- " self.lora_adapter = LoRAAdapter(config)\n",
317
- " self.loop_embed_dim = config.loop_embed_dim\n",
318
- " def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):\n",
319
- " n_loops = n_loops or self.config.max_loop_iters\n",
320
- " input_embedding = input_embedding if input_embedding is not None else hidden_states\n",
321
- " total_aux_loss = 0.0\n",
322
- " for layer in self.prelude_layers:\n",
323
- " hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)\n",
324
- " e = hidden_states.clone()\n",
325
- " B, T_seq, D = hidden_states.shape\n",
326
- " halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)\n",
327
- " cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)\n",
328
- " h_out = torch.zeros_like(hidden_states)\n",
329
- " past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)\n",
330
- " for t in range(n_loops):\n",
331
- " h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)\n",
332
- " if t > 0:\n",
333
- " injection = self.injection(hidden_states, input_embedding)\n",
334
- " hidden_states = hidden_states + injection\n",
335
- " new_past_key_values = []\n",
336
- " for i, layer in enumerate(self.recurrent_layers):\n",
337
- " hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)\n",
338
- " total_aux_loss = total_aux_loss + aux_loss\n",
339
- " new_past_key_values.append(past_kv)\n",
340
- " lora_delta = self.lora_adapter(hidden_states, t)\n",
341
- " hidden_states = hidden_states + lora_delta\n",
342
- " halt_prob = self.act_halting(hidden_states).squeeze(-1)\n",
343
- " still_running = ~halted\n",
344
- " remainder = (1.0 - cumulative_p).clamp(min=0)\n",
345
- " weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)\n",
346
- " weight = weight * still_running.to(hidden_states.dtype)\n",
347
- " h_out = h_out + weight.unsqueeze(-1) * hidden_states\n",
348
- " cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)\n",
349
- " halted = halted | (cumulative_p >= self.config.act_threshold)\n",
350
- " if halted.all() and not self.training:\n",
351
- " break\n",
352
- " never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)\n",
353
- " hidden_states = h_out + never_halted * hidden_states\n",
354
- " for layer in self.coda_layers:\n",
355
- " hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)\n",
356
- " hidden_states = self.norm(hidden_states)\n",
357
- " return hidden_states, total_aux_loss, new_past_key_values\n",
358
- "\n",
359
- "class SpiderPortalForConditionalGeneration(nn.Module):\n",
360
- " def __init__(self, config):\n",
361
- " super().__init__()\n",
362
- " self.config = config\n",
363
- " self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n",
364
- " self.model = SpiderPortalMoEModel(config)\n",
365
- " self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n",
366
- " if config.tie_word_embeddings:\n",
367
- " self.lm_head.weight = self.embed_tokens.weight\n",
368
- " self.apply(self._init_weights)\n",
369
- " def _init_weights(self, module):\n",
370
- " if isinstance(module, nn.Linear):\n",
371
- " if hasattr(self, 'model') and module is self.model.injection.B:\n",
372
- " return\n",
373
- " module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
374
- " if module.bias is not None:\n",
375
- " module.bias.data.zero_()\n",
376
- " elif isinstance(module, nn.Embedding):\n",
377
- " module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
378
- " def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):\n",
379
- " hidden_states = self.embed_tokens(input_ids)\n",
380
- " model_dtype = next(self.model.parameters()).dtype\n",
381
- " hidden_states = hidden_states.to(model_dtype)\n",
382
- " input_embedding = hidden_states.clone()\n",
383
- " if attention_mask is None:\n",
384
- " attention_mask = torch.ones_like(input_ids, dtype=torch.bool)\n",
385
- " causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)\n",
386
- " causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)\n",
387
- " causal_mask = causal_mask.triu(1)\n",
388
- " hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)\n",
389
- " logits = self.lm_head(hidden_states)\n",
390
- " loss = None\n",
391
- " if labels is not None:\n",
392
- " shift_logits = logits[..., :-1, :].contiguous()\n",
393
- " shift_labels = labels[..., 1:].contiguous()\n",
394
- " loss_fct = CrossEntropyLoss()\n",
395
- " loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
396
- " loss = loss + self.config.router_aux_loss_coef * aux_loss\n",
397
- " return {\"loss\": loss, \"logits\": logits, \"aux_loss\": aux_loss, \"past_key_values\": past_kv}\n",
398
- " def get_num_params(self):\n",
399
- " total = sum(p.numel() for p in self.parameters())\n",
400
- " return {\"total\": total, \"trainable\": total}"
401
- ]
402
- },
403
- {
404
- "cell_type": "code",
405
- "execution_count": null,
406
- "metadata": {},
407
- "outputs": [],
408
- "source": [
409
- "# Multi-GPU DDP Training\n",
410
- "# DDP requires a standalone script (mp.spawn doesn't work in notebooks)\n",
411
- "# The script is at scripts/train_ddp.py\n",
412
- "!python scripts/train_ddp.py"
413
- ]
414
- },
415
- {
416
- "cell_type": "code",
417
- "execution_count": null,
418
- "metadata": {},
419
- "outputs": [],
420
- "source": [
421
- "# Test generation (after training)\n",
422
- "device = torch.device(\"cuda:0\")\n",
423
- "config = SpiderPortalConfig(\n",
424
- " prelude_layers=2, coda_layers=2, lora_rank=32,\n",
425
- " rope_theta=10000000.0, tie_word_embeddings=True,\n",
426
- ")\n",
427
- "model = SpiderPortalForConditionalGeneration(config)\n",
428
- "\n",
429
- "checkpoint = torch.load(\"/kaggle/working/spiderportal-v5-ep1/model.pt\", map_location=\"cpu\")\n",
430
- "model.load_state_dict(checkpoint)\n",
431
- "model = model.to(torch.bfloat16).to(device)\n",
432
- "model.eval()\n",
433
- "\n",
434
- "from transformers import AutoTokenizer\n",
435
- "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
436
- "tokenizer.pad_token = tokenizer.eos_token\n",
437
- "\n",
438
- "with torch.no_grad():\n",
439
- " prompt = \"Question: Instruction: What is the capital of France?\\nAnswer:\"\n",
440
- " input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].to(device)\n",
441
- " generated = input_ids.clone()\n",
442
- " for _ in range(32):\n",
443
- " outputs = model(generated, n_loops=1)\n",
444
- " next_token = torch.argmax(outputs[\"logits\"][:, -1, :], dim=-1, keepdim=True)\n",
445
- " generated = torch.cat([generated, next_token], dim=1)\n",
446
- " if next_token.item() == tokenizer.eos_token_id:\n",
447
- " break\n",
448
- " print(f\"Generated: {tokenizer.decode(generated[0])}\")"
449
- ]
450
- }
451
- ],
452
- "metadata": {
453
- "kernelspec": {
454
- "display_name": "Python 3",
455
- "language": "python",
456
- "name": "python3"
457
- },
458
- "language_info": {
459
- "name": "python",
460
- "version": "3.10.0"
461
- },
462
- "accelerator": "GPU"
463
- },
464
- "nbformat": 4,
465
- "nbformat_minor": 4
466
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
AI-Transfer/scripts/train_ddp.py DELETED
@@ -1,601 +0,0 @@
1
- #!/usr/bin/env python3
2
- """SpiderPortal v5 — Multi-GPU DDP Training.
3
-
4
- Run from Kaggle notebook cell:
5
- !python train_ddp.py
6
-
7
- Or directly:
8
- python -m torch.distributed.run --nproc_per_node=2 train_ddp.py
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- import torch.distributed as dist
15
- import math
16
- import os
17
- import json
18
- import gc
19
- import random
20
- import time
21
- import subprocess
22
- from pathlib import Path
23
- from dataclasses import dataclass
24
- from typing import Optional, Tuple, Dict, List
25
- from torch.nn import CrossEntropyLoss
26
- from torch.nn.parallel import DistributedDataParallel as DDP
27
-
28
- @dataclass
29
- class SpiderPortalConfig:
30
- vocab_size: int = 50278
31
- hidden_size: int = 384
32
- num_hidden_layers: int = 8
33
- num_attention_heads: int = 8
34
- num_key_value_heads: int = 2
35
- intermediate_size: int = 1024
36
- hidden_act: str = "silu"
37
- num_experts: int = 64
38
- num_experts_per_tok: int = 1
39
- num_shared_experts: int = 1
40
- router_aux_loss_coef: float = 0.05
41
- max_loop_iters: int = 4
42
- act_threshold: float = 0.5
43
- max_position_embeddings: int = 131072
44
- rope_theta: float = 10000000.0
45
- rope_scaling: dict = None
46
- sliding_window: int = 4096
47
- attention_dropout: float = 0.0
48
- rms_norm_eps: float = 1e-6
49
- initializer_range: float = 0.02
50
- use_cache: bool = True
51
- tie_word_embeddings: bool = True
52
- prelude_layers: int = 2
53
- coda_layers: int = 2
54
- lora_rank: int = 32
55
- loop_embed_dim: int = 48
56
- vision_hidden_size: int = 384
57
- audio_hidden_size: int = 512
58
- vision_num_frames: int = 60
59
- vision_tokens_per_frame: int = 256
60
- vision_temporal_tokens: int = 64
61
- vision_temporal_layers: int = 2
62
- model_type: str = "spiderportal"
63
- torch_dtype: str = "bfloat16"
64
-
65
- def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
66
- freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
67
- angles = loop_t * freqs
68
- emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
69
- emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
70
- emb_full[:loop_dim] = emb
71
- return h + emb_full.unsqueeze(0).unsqueeze(0)
72
-
73
- class SpiderPortalRMSNorm(nn.Module):
74
- def __init__(self, hidden_size, eps=1e-6):
75
- super().__init__()
76
- self.weight = nn.Parameter(torch.ones(hidden_size))
77
- self.variance_epsilon = eps
78
- def forward(self, hidden_states):
79
- input_dtype = hidden_states.dtype
80
- hidden_states = hidden_states.to(torch.float32)
81
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
82
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
83
- return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
84
-
85
- def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
86
- dim = head_dim
87
- orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
88
- pos_freqs = torch.arange(0, dim, 2).float() / dim
89
- beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
90
- scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
91
- return orig_inv_freq * scale
92
-
93
- class SpiderPortalGQA(nn.Module):
94
- def __init__(self, config):
95
- super().__init__()
96
- self.config = config
97
- self.hidden_size = config.hidden_size
98
- self.num_heads = config.num_attention_heads
99
- self.num_kv_heads = config.num_key_value_heads
100
- self.head_dim = config.hidden_size // config.num_attention_heads
101
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
102
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
103
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
104
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
105
- self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
106
- self.attention_dropout = config.attention_dropout
107
- rope_scaling = getattr(config, 'rope_scaling', None)
108
- if rope_scaling and rope_scaling.get("type") == "yarn":
109
- factor = rope_scaling.get("factor", 1.0)
110
- orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
111
- inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
112
- else:
113
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
114
- self.register_buffer("inv_freq", inv_freq, persistent=False)
115
- def _rotate_half(self, x):
116
- x1 = x[..., :x.shape[-1] // 2]
117
- x2 = x[..., x.shape[-1] // 2:]
118
- return torch.cat((-x2, x1), dim=-1)
119
- def _apply_rotary(self, x, cos, sin):
120
- return (x * cos) + (self._rotate_half(x) * sin)
121
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
122
- bsz, q_len, _ = hidden_states.size()
123
- query_states = self.q_proj(hidden_states)
124
- key_states = self.k_proj(hidden_states)
125
- value_states = self.v_proj(hidden_states)
126
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
127
- key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
128
- value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
129
- if position_ids is None:
130
- position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
131
- max_pos = position_ids.max().item() + 1
132
- seq_len = max(max_pos, q_len)
133
- t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
134
- freqs = torch.outer(t, self.inv_freq)
135
- emb = torch.cat((freqs, freqs), dim=-1)
136
- cos, sin = emb.cos(), emb.sin()
137
- cos = cos[position_ids].unsqueeze(1)
138
- sin = sin[position_ids].unsqueeze(1)
139
- query_states = self._apply_rotary(query_states, cos, sin)
140
- key_states = self._apply_rotary(key_states, cos, sin)
141
- if past_key_value is not None:
142
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
143
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
144
- past_kv = (key_states, value_states) if use_cache else None
145
- key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
146
- value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
147
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
148
- if attention_mask is not None:
149
- attn_weights = attn_weights + attention_mask
150
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
151
- attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
152
- attn_output = torch.matmul(attn_weights, value_states)
153
- attn_output = attn_output.transpose(1, 2).contiguous()
154
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
155
- return self.o_proj(attn_output), past_kv
156
-
157
- class SpiderPortalExpert(nn.Module):
158
- def __init__(self, config, intermediate_size=None):
159
- super().__init__()
160
- inter_size = intermediate_size or config.intermediate_size
161
- self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
162
- self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
163
- self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
164
- self.act_fn = nn.SiLU()
165
- def forward(self, hidden_states):
166
- return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
167
-
168
- class SpiderPortalRouter(nn.Module):
169
- def __init__(self, config):
170
- super().__init__()
171
- self.num_experts = config.num_experts
172
- self.num_experts_per_tok = config.num_experts_per_tok
173
- self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
174
- self.register_buffer("router_bias", torch.zeros(config.num_experts))
175
- def forward(self, hidden_states):
176
- router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
177
- routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
178
- biased_logits = router_logits + self.router_bias
179
- biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
180
- top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
181
- top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
182
- top_weights = top_weights.to(hidden_states.dtype)
183
- mean_probs = routing_weights.mean(dim=0)
184
- aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
185
- return top_weights, top_indices, aux_loss
186
-
187
- class SpiderPortalMoE(nn.Module):
188
- def __init__(self, config):
189
- super().__init__()
190
- self.config = config
191
- self.num_experts = config.num_experts
192
- self.num_experts_per_tok = config.num_experts_per_tok
193
- self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
194
- self.shared_expert = SpiderPortalExpert(config)
195
- self.router = SpiderPortalRouter(config)
196
- def forward(self, hidden_states):
197
- batch_size, seq_len, hidden_dim = hidden_states.shape
198
- top_weights, top_indices, aux_loss = self.router(hidden_states)
199
- flat_hidden = hidden_states.view(-1, hidden_dim)
200
- final_output = torch.zeros_like(flat_hidden)
201
- for expert_idx in range(self.num_experts_per_tok):
202
- expert_ids = top_indices[:, expert_idx]
203
- expert_weights = top_weights[:, expert_idx:expert_idx+1]
204
- for e in range(self.num_experts):
205
- mask = expert_ids == e
206
- if mask.any():
207
- expert_output = self.experts[e](flat_hidden[mask])
208
- final_output[mask] += expert_output * expert_weights[mask]
209
- shared_output = self.shared_expert(flat_hidden)
210
- final_output = final_output + shared_output
211
- return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
212
-
213
- class SpiderPortalDenseLayer(nn.Module):
214
- def __init__(self, config):
215
- super().__init__()
216
- self.self_attn = SpiderPortalGQA(config)
217
- dense_intermediate = config.hidden_size * 4 // 3
218
- self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
219
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
221
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
222
- attn_input = self.input_layernorm(hidden_states)
223
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
224
- hidden_states = hidden_states + attn_output
225
- ffn_input = self.post_attention_layernorm(hidden_states)
226
- ffn_output = self.ffn(ffn_input)
227
- hidden_states = hidden_states + ffn_output
228
- return hidden_states, past_kv
229
-
230
- class SpiderPortalMoELayer(nn.Module):
231
- def __init__(self, config, layer_idx):
232
- super().__init__()
233
- self.layer_idx = layer_idx
234
- self.self_attn = SpiderPortalGQA(config)
235
- self.moe = SpiderPortalMoE(config)
236
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
239
- attn_input = self.input_layernorm(hidden_states)
240
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
241
- hidden_states = hidden_states + attn_output
242
- moe_input = self.post_attention_layernorm(hidden_states)
243
- moe_output, aux_loss = self.moe(moe_input)
244
- hidden_states = hidden_states + moe_output
245
- return hidden_states, aux_loss, past_kv
246
-
247
- class LTIInjection(nn.Module):
248
- def __init__(self, config):
249
- super().__init__()
250
- self.hidden_size = config.hidden_size
251
- self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
252
- self.delta_t = nn.Parameter(torch.tensor(1.0))
253
- self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
254
- with torch.no_grad():
255
- self.B.weight.data.normal_(mean=0.0, std=0.01)
256
- def get_A(self):
257
- return -torch.exp(self.log_A)
258
- def forward(self, h_t, e):
259
- A = self.get_A()
260
- return A * h_t + self.B(e)
261
-
262
- class ACTHalting(nn.Module):
263
- def __init__(self, config):
264
- super().__init__()
265
- self.halt_predictor = nn.Linear(config.hidden_size, 1)
266
- self.threshold = config.act_threshold
267
- def forward(self, hidden_states):
268
- return torch.sigmoid(self.halt_predictor(hidden_states))
269
-
270
- class LoRAAdapter(nn.Module):
271
- def __init__(self, config):
272
- super().__init__()
273
- rank = config.lora_rank
274
- self.down = nn.Linear(config.hidden_size, rank, bias=False)
275
- self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
276
- self.scale = nn.Embedding(config.max_loop_iters, rank)
277
- with torch.no_grad():
278
- self.scale.weight.data.zero_()
279
- self.down.weight.data.normal_(mean=0.0, std=0.001)
280
- def forward(self, x, loop_t):
281
- max_t = self.scale.num_embeddings - 1
282
- t_idx = min(loop_t, max_t)
283
- s = self.scale(torch.tensor(t_idx, device=x.device))
284
- down = self.down(x) * s
285
- return down @ self.B
286
-
287
- class SpiderPortalMoEModel(nn.Module):
288
- def __init__(self, config):
289
- super().__init__()
290
- self.config = config
291
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
292
- self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
293
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
294
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
295
- self.injection = LTIInjection(config)
296
- self.act_halting = ACTHalting(config)
297
- self.lora_adapter = LoRAAdapter(config)
298
- self.loop_embed_dim = config.loop_embed_dim
299
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
300
- n_loops = n_loops or self.config.max_loop_iters
301
- input_embedding = input_embedding if input_embedding is not None else hidden_states
302
- total_aux_loss = 0.0
303
- for layer in self.prelude_layers:
304
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
305
- e = hidden_states.clone()
306
- B, T_seq, D = hidden_states.shape
307
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
308
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
309
- h_out = torch.zeros_like(hidden_states)
310
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
311
- for t in range(n_loops):
312
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
313
- if t > 0:
314
- injection = self.injection(hidden_states, input_embedding)
315
- hidden_states = hidden_states + injection
316
- new_past_key_values = []
317
- for i, layer in enumerate(self.recurrent_layers):
318
- hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
319
- total_aux_loss = total_aux_loss + aux_loss
320
- new_past_key_values.append(past_kv)
321
- lora_delta = self.lora_adapter(hidden_states, t)
322
- hidden_states = hidden_states + lora_delta
323
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
324
- still_running = ~halted
325
- remainder = (1.0 - cumulative_p).clamp(min=0)
326
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
327
- weight = weight * still_running.to(hidden_states.dtype)
328
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
329
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
330
- halted = halted | (cumulative_p >= self.config.act_threshold)
331
- if halted.all() and not self.training:
332
- break
333
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
334
- hidden_states = h_out + never_halted * hidden_states
335
- for layer in self.coda_layers:
336
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
337
- hidden_states = self.norm(hidden_states)
338
- return hidden_states, total_aux_loss, new_past_key_values
339
-
340
- class SpiderPortalForConditionalGeneration(nn.Module):
341
- def __init__(self, config):
342
- super().__init__()
343
- self.config = config
344
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
345
- self.model = SpiderPortalMoEModel(config)
346
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
347
- if config.tie_word_embeddings:
348
- self.lm_head.weight = self.embed_tokens.weight
349
- self.apply(self._init_weights)
350
- def _init_weights(self, module):
351
- if isinstance(module, nn.Linear):
352
- if hasattr(self, 'model') and module is self.model.injection.B:
353
- return
354
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
355
- if module.bias is not None:
356
- module.bias.data.zero_()
357
- elif isinstance(module, nn.Embedding):
358
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
359
- def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
360
- hidden_states = self.embed_tokens(input_ids)
361
- model_dtype = next(self.model.parameters()).dtype
362
- hidden_states = hidden_states.to(model_dtype)
363
- input_embedding = hidden_states.clone()
364
- if attention_mask is None:
365
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
366
- causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
367
- causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
368
- causal_mask = causal_mask.triu(1)
369
- hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
370
- logits = self.lm_head(hidden_states)
371
- loss = None
372
- if labels is not None:
373
- shift_logits = logits[..., :-1, :].contiguous()
374
- shift_labels = labels[..., 1:].contiguous()
375
- loss_fct = CrossEntropyLoss()
376
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
377
- loss = loss + self.config.router_aux_loss_coef * aux_loss
378
- return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
379
- def get_num_params(self):
380
- total = sum(p.numel() for p in self.parameters())
381
- return {"total": total, "trainable": total}
382
-
383
- def train_ddp(local_rank, world_size):
384
- dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
385
- torch.cuda.set_device(local_rank)
386
- device = torch.device(f"cuda:{local_rank}")
387
- is_master = local_rank == 0
388
-
389
- if is_master:
390
- print(f"Training on {world_size} GPUs (DDP)")
391
- for i in range(world_size):
392
- p = torch.cuda.get_device_properties(i)
393
- print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({p.total_memory / 1e9:.1f}GB)")
394
-
395
- config = SpiderPortalConfig(
396
- hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
397
- num_key_value_heads=2, intermediate_size=1024,
398
- num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
399
- router_aux_loss_coef=0.05, max_loop_iters=2,
400
- prelude_layers=2, coda_layers=2, lora_rank=32,
401
- rope_theta=10000000.0,
402
- rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
403
- max_position_embeddings=131072, sliding_window=4096,
404
- tie_word_embeddings=True,
405
- )
406
-
407
- model = SpiderPortalForConditionalGeneration(config)
408
- model = model.to(torch.bfloat16).to(device)
409
- model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
410
-
411
- if is_master:
412
- params = model.module.get_num_params()
413
- print(f"Model: {params['total']/1e6:.1f}M params")
414
- print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
415
-
416
- BASE_LR = 2e-5
417
- WARMUP_STEPS = 1000
418
- optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01)
419
-
420
- import pandas as pd
421
- data_dir = Path("/kaggle/input/datasets/cliworks/spiderp-custom-1")
422
- all_records = []
423
- if data_dir.exists():
424
- pkl_file = data_dir / "spiderportal_combined.pkl"
425
- if pkl_file.exists():
426
- df = pd.read_pickle(pkl_file)
427
- all_records = df.to_dict("records")
428
- else:
429
- for f in data_dir.glob("*.parquet"):
430
- try:
431
- df_pq = pd.read_parquet(f)
432
- if all(c in df_pq.columns for c in ["instruction", "input", "output"]):
433
- all_records.extend(df_pq[["instruction", "input", "output"]].to_dict("records"))
434
- except:
435
- pass
436
-
437
- if not all_records:
438
- if is_master:
439
- print("No data found, creating synthetic data...")
440
- all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
441
-
442
- if is_master:
443
- print(f"Loaded {len(all_records)} samples")
444
-
445
- from transformers import AutoTokenizer
446
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
447
- tokenizer.pad_token = tokenizer.eos_token
448
-
449
- BATCH_SIZE = 8
450
- GRAD_ACCUM = 2
451
- MAX_LEN = 256
452
- EPOCHS = 3
453
- N_LOOPS = 1
454
-
455
- if is_master:
456
- print(f"Batch size: {BATCH_SIZE}, Grad accum: {GRAD_ACCUM}")
457
- print(f"Effective batch: {BATCH_SIZE * GRAD_ACCUM * world_size}")
458
- print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup")
459
-
460
- def build_prompt(sample):
461
- instruction = str(sample.get("instruction", "")).strip()
462
- inp = str(sample.get("input", "")).strip()
463
- output = str(sample.get("output", "")).strip()
464
- if inp:
465
- return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
466
- return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
467
-
468
- def format_sample(sample, tokenizer, max_len):
469
- text = build_prompt(sample) + tokenizer.eos_token
470
- enc = tokenizer(text, truncation=True, max_length=max_len, padding="max_length", return_tensors="pt")
471
- input_ids = enc["input_ids"].squeeze(0)
472
- labels = input_ids.clone()
473
- prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
474
- mask_len = min(len(prefix_ids), labels.shape[0])
475
- labels[:mask_len] = -100
476
- return {"input_ids": input_ids, "labels": labels}
477
-
478
- shard_size = len(all_records) // world_size
479
- start_idx = local_rank * shard_size
480
- end_idx = start_idx + shard_size if local_rank < world_size - 1 else len(all_records)
481
- local_samples = all_records[start_idx:end_idx]
482
-
483
- if is_master:
484
- print(f"Data sharding: {len(all_records)} total -> {len(local_samples)} per GPU")
485
-
486
- global_step = 0
487
- best_loss = float('inf')
488
- start_time = time.time()
489
-
490
- for epoch in range(1, EPOCHS + 1):
491
- random.shuffle(local_samples)
492
- total_loss = 0
493
- num_batches = 0
494
- optimizer.zero_grad()
495
-
496
- for i in range(0, len(local_samples), BATCH_SIZE):
497
- batch_samples = local_samples[i:i+BATCH_SIZE]
498
- if len(batch_samples) < BATCH_SIZE:
499
- continue
500
-
501
- if global_step < WARMUP_STEPS:
502
- lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
503
- for param_group in optimizer.param_groups:
504
- param_group['lr'] = lr
505
-
506
- batch = [format_sample(s, tokenizer, MAX_LEN) for s in batch_samples]
507
- input_ids = torch.stack([b["input_ids"] for b in batch]).to(device)
508
- labels = torch.stack([b["labels"] for b in batch]).to(device)
509
-
510
- if global_step == 0 and is_master:
511
- print(" [First forward pass - CUDA graph building...]")
512
-
513
- outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
514
- loss = outputs["loss"] / GRAD_ACCUM
515
- loss.backward()
516
-
517
- if (i // BATCH_SIZE + 1) % GRAD_ACCUM == 0:
518
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
519
- optimizer.step()
520
- optimizer.zero_grad()
521
- global_step += 1
522
-
523
- total_loss += loss.item() * GRAD_ACCUM
524
- num_batches += 1
525
-
526
- if is_master and ((i // BATCH_SIZE) == 0 or (global_step < 20) or (global_step % 50 == 0)):
527
- avg_loss = total_loss / max(num_batches, 1)
528
- elapsed = time.time() - start_time
529
- steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
530
- current_lr = optimizer.param_groups[0]['lr']
531
- print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | avg_loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.1f} steps/hr")
532
-
533
- if is_master and global_step > 0 and global_step % 500 == 0:
534
- output_dir = Path(f"/kaggle/working/spiderportal-v5-ep{epoch}-step{global_step}")
535
- output_dir.mkdir(parents=True, exist_ok=True)
536
- state_dict = {k: v.cpu() for k, v in model.module.state_dict().items()}
537
- torch.save(state_dict, output_dir / "model.pt")
538
- print(f"Saved checkpoint at step {global_step}")
539
-
540
- if is_master:
541
- avg_loss = total_loss / max(num_batches, 1)
542
- print(f"Epoch {epoch}/{EPOCHS} | avg_loss={avg_loss:.4f} | Time: {(time.time()-start_time)/60:.1f}min")
543
- output_dir = Path(f"/kaggle/working/spiderportal-v5-ep{epoch}")
544
- output_dir.mkdir(parents=True, exist_ok=True)
545
- state_dict = {k: v.cpu() for k, v in model.module.state_dict().items()}
546
- torch.save(state_dict, output_dir / "model.pt")
547
- if avg_loss < best_loss:
548
- best_loss = avg_loss
549
-
550
- if is_master:
551
- print(f"Training complete! Best loss: {best_loss:.4f}")
552
- print(f"Total time: {(time.time() - start_time)/60:.1f} minutes")
553
-
554
- dist.destroy_process_group()
555
-
556
- if __name__ == "__main__":
557
- import sys
558
- print(f"Script started, CUDA devices: {torch.cuda.device_count()}")
559
- sys.stdout.flush()
560
-
561
- num_gpus = torch.cuda.device_count()
562
-
563
- if num_gpus <= 1:
564
- dist.init_process_group = lambda *a, **k: None
565
- dist.destroy_process_group = lambda: None
566
- train_ddp(0, 1)
567
- else:
568
- print(f"Launching DDP with {num_gpus} GPUs...")
569
- sys.stdout.flush()
570
- import os
571
- import signal
572
-
573
- # Set NCCL environment before forking
574
- os.environ["MASTER_ADDR"] = "127.0.0.1"
575
- os.environ["MASTER_PORT"] = "29500"
576
- os.environ["WORLD_SIZE"] = str(num_gpus)
577
- os.environ["LOCAL_WORLD_SIZE"] = str(num_gpus)
578
-
579
- procs = []
580
- for rank in range(num_gpus):
581
- print(f"Forking rank {rank}...")
582
- sys.stdout.flush()
583
- pid = os.fork()
584
- if pid == 0:
585
- # Child process
586
- os.environ["RANK"] = str(rank)
587
- os.environ["LOCAL_RANK"] = str(rank)
588
- print(f"Child {rank} starting training...")
589
- sys.stdout.flush()
590
- train_ddp(rank, num_gpus)
591
- os._exit(0)
592
- else:
593
- procs.append(pid)
594
- print(f"Parent: forked child {rank} with PID {pid}")
595
- sys.stdout.flush()
596
-
597
- # Parent process: wait for children
598
- print("Parent waiting for children...")
599
- sys.stdout.flush()
600
- for pid in procs:
601
- os.waitpid(pid, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
AI-Transfer/scripts/train_single_gpu.py DELETED
@@ -1,587 +0,0 @@
1
- #!/usr/bin/env python3
2
- """SpiderPortal v5 — Single-GPU Optimized Training.
3
-
4
- For RTX PRO 6000 (96GB) or similar large-VRAM GPU.
5
- No DDP, maximal batch size, torch.compile, pre-tokenized data.
6
-
7
- Usage:
8
- python train_single_gpu.py
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- import math
15
- import os
16
- import json
17
- import gc
18
- import random
19
- import time
20
- from pathlib import Path
21
- from dataclasses import dataclass
22
- from typing import Optional, Tuple, Dict, List
23
- from torch.nn import CrossEntropyLoss
24
-
25
- @dataclass
26
- class SpiderPortalConfig:
27
- vocab_size: int = 50278
28
- hidden_size: int = 384
29
- num_hidden_layers: int = 8
30
- num_attention_heads: int = 8
31
- num_key_value_heads: int = 2
32
- intermediate_size: int = 1024
33
- hidden_act: str = "silu"
34
- num_experts: int = 64
35
- num_experts_per_tok: int = 1
36
- num_shared_experts: int = 1
37
- router_aux_loss_coef: float = 0.05
38
- max_loop_iters: int = 4
39
- act_threshold: float = 0.5
40
- max_position_embeddings: int = 131072
41
- rope_theta: float = 10000000.0
42
- rope_scaling: dict = None
43
- sliding_window: int = 4096
44
- attention_dropout: float = 0.0
45
- rms_norm_eps: float = 1e-6
46
- initializer_range: float = 0.02
47
- use_cache: bool = True
48
- tie_word_embeddings: bool = True
49
- prelude_layers: int = 2
50
- coda_layers: int = 2
51
- lora_rank: int = 32
52
- loop_embed_dim: int = 48
53
- vision_hidden_size: int = 384
54
- audio_hidden_size: int = 512
55
- vision_num_frames: int = 60
56
- vision_tokens_per_frame: int = 256
57
- vision_temporal_tokens: int = 64
58
- vision_temporal_layers: int = 2
59
- model_type: str = "spiderportal"
60
- torch_dtype: str = "bfloat16"
61
-
62
- def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
63
- freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
64
- angles = loop_t * freqs
65
- emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
66
- emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
67
- emb_full[:loop_dim] = emb
68
- return h + emb_full.unsqueeze(0).unsqueeze(0)
69
-
70
- class SpiderPortalRMSNorm(nn.Module):
71
- def __init__(self, hidden_size, eps=1e-6):
72
- super().__init__()
73
- self.weight = nn.Parameter(torch.ones(hidden_size))
74
- self.variance_epsilon = eps
75
- def forward(self, hidden_states):
76
- input_dtype = hidden_states.dtype
77
- hidden_states = hidden_states.to(torch.float32)
78
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
- return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
81
-
82
- def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
83
- dim = head_dim
84
- orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
85
- pos_freqs = torch.arange(0, dim, 2).float() / dim
86
- beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
87
- scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
88
- return orig_inv_freq * scale
89
-
90
- class SpiderPortalGQA(nn.Module):
91
- def __init__(self, config):
92
- super().__init__()
93
- self.config = config
94
- self.hidden_size = config.hidden_size
95
- self.num_heads = config.num_attention_heads
96
- self.num_kv_heads = config.num_key_value_heads
97
- self.head_dim = config.hidden_size // config.num_attention_heads
98
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
99
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
100
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
101
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
102
- self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
103
- self.attention_dropout = config.attention_dropout
104
- rope_scaling = getattr(config, 'rope_scaling', None)
105
- if rope_scaling and rope_scaling.get("type") == "yarn":
106
- factor = rope_scaling.get("factor", 1.0)
107
- orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
108
- inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
109
- else:
110
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
111
- self.register_buffer("inv_freq", inv_freq, persistent=False)
112
- def _rotate_half(self, x):
113
- x1 = x[..., :x.shape[-1] // 2]
114
- x2 = x[..., x.shape[-1] // 2:]
115
- return torch.cat((-x2, x1), dim=-1)
116
- def _apply_rotary(self, x, cos, sin):
117
- return (x * cos) + (self._rotate_half(x) * sin)
118
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
119
- bsz, q_len, _ = hidden_states.size()
120
- query_states = self.q_proj(hidden_states)
121
- key_states = self.k_proj(hidden_states)
122
- value_states = self.v_proj(hidden_states)
123
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
124
- key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
125
- value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
126
- if position_ids is None:
127
- position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
128
- max_pos = position_ids.max().item() + 1
129
- seq_len = max(max_pos, q_len)
130
- t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
131
- freqs = torch.outer(t, self.inv_freq)
132
- emb = torch.cat((freqs, freqs), dim=-1)
133
- cos, sin = emb.cos(), emb.sin()
134
- cos = cos[position_ids].unsqueeze(1)
135
- sin = sin[position_ids].unsqueeze(1)
136
- query_states = self._apply_rotary(query_states, cos, sin)
137
- key_states = self._apply_rotary(key_states, cos, sin)
138
- if past_key_value is not None:
139
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
140
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
141
- past_kv = (key_states, value_states) if use_cache else None
142
- key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
143
- value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
144
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
145
- if attention_mask is not None:
146
- attn_weights = attn_weights + attention_mask
147
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
148
- attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
149
- attn_output = torch.matmul(attn_weights, value_states)
150
- attn_output = attn_output.transpose(1, 2).contiguous()
151
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
152
- return self.o_proj(attn_output), past_kv
153
-
154
- class SpiderPortalExpert(nn.Module):
155
- def __init__(self, config, intermediate_size=None):
156
- super().__init__()
157
- inter_size = intermediate_size or config.intermediate_size
158
- self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
159
- self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
160
- self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
161
- self.act_fn = nn.SiLU()
162
- def forward(self, hidden_states):
163
- return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
164
-
165
- class SpiderPortalRouter(nn.Module):
166
- def __init__(self, config):
167
- super().__init__()
168
- self.num_experts = config.num_experts
169
- self.num_experts_per_tok = config.num_experts_per_tok
170
- self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
171
- self.register_buffer("router_bias", torch.zeros(config.num_experts))
172
- def forward(self, hidden_states):
173
- router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
174
- routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
175
- biased_logits = router_logits + self.router_bias
176
- biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
177
- top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
178
- top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
179
- top_weights = top_weights.to(hidden_states.dtype)
180
- mean_probs = routing_weights.mean(dim=0)
181
- aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
182
- return top_weights, top_indices, aux_loss
183
-
184
- class SpiderPortalMoE(nn.Module):
185
- def __init__(self, config):
186
- super().__init__()
187
- self.config = config
188
- self.num_experts = config.num_experts
189
- self.num_experts_per_tok = config.num_experts_per_tok
190
- self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
191
- self.shared_expert = SpiderPortalExpert(config)
192
- self.router = SpiderPortalRouter(config)
193
- def forward(self, hidden_states):
194
- batch_size, seq_len, hidden_dim = hidden_states.shape
195
- top_weights, top_indices, aux_loss = self.router(hidden_states)
196
- flat_hidden = hidden_states.view(-1, hidden_dim)
197
- final_output = torch.zeros_like(flat_hidden)
198
- for expert_idx in range(self.num_experts_per_tok):
199
- expert_ids = top_indices[:, expert_idx]
200
- expert_weights = top_weights[:, expert_idx:expert_idx+1]
201
- for e in range(self.num_experts):
202
- mask = expert_ids == e
203
- if mask.any():
204
- expert_output = self.experts[e](flat_hidden[mask])
205
- final_output[mask] += expert_output * expert_weights[mask]
206
- shared_output = self.shared_expert(flat_hidden)
207
- final_output = final_output + shared_output
208
- return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
209
-
210
- class SpiderPortalDenseLayer(nn.Module):
211
- def __init__(self, config):
212
- super().__init__()
213
- self.self_attn = SpiderPortalGQA(config)
214
- dense_intermediate = config.hidden_size * 4 // 3
215
- self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
216
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
219
- attn_input = self.input_layernorm(hidden_states)
220
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
221
- hidden_states = hidden_states + attn_output
222
- ffn_input = self.post_attention_layernorm(hidden_states)
223
- ffn_output = self.ffn(ffn_input)
224
- hidden_states = hidden_states + ffn_output
225
- return hidden_states, past_kv
226
-
227
- class SpiderPortalMoELayer(nn.Module):
228
- def __init__(self, config, layer_idx):
229
- super().__init__()
230
- self.layer_idx = layer_idx
231
- self.self_attn = SpiderPortalGQA(config)
232
- self.moe = SpiderPortalMoE(config)
233
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
236
- attn_input = self.input_layernorm(hidden_states)
237
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
238
- hidden_states = hidden_states + attn_output
239
- moe_input = self.post_attention_layernorm(hidden_states)
240
- moe_output, aux_loss = self.moe(moe_input)
241
- hidden_states = hidden_states + moe_output
242
- return hidden_states, aux_loss, past_kv
243
-
244
- class LTIInjection(nn.Module):
245
- def __init__(self, config):
246
- super().__init__()
247
- self.hidden_size = config.hidden_size
248
- self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
249
- self.delta_t = nn.Parameter(torch.tensor(1.0))
250
- self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
251
- with torch.no_grad():
252
- self.B.weight.data.normal_(mean=0.0, std=0.01)
253
- def get_A(self):
254
- return -torch.exp(self.log_A)
255
- def forward(self, h_t, e):
256
- A = self.get_A()
257
- return A * h_t + self.B(e)
258
-
259
- class ACTHalting(nn.Module):
260
- def __init__(self, config):
261
- super().__init__()
262
- self.halt_predictor = nn.Linear(config.hidden_size, 1)
263
- self.threshold = config.act_threshold
264
- def forward(self, hidden_states):
265
- return torch.sigmoid(self.halt_predictor(hidden_states))
266
-
267
- class LoRAAdapter(nn.Module):
268
- def __init__(self, config):
269
- super().__init__()
270
- rank = config.lora_rank
271
- self.down = nn.Linear(config.hidden_size, rank, bias=False)
272
- self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
273
- self.scale = nn.Embedding(config.max_loop_iters, rank)
274
- with torch.no_grad():
275
- self.scale.weight.data.zero_()
276
- self.down.weight.data.normal_(mean=0.0, std=0.001)
277
- def forward(self, x, loop_t):
278
- max_t = self.scale.num_embeddings - 1
279
- t_idx = min(loop_t, max_t)
280
- s = self.scale(torch.tensor(t_idx, device=x.device))
281
- down = self.down(x) * s
282
- return down @ self.B
283
-
284
- class SpiderPortalMoEModel(nn.Module):
285
- def __init__(self, config):
286
- super().__init__()
287
- self.config = config
288
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
289
- self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
290
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
291
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
292
- self.injection = LTIInjection(config)
293
- self.act_halting = ACTHalting(config)
294
- self.lora_adapter = LoRAAdapter(config)
295
- self.loop_embed_dim = config.loop_embed_dim
296
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
297
- n_loops = n_loops or self.config.max_loop_iters
298
- input_embedding = input_embedding if input_embedding is not None else hidden_states
299
- total_aux_loss = 0.0
300
- for layer in self.prelude_layers:
301
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
302
- e = hidden_states.clone()
303
- B, T_seq, D = hidden_states.shape
304
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
305
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
306
- h_out = torch.zeros_like(hidden_states)
307
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
308
- for t in range(n_loops):
309
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
310
- if t > 0:
311
- injection = self.injection(hidden_states, input_embedding)
312
- hidden_states = hidden_states + injection
313
- new_past_key_values = []
314
- for i, layer in enumerate(self.recurrent_layers):
315
- hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
316
- total_aux_loss = total_aux_loss + aux_loss
317
- new_past_key_values.append(past_kv)
318
- lora_delta = self.lora_adapter(hidden_states, t)
319
- hidden_states = hidden_states + lora_delta
320
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
321
- still_running = ~halted
322
- remainder = (1.0 - cumulative_p).clamp(min=0)
323
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
324
- weight = weight * still_running.to(hidden_states.dtype)
325
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
326
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
327
- halted = halted | (cumulative_p >= self.config.act_threshold)
328
- if halted.all() and not self.training:
329
- break
330
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
331
- hidden_states = h_out + never_halted * hidden_states
332
- for layer in self.coda_layers:
333
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
334
- hidden_states = self.norm(hidden_states)
335
- return hidden_states, total_aux_loss, new_past_key_values
336
-
337
- class SpiderPortalForConditionalGeneration(nn.Module):
338
- def __init__(self, config):
339
- super().__init__()
340
- self.config = config
341
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
342
- self.model = SpiderPortalMoEModel(config)
343
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
344
- if config.tie_word_embeddings:
345
- self.lm_head.weight = self.embed_tokens.weight
346
- self.apply(self._init_weights)
347
- def _init_weights(self, module):
348
- if isinstance(module, nn.Linear):
349
- if hasattr(self, 'model') and module is self.model.injection.B:
350
- return
351
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
352
- if module.bias is not None:
353
- module.bias.data.zero_()
354
- elif isinstance(module, nn.Embedding):
355
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
356
- def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
357
- hidden_states = self.embed_tokens(input_ids)
358
- model_dtype = next(self.model.parameters()).dtype
359
- hidden_states = hidden_states.to(model_dtype)
360
- input_embedding = hidden_states.clone()
361
- if attention_mask is None:
362
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
363
- causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
364
- causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
365
- causal_mask = causal_mask.triu(1)
366
- hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
367
- logits = self.lm_head(hidden_states)
368
- loss = None
369
- if labels is not None:
370
- shift_logits = logits[..., :-1, :].contiguous()
371
- shift_labels = labels[..., 1:].contiguous()
372
- loss_fct = CrossEntropyLoss()
373
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
374
- loss = loss + self.config.router_aux_loss_coef * aux_loss
375
- return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
376
- def get_num_params(self):
377
- total = sum(p.numel() for p in self.parameters())
378
- return {"total": total, "trainable": total}
379
-
380
- def train_single_gpu():
381
- device = torch.device("cuda")
382
- gpu_name = torch.cuda.get_device_name(0)
383
- gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
384
- print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
385
-
386
- config = SpiderPortalConfig(
387
- hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
388
- num_key_value_heads=2, intermediate_size=1024,
389
- num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
390
- router_aux_loss_coef=0.05, max_loop_iters=2,
391
- prelude_layers=2, coda_layers=2, lora_rank=32,
392
- rope_theta=10000000.0,
393
- rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
394
- max_position_embeddings=131072, sliding_window=4096,
395
- tie_word_embeddings=True,
396
- )
397
-
398
- print("Building model...")
399
- model = SpiderPortalForConditionalGeneration(config)
400
- model = model.to(torch.bfloat16).to(device)
401
-
402
- params = model.get_num_params()
403
- print(f"Model: {params['total']/1e6:.1f}M params")
404
- print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
405
-
406
- try:
407
- model = torch.compile(model, mode="reduce-overhead")
408
- print("torch.compile: enabled")
409
- except Exception:
410
- print("torch.compile: not available, using eager mode")
411
-
412
- BASE_LR = 2e-5
413
- WARMUP_STEPS = 1000
414
- optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01)
415
-
416
- import pandas as pd
417
- data_dir = Path(__file__).parent / "data"
418
- all_records = []
419
- pkl_file = data_dir / "spiderportal_combined.pkl"
420
- if pkl_file.exists():
421
- print(f"Loading dataset from {pkl_file}...")
422
- df = pd.read_pickle(pkl_file)
423
- all_records = df.to_dict("records")
424
- else:
425
- print(f"No dataset found at {pkl_file}, creating synthetic data...")
426
- all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
427
-
428
- print(f"Loaded {len(all_records):,} samples")
429
-
430
- from transformers import AutoTokenizer
431
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
432
- tokenizer.pad_token = tokenizer.eos_token
433
-
434
- BATCH_SIZE = 128
435
- MAX_LEN = 256
436
- EPOCHS = 3
437
- N_LOOPS = 2
438
-
439
- print(f"Batch size: {BATCH_SIZE} (no grad accum)")
440
- print(f"Effective batch: {BATCH_SIZE}")
441
- print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup")
442
- print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}")
443
-
444
- def build_prompt(sample):
445
- instruction = str(sample.get("instruction", "")).strip()
446
- inp = str(sample.get("input", "")).strip()
447
- output = str(sample.get("output", "")).strip()
448
- if inp:
449
- return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
450
- return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
451
-
452
- print("Pre-tokenizing dataset...")
453
- prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
454
- mask_len = len(prefix_ids)
455
-
456
- pre_tokenized = []
457
- for i, sample in enumerate(all_records):
458
- instruction = str(sample.get("instruction", "")).strip()
459
- inp = str(sample.get("input", "")).strip()
460
- output = str(sample.get("output", "")).strip()
461
- if inp:
462
- text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token
463
- else:
464
- text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token
465
- enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length")
466
- input_ids = enc["input_ids"]
467
- labels = input_ids[:]
468
- for j in range(min(mask_len, len(labels))):
469
- labels[j] = -100
470
- pre_tokenized.append((input_ids, labels))
471
- if (i + 1) % 50000 == 0:
472
- print(f" Tokenized {i+1:,}/{len(all_records):,}")
473
-
474
- print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples")
475
- del all_records
476
- gc.collect()
477
-
478
- global_step = 0
479
- best_loss = float('inf')
480
- start_time = time.time()
481
- checkpoint_dir = Path("checkpoints")
482
- checkpoint_dir.mkdir(exist_ok=True)
483
- step_ckpt_files = []
484
-
485
- for epoch in range(1, EPOCHS + 1):
486
- if epoch > 1:
487
- for f in step_ckpt_files:
488
- if f.exists():
489
- f.unlink()
490
- print(f" Deleted old step checkpoint: {f.name}")
491
- step_ckpt_files.clear()
492
- gc.collect()
493
-
494
- indices = list(range(len(pre_tokenized)))
495
- random.shuffle(indices)
496
- total_loss = 0
497
- num_batches = 0
498
- optimizer.zero_grad()
499
-
500
- for batch_start in range(0, len(indices), BATCH_SIZE):
501
- batch_indices = indices[batch_start:batch_start + BATCH_SIZE]
502
- if len(batch_indices) < BATCH_SIZE:
503
- continue
504
-
505
- if global_step < WARMUP_STEPS:
506
- lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
507
- for param_group in optimizer.param_groups:
508
- param_group['lr'] = lr
509
-
510
- batch_input_ids = []
511
- batch_labels = []
512
- for idx in batch_indices:
513
- input_ids, labels = pre_tokenized[idx]
514
- batch_input_ids.append(input_ids)
515
- batch_labels.append(labels)
516
-
517
- input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device)
518
- labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
519
-
520
- if global_step == 0:
521
- print(" [First forward pass - compiling...]")
522
-
523
- outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
524
- loss = outputs["loss"]
525
- loss.backward()
526
-
527
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
528
- optimizer.step()
529
- optimizer.zero_grad()
530
- global_step += 1
531
-
532
- total_loss += loss.item()
533
- num_batches += 1
534
-
535
- if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0:
536
- avg_loss = total_loss / max(num_batches, 1)
537
- elapsed = time.time() - start_time
538
- steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
539
- current_lr = optimizer.param_groups[0]['lr']
540
- samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0
541
- print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec")
542
-
543
- if global_step > 0 and global_step % 500 == 0:
544
- ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt"
545
- state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
546
- torch.save(state_dict, ckpt_path)
547
- step_ckpt_files.append(ckpt_path)
548
- size_mb = ckpt_path.stat().st_size / (1024 * 1024)
549
- print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
550
-
551
- avg_loss = total_loss / max(num_batches, 1)
552
- epoch_time = (time.time() - start_time) / 60
553
- print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min")
554
-
555
- ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt"
556
- torch.save({
557
- "step": global_step,
558
- "epoch": epoch,
559
- "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
560
- "optimizer_state_dict": optimizer.state_dict(),
561
- "config": config.__dict__,
562
- }, ckpt_path)
563
- size_mb = ckpt_path.stat().st_size / (1024 * 1024)
564
- print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
565
-
566
- if avg_loss < best_loss:
567
- best_loss = avg_loss
568
- best_path = checkpoint_dir / "spiderportal-v5-best.pt"
569
- torch.save({
570
- "step": global_step,
571
- "epoch": epoch,
572
- "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
573
- "optimizer_state_dict": optimizer.state_dict(),
574
- "config": config.__dict__,
575
- }, best_path)
576
- size_mb = best_path.stat().st_size / (1024 * 1024)
577
- print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)")
578
-
579
- total_time = (time.time() - start_time) / 3600
580
- print(f"\nTraining complete!")
581
- print(f"Best loss: {best_loss:.4f}")
582
- print(f"Total time: {total_time:.2f} hours")
583
- print(f"Total steps: {global_step}")
584
- print(f"Checkpoints saved to: {checkpoint_dir}")
585
-
586
- if __name__ == "__main__":
587
- train_single_gpu()