CLIWorks commited on
Commit
9a53bd6
·
verified ·
1 Parent(s): f025f95

Upload 12 files

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ AI-Transfer/data/spiderportal_combined.pkl filter=lfs diff=lfs merge=lfs -text
2
+ SpiderPortal-Base-v5/model.safetensors filter=lfs diff=lfs merge=lfs -text
AI-Transfer/README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SpiderPortal v5 — GPU Training Transfer Package
2
+
3
+ ## Contents
4
+ - `scripts/train_single_gpu.py` — Optimized single-GPU training script
5
+ - `scripts/train_ddp.py` — Original DDP script (for reference)
6
+ - `data/spiderportal_combined.pkl` — Training dataset (491K samples, 537MB)
7
+ - `notebooks/spiderportal_gpu.ipynb` — Original Kaggle notebook (for reference)
8
+
9
+ ## Quick Start
10
+
11
+ ### 1. Setup
12
+ ```bash
13
+ pip install torch transformers pandas
14
+ ```
15
+
16
+ ### 2. Run Training
17
+ ```bash
18
+ python scripts/train_single_gpu.py
19
+ ```
20
+
21
+ ## Configuration (edit in `train_single_gpu.py`)
22
+
23
+ | Parameter | Default | RTX 6000 (96GB) |
24
+ |-----------|---------|-----------------|
25
+ | `BATCH_SIZE` | 64 | 64-128 |
26
+ | `MAX_LEN` | 256 | 256-512 |
27
+ | `EPOCHS` | 3 | 3 |
28
+ | `N_LOOPS` | 2 | 2 |
29
+ | `BASE_LR` | 2e-5 | 2e-5 |
30
+ | `WARMUP_STEPS` | 1000 | 1000 |
31
+
32
+ ## Expected Performance (RTX PRO 6000, 96GB)
33
+
34
+ | Metric | Value |
35
+ |--------|-------|
36
+ | Model size | 659M params |
37
+ | Active params (MoE) | ~59M |
38
+ | VRAM usage | ~15-25GB |
39
+ | Batch size | 64 |
40
+ | Steps per epoch | ~7,672 |
41
+ | Time per epoch | 40-100 min |
42
+ | **Total (3 epochs)** | **2-5 hours** |
43
+
44
+ ## Checkpoints
45
+
46
+ Saved to `checkpoints/` directory:
47
+
48
+ | Checkpoint type | What's saved | Size | Purpose |
49
+ |----------------|--------------|------|---------|
50
+ | Every 500 steps | Model weights only | ~1.3GB | Testing, transfer, inference |
51
+ | End of epoch | Model + optimizer | ~6.6GB | Resume training |
52
+ | Best loss | Model + optimizer | ~6.6GB | Resume training |
53
+
54
+ Step checkpoints are auto-deleted at the start of each new epoch to free disk space.
55
+
56
+ ### Loading a weights-only checkpoint (inference)
57
+ ```python
58
+ state_dict = torch.load("checkpoints/spiderportal-v5-ep1-step500.pt", map_location="cpu")
59
+ model.load_state_dict(state_dict)
60
+ ```
61
+
62
+ ### Resuming training from an epoch checkpoint
63
+ ```python
64
+ ckpt = torch.load("checkpoints/spiderportal-v5-ep1.pt", map_location="cpu")
65
+ model.load_state_dict(ckpt["model_state_dict"])
66
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
67
+ start_epoch = ckpt["epoch"]
68
+ ```
69
+
70
+ ### Peak disk usage during training: ~20GB
71
+
72
+ ## Model Architecture
73
+
74
+ - **Type**: MoE-RDT (Mixture of Experts + Recurrent Depth Transformer)
75
+ - **Total params**: 659M
76
+ - **Active params**: ~59M (1 routed expert + 1 shared expert per token)
77
+ - **Experts**: 64 routed + 1 shared
78
+ - **Layers**: 2 prelude → 8 recurrent (MoE) → 2 coda
79
+ - **Attention**: GQA (8 heads, 2 KV heads)
80
+ - **Context**: 131K tokens (YaRN RoPE scaling)
81
+ - **Loop**: ACT halting + LTI injection + LoRA adapters
AI-Transfer/data/spiderportal_combined.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:706bb0870f785241df36c3dcc1925deb8a47ce6003dfe506b6c0751624c34e63
3
+ size 563067610
AI-Transfer/notebooks/spiderportal_gpu.ipynb ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# SpiderPortal MoE-RDT v5 \u2014 Multi-GPU Training (DDP)\n",
8
+ "\n",
9
+ "Optimized for 2\u00d7 T4 GPUs (32GB total VRAM).\n",
10
+ "- **Total params**: ~659M (64 experts)\n",
11
+ "- **Active params**: ~59M per token\n",
12
+ "- **Training**: bf16, DDP, gradient accumulation\n",
13
+ "- **Expected time**: ~1-1.5 hours for 3 epochs"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "!pip install -q transformers pandas safetensors\n",
23
+ "\n",
24
+ "import torch\n",
25
+ "import torch.nn as nn\n",
26
+ "import torch.nn.functional as F\n",
27
+ "import math, os, json, gc, random, time\n",
28
+ "from pathlib import Path\n",
29
+ "from dataclasses import dataclass\n",
30
+ "from typing import Optional, Tuple, Dict, List\n",
31
+ "from torch.nn import CrossEntropyLoss\n",
32
+ "\n",
33
+ "print(f\"PyTorch: {torch.__version__}\")\n",
34
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
35
+ "print(f\"GPU count: {torch.cuda.device_count()}\")\n",
36
+ "for i in range(torch.cuda.device_count()):\n",
37
+ " p = torch.cuda.get_device_properties(i)\n",
38
+ " print(f\" GPU {i}: {torch.cuda.get_device_name(i)} ({p.total_memory / 1e9:.1f}GB)\")"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "@dataclass\n",
48
+ "class SpiderPortalConfig:\n",
49
+ " vocab_size: int = 50278\n",
50
+ " hidden_size: int = 384\n",
51
+ " num_hidden_layers: int = 8\n",
52
+ " num_attention_heads: int = 8\n",
53
+ " num_key_value_heads: int = 2\n",
54
+ " intermediate_size: int = 1024\n",
55
+ " hidden_act: str = \"silu\"\n",
56
+ " num_experts: int = 64\n",
57
+ " num_experts_per_tok: int = 1\n",
58
+ " num_shared_experts: int = 1\n",
59
+ " router_aux_loss_coef: float = 0.05\n",
60
+ " max_loop_iters: int = 4\n",
61
+ " act_threshold: float = 0.5\n",
62
+ " max_position_embeddings: int = 131072\n",
63
+ " rope_theta: float = 10000000.0\n",
64
+ " rope_scaling: dict = None\n",
65
+ " sliding_window: int = 4096\n",
66
+ " attention_dropout: float = 0.0\n",
67
+ " rms_norm_eps: float = 1e-6\n",
68
+ " initializer_range: float = 0.02\n",
69
+ " use_cache: bool = True\n",
70
+ " tie_word_embeddings: bool = True\n",
71
+ " prelude_layers: int = 2\n",
72
+ " coda_layers: int = 2\n",
73
+ " lora_rank: int = 32\n",
74
+ " loop_embed_dim: int = 48\n",
75
+ " vision_hidden_size: int = 384\n",
76
+ " audio_hidden_size: int = 512\n",
77
+ " vision_num_frames: int = 60\n",
78
+ " vision_tokens_per_frame: int = 256\n",
79
+ " vision_temporal_tokens: int = 64\n",
80
+ " vision_temporal_layers: int = 2\n",
81
+ " model_type: str = \"spiderportal\"\n",
82
+ " torch_dtype: str = \"bfloat16\"\n",
83
+ "\n",
84
+ "def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):\n",
85
+ " freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))\n",
86
+ " angles = loop_t * freqs\n",
87
+ " emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]\n",
88
+ " emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)\n",
89
+ " emb_full[:loop_dim] = emb\n",
90
+ " return h + emb_full.unsqueeze(0).unsqueeze(0)\n",
91
+ "\n",
92
+ "class SpiderPortalRMSNorm(nn.Module):\n",
93
+ " def __init__(self, hidden_size, eps=1e-6):\n",
94
+ " super().__init__()\n",
95
+ " self.weight = nn.Parameter(torch.ones(hidden_size))\n",
96
+ " self.variance_epsilon = eps\n",
97
+ " def forward(self, hidden_states):\n",
98
+ " input_dtype = hidden_states.dtype\n",
99
+ " hidden_states = hidden_states.to(torch.float32)\n",
100
+ " variance = hidden_states.pow(2).mean(-1, keepdim=True)\n",
101
+ " hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n",
102
+ " return self.weight.to(input_dtype) * hidden_states.to(input_dtype)\n",
103
+ "\n",
104
+ "def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):\n",
105
+ " dim = head_dim\n",
106
+ " orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))\n",
107
+ " pos_freqs = torch.arange(0, dim, 2).float() / dim\n",
108
+ " beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))\n",
109
+ " scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))\n",
110
+ " return orig_inv_freq * scale\n",
111
+ "\n",
112
+ "class SpiderPortalGQA(nn.Module):\n",
113
+ " def __init__(self, config):\n",
114
+ " super().__init__()\n",
115
+ " self.config = config\n",
116
+ " self.hidden_size = config.hidden_size\n",
117
+ " self.num_heads = config.num_attention_heads\n",
118
+ " self.num_kv_heads = config.num_key_value_heads\n",
119
+ " self.head_dim = config.hidden_size // config.num_attention_heads\n",
120
+ " self.num_key_value_groups = self.num_heads // self.num_kv_heads\n",
121
+ " self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
122
+ " self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)\n",
123
+ " self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)\n",
124
+ " self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
125
+ " self.attention_dropout = config.attention_dropout\n",
126
+ " rope_scaling = getattr(config, 'rope_scaling', None)\n",
127
+ " if rope_scaling and rope_scaling.get(\"type\") == \"yarn\":\n",
128
+ " factor = rope_scaling.get(\"factor\", 1.0)\n",
129
+ " orig_max_pos = rope_scaling.get(\"original_max_position_embeddings\", config.max_position_embeddings)\n",
130
+ " inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)\n",
131
+ " else:\n",
132
+ " inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))\n",
133
+ " self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n",
134
+ " def _rotate_half(self, x):\n",
135
+ " x1 = x[..., :x.shape[-1] // 2]\n",
136
+ " x2 = x[..., x.shape[-1] // 2:]\n",
137
+ " return torch.cat((-x2, x1), dim=-1)\n",
138
+ " def _apply_rotary(self, x, cos, sin):\n",
139
+ " return (x * cos) + (self._rotate_half(x) * sin)\n",
140
+ " def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):\n",
141
+ " bsz, q_len, _ = hidden_states.size()\n",
142
+ " query_states = self.q_proj(hidden_states)\n",
143
+ " key_states = self.k_proj(hidden_states)\n",
144
+ " value_states = self.v_proj(hidden_states)\n",
145
+ " query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
146
+ " key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n",
147
+ " value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n",
148
+ " if position_ids is None:\n",
149
+ " position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)\n",
150
+ " max_pos = position_ids.max().item() + 1\n",
151
+ " seq_len = max(max_pos, q_len)\n",
152
+ " t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)\n",
153
+ " freqs = torch.outer(t, self.inv_freq)\n",
154
+ " emb = torch.cat((freqs, freqs), dim=-1)\n",
155
+ " cos, sin = emb.cos(), emb.sin()\n",
156
+ " cos = cos[position_ids].unsqueeze(1)\n",
157
+ " sin = sin[position_ids].unsqueeze(1)\n",
158
+ " query_states = self._apply_rotary(query_states, cos, sin)\n",
159
+ " key_states = self._apply_rotary(key_states, cos, sin)\n",
160
+ " if past_key_value is not None:\n",
161
+ " key_states = torch.cat([past_key_value[0], key_states], dim=2)\n",
162
+ " value_states = torch.cat([past_key_value[1], value_states], dim=2)\n",
163
+ " past_kv = (key_states, value_states) if use_cache else None\n",
164
+ " key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)\n",
165
+ " value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)\n",
166
+ " attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
167
+ " if attention_mask is not None:\n",
168
+ " attn_weights = attn_weights + attention_mask\n",
169
+ " attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)\n",
170
+ " attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n",
171
+ " attn_output = torch.matmul(attn_weights, value_states)\n",
172
+ " attn_output = attn_output.transpose(1, 2).contiguous()\n",
173
+ " attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n",
174
+ " return self.o_proj(attn_output), past_kv\n",
175
+ "\n",
176
+ "class SpiderPortalExpert(nn.Module):\n",
177
+ " def __init__(self, config, intermediate_size=None):\n",
178
+ " super().__init__()\n",
179
+ " inter_size = intermediate_size or config.intermediate_size\n",
180
+ " self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)\n",
181
+ " self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)\n",
182
+ " self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)\n",
183
+ " self.act_fn = nn.SiLU()\n",
184
+ " def forward(self, hidden_states):\n",
185
+ " return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))\n",
186
+ "\n",
187
+ "class SpiderPortalRouter(nn.Module):\n",
188
+ " def __init__(self, config):\n",
189
+ " super().__init__()\n",
190
+ " self.num_experts = config.num_experts\n",
191
+ " self.num_experts_per_tok = config.num_experts_per_tok\n",
192
+ " self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)\n",
193
+ " self.register_buffer(\"router_bias\", torch.zeros(config.num_experts))\n",
194
+ " def forward(self, hidden_states):\n",
195
+ " router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight\n",
196
+ " routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)\n",
197
+ " biased_logits = router_logits + self.router_bias\n",
198
+ " biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)\n",
199
+ " top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)\n",
200
+ " top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)\n",
201
+ " top_weights = top_weights.to(hidden_states.dtype)\n",
202
+ " mean_probs = routing_weights.mean(dim=0)\n",
203
+ " aux_loss = self.num_experts * (mean_probs * mean_probs).sum()\n",
204
+ " return top_weights, top_indices, aux_loss\n",
205
+ "\n",
206
+ "class SpiderPortalMoE(nn.Module):\n",
207
+ " def __init__(self, config):\n",
208
+ " super().__init__()\n",
209
+ " self.config = config\n",
210
+ " self.num_experts = config.num_experts\n",
211
+ " self.num_experts_per_tok = config.num_experts_per_tok\n",
212
+ " self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])\n",
213
+ " self.shared_expert = SpiderPortalExpert(config)\n",
214
+ " self.router = SpiderPortalRouter(config)\n",
215
+ " def forward(self, hidden_states):\n",
216
+ " batch_size, seq_len, hidden_dim = hidden_states.shape\n",
217
+ " top_weights, top_indices, aux_loss = self.router(hidden_states)\n",
218
+ " flat_hidden = hidden_states.view(-1, hidden_dim)\n",
219
+ " final_output = torch.zeros_like(flat_hidden)\n",
220
+ " for expert_idx in range(self.num_experts_per_tok):\n",
221
+ " expert_ids = top_indices[:, expert_idx]\n",
222
+ " expert_weights = top_weights[:, expert_idx:expert_idx+1]\n",
223
+ " for e in range(self.num_experts):\n",
224
+ " mask = expert_ids == e\n",
225
+ " if mask.any():\n",
226
+ " expert_output = self.experts[e](flat_hidden[mask])\n",
227
+ " final_output[mask] += expert_output * expert_weights[mask]\n",
228
+ " shared_output = self.shared_expert(flat_hidden)\n",
229
+ " final_output = final_output + shared_output\n",
230
+ " return final_output.view(batch_size, seq_len, hidden_dim), aux_loss\n",
231
+ "\n",
232
+ "class SpiderPortalDenseLayer(nn.Module):\n",
233
+ " def __init__(self, config):\n",
234
+ " super().__init__()\n",
235
+ " self.self_attn = SpiderPortalGQA(config)\n",
236
+ " dense_intermediate = config.hidden_size * 4 // 3\n",
237
+ " self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)\n",
238
+ " self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
239
+ " self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
240
+ " def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):\n",
241
+ " attn_input = self.input_layernorm(hidden_states)\n",
242
+ " attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)\n",
243
+ " hidden_states = hidden_states + attn_output\n",
244
+ " ffn_input = self.post_attention_layernorm(hidden_states)\n",
245
+ " ffn_output = self.ffn(ffn_input)\n",
246
+ " hidden_states = hidden_states + ffn_output\n",
247
+ " return hidden_states, past_kv\n",
248
+ "\n",
249
+ "class SpiderPortalMoELayer(nn.Module):\n",
250
+ " def __init__(self, config, layer_idx):\n",
251
+ " super().__init__()\n",
252
+ " self.layer_idx = layer_idx\n",
253
+ " self.self_attn = SpiderPortalGQA(config)\n",
254
+ " self.moe = SpiderPortalMoE(config)\n",
255
+ " self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
256
+ " self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
257
+ " def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):\n",
258
+ " attn_input = self.input_layernorm(hidden_states)\n",
259
+ " attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)\n",
260
+ " hidden_states = hidden_states + attn_output\n",
261
+ " moe_input = self.post_attention_layernorm(hidden_states)\n",
262
+ " moe_output, aux_loss = self.moe(moe_input)\n",
263
+ " hidden_states = hidden_states + moe_output\n",
264
+ " return hidden_states, aux_loss, past_kv\n",
265
+ "\n",
266
+ "class LTIInjection(nn.Module):\n",
267
+ " def __init__(self, config):\n",
268
+ " super().__init__()\n",
269
+ " self.hidden_size = config.hidden_size\n",
270
+ " self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))\n",
271
+ " self.delta_t = nn.Parameter(torch.tensor(1.0))\n",
272
+ " self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
273
+ " with torch.no_grad():\n",
274
+ " self.B.weight.data.normal_(mean=0.0, std=0.01)\n",
275
+ " def get_A(self):\n",
276
+ " return -torch.exp(self.log_A)\n",
277
+ " def forward(self, h_t, e):\n",
278
+ " A = self.get_A()\n",
279
+ " return A * h_t + self.B(e)\n",
280
+ "\n",
281
+ "class ACTHalting(nn.Module):\n",
282
+ " def __init__(self, config):\n",
283
+ " super().__init__()\n",
284
+ " self.halt_predictor = nn.Linear(config.hidden_size, 1)\n",
285
+ " self.threshold = config.act_threshold\n",
286
+ " def forward(self, hidden_states):\n",
287
+ " return torch.sigmoid(self.halt_predictor(hidden_states))\n",
288
+ "\n",
289
+ "class LoRAAdapter(nn.Module):\n",
290
+ " def __init__(self, config):\n",
291
+ " super().__init__()\n",
292
+ " rank = config.lora_rank\n",
293
+ " self.down = nn.Linear(config.hidden_size, rank, bias=False)\n",
294
+ " self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)\n",
295
+ " self.scale = nn.Embedding(config.max_loop_iters, rank)\n",
296
+ " with torch.no_grad():\n",
297
+ " self.scale.weight.data.zero_()\n",
298
+ " self.down.weight.data.normal_(mean=0.0, std=0.001)\n",
299
+ " def forward(self, x, loop_t):\n",
300
+ " max_t = self.scale.num_embeddings - 1\n",
301
+ " t_idx = min(loop_t, max_t)\n",
302
+ " s = self.scale(torch.tensor(t_idx, device=x.device))\n",
303
+ " down = self.down(x) * s\n",
304
+ " return down @ self.B\n",
305
+ "\n",
306
+ "class SpiderPortalMoEModel(nn.Module):\n",
307
+ " def __init__(self, config):\n",
308
+ " super().__init__()\n",
309
+ " self.config = config\n",
310
+ " self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])\n",
311
+ " self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])\n",
312
+ " self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])\n",
313
+ " self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
314
+ " self.injection = LTIInjection(config)\n",
315
+ " self.act_halting = ACTHalting(config)\n",
316
+ " self.lora_adapter = LoRAAdapter(config)\n",
317
+ " self.loop_embed_dim = config.loop_embed_dim\n",
318
+ " def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):\n",
319
+ " n_loops = n_loops or self.config.max_loop_iters\n",
320
+ " input_embedding = input_embedding if input_embedding is not None else hidden_states\n",
321
+ " total_aux_loss = 0.0\n",
322
+ " for layer in self.prelude_layers:\n",
323
+ " hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)\n",
324
+ " e = hidden_states.clone()\n",
325
+ " B, T_seq, D = hidden_states.shape\n",
326
+ " halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)\n",
327
+ " cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)\n",
328
+ " h_out = torch.zeros_like(hidden_states)\n",
329
+ " past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)\n",
330
+ " for t in range(n_loops):\n",
331
+ " h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)\n",
332
+ " if t > 0:\n",
333
+ " injection = self.injection(hidden_states, input_embedding)\n",
334
+ " hidden_states = hidden_states + injection\n",
335
+ " new_past_key_values = []\n",
336
+ " for i, layer in enumerate(self.recurrent_layers):\n",
337
+ " hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)\n",
338
+ " total_aux_loss = total_aux_loss + aux_loss\n",
339
+ " new_past_key_values.append(past_kv)\n",
340
+ " lora_delta = self.lora_adapter(hidden_states, t)\n",
341
+ " hidden_states = hidden_states + lora_delta\n",
342
+ " halt_prob = self.act_halting(hidden_states).squeeze(-1)\n",
343
+ " still_running = ~halted\n",
344
+ " remainder = (1.0 - cumulative_p).clamp(min=0)\n",
345
+ " weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)\n",
346
+ " weight = weight * still_running.to(hidden_states.dtype)\n",
347
+ " h_out = h_out + weight.unsqueeze(-1) * hidden_states\n",
348
+ " cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)\n",
349
+ " halted = halted | (cumulative_p >= self.config.act_threshold)\n",
350
+ " if halted.all() and not self.training:\n",
351
+ " break\n",
352
+ " never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)\n",
353
+ " hidden_states = h_out + never_halted * hidden_states\n",
354
+ " for layer in self.coda_layers:\n",
355
+ " hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)\n",
356
+ " hidden_states = self.norm(hidden_states)\n",
357
+ " return hidden_states, total_aux_loss, new_past_key_values\n",
358
+ "\n",
359
+ "class SpiderPortalForConditionalGeneration(nn.Module):\n",
360
+ " def __init__(self, config):\n",
361
+ " super().__init__()\n",
362
+ " self.config = config\n",
363
+ " self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n",
364
+ " self.model = SpiderPortalMoEModel(config)\n",
365
+ " self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n",
366
+ " if config.tie_word_embeddings:\n",
367
+ " self.lm_head.weight = self.embed_tokens.weight\n",
368
+ " self.apply(self._init_weights)\n",
369
+ " def _init_weights(self, module):\n",
370
+ " if isinstance(module, nn.Linear):\n",
371
+ " if hasattr(self, 'model') and module is self.model.injection.B:\n",
372
+ " return\n",
373
+ " module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
374
+ " if module.bias is not None:\n",
375
+ " module.bias.data.zero_()\n",
376
+ " elif isinstance(module, nn.Embedding):\n",
377
+ " module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
378
+ " def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):\n",
379
+ " hidden_states = self.embed_tokens(input_ids)\n",
380
+ " model_dtype = next(self.model.parameters()).dtype\n",
381
+ " hidden_states = hidden_states.to(model_dtype)\n",
382
+ " input_embedding = hidden_states.clone()\n",
383
+ " if attention_mask is None:\n",
384
+ " attention_mask = torch.ones_like(input_ids, dtype=torch.bool)\n",
385
+ " causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)\n",
386
+ " causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)\n",
387
+ " causal_mask = causal_mask.triu(1)\n",
388
+ " hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)\n",
389
+ " logits = self.lm_head(hidden_states)\n",
390
+ " loss = None\n",
391
+ " if labels is not None:\n",
392
+ " shift_logits = logits[..., :-1, :].contiguous()\n",
393
+ " shift_labels = labels[..., 1:].contiguous()\n",
394
+ " loss_fct = CrossEntropyLoss()\n",
395
+ " loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
396
+ " loss = loss + self.config.router_aux_loss_coef * aux_loss\n",
397
+ " return {\"loss\": loss, \"logits\": logits, \"aux_loss\": aux_loss, \"past_key_values\": past_kv}\n",
398
+ " def get_num_params(self):\n",
399
+ " total = sum(p.numel() for p in self.parameters())\n",
400
+ " return {\"total\": total, \"trainable\": total}"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": [
409
+ "# Multi-GPU DDP Training\n",
410
+ "# DDP requires a standalone script (mp.spawn doesn't work in notebooks)\n",
411
+ "# The script is at scripts/train_ddp.py\n",
412
+ "!python scripts/train_ddp.py"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "# Test generation (after training)\n",
422
+ "device = torch.device(\"cuda:0\")\n",
423
+ "config = SpiderPortalConfig(\n",
424
+ " prelude_layers=2, coda_layers=2, lora_rank=32,\n",
425
+ " rope_theta=10000000.0, tie_word_embeddings=True,\n",
426
+ ")\n",
427
+ "model = SpiderPortalForConditionalGeneration(config)\n",
428
+ "\n",
429
+ "checkpoint = torch.load(\"/kaggle/working/spiderportal-v5-ep1/model.pt\", map_location=\"cpu\")\n",
430
+ "model.load_state_dict(checkpoint)\n",
431
+ "model = model.to(torch.bfloat16).to(device)\n",
432
+ "model.eval()\n",
433
+ "\n",
434
+ "from transformers import AutoTokenizer\n",
435
+ "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
436
+ "tokenizer.pad_token = tokenizer.eos_token\n",
437
+ "\n",
438
+ "with torch.no_grad():\n",
439
+ " prompt = \"Question: Instruction: What is the capital of France?\\nAnswer:\"\n",
440
+ " input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].to(device)\n",
441
+ " generated = input_ids.clone()\n",
442
+ " for _ in range(32):\n",
443
+ " outputs = model(generated, n_loops=1)\n",
444
+ " next_token = torch.argmax(outputs[\"logits\"][:, -1, :], dim=-1, keepdim=True)\n",
445
+ " generated = torch.cat([generated, next_token], dim=1)\n",
446
+ " if next_token.item() == tokenizer.eos_token_id:\n",
447
+ " break\n",
448
+ " print(f\"Generated: {tokenizer.decode(generated[0])}\")"
449
+ ]
450
+ }
451
+ ],
452
+ "metadata": {
453
+ "kernelspec": {
454
+ "display_name": "Python 3",
455
+ "language": "python",
456
+ "name": "python3"
457
+ },
458
+ "language_info": {
459
+ "name": "python",
460
+ "version": "3.10.0"
461
+ },
462
+ "accelerator": "GPU"
463
+ },
464
+ "nbformat": 4,
465
+ "nbformat_minor": 4
466
+ }
AI-Transfer/scripts/train_ddp.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SpiderPortal v5 — Multi-GPU DDP Training.
3
+
4
+ Run from Kaggle notebook cell:
5
+ !python train_ddp.py
6
+
7
+ Or directly:
8
+ python -m torch.distributed.run --nproc_per_node=2 train_ddp.py
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.distributed as dist
15
+ import math
16
+ import os
17
+ import json
18
+ import gc
19
+ import random
20
+ import time
21
+ import subprocess
22
+ from pathlib import Path
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Tuple, Dict, List
25
+ from torch.nn import CrossEntropyLoss
26
+ from torch.nn.parallel import DistributedDataParallel as DDP
27
+
28
+ @dataclass
29
+ class SpiderPortalConfig:
30
+ vocab_size: int = 50278
31
+ hidden_size: int = 384
32
+ num_hidden_layers: int = 8
33
+ num_attention_heads: int = 8
34
+ num_key_value_heads: int = 2
35
+ intermediate_size: int = 1024
36
+ hidden_act: str = "silu"
37
+ num_experts: int = 64
38
+ num_experts_per_tok: int = 1
39
+ num_shared_experts: int = 1
40
+ router_aux_loss_coef: float = 0.05
41
+ max_loop_iters: int = 4
42
+ act_threshold: float = 0.5
43
+ max_position_embeddings: int = 131072
44
+ rope_theta: float = 10000000.0
45
+ rope_scaling: dict = None
46
+ sliding_window: int = 4096
47
+ attention_dropout: float = 0.0
48
+ rms_norm_eps: float = 1e-6
49
+ initializer_range: float = 0.02
50
+ use_cache: bool = True
51
+ tie_word_embeddings: bool = True
52
+ prelude_layers: int = 2
53
+ coda_layers: int = 2
54
+ lora_rank: int = 32
55
+ loop_embed_dim: int = 48
56
+ vision_hidden_size: int = 384
57
+ audio_hidden_size: int = 512
58
+ vision_num_frames: int = 60
59
+ vision_tokens_per_frame: int = 256
60
+ vision_temporal_tokens: int = 64
61
+ vision_temporal_layers: int = 2
62
+ model_type: str = "spiderportal"
63
+ torch_dtype: str = "bfloat16"
64
+
65
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
66
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
67
+ angles = loop_t * freqs
68
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
69
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
70
+ emb_full[:loop_dim] = emb
71
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
72
+
73
+ class SpiderPortalRMSNorm(nn.Module):
74
+ def __init__(self, hidden_size, eps=1e-6):
75
+ super().__init__()
76
+ self.weight = nn.Parameter(torch.ones(hidden_size))
77
+ self.variance_epsilon = eps
78
+ def forward(self, hidden_states):
79
+ input_dtype = hidden_states.dtype
80
+ hidden_states = hidden_states.to(torch.float32)
81
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
82
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
83
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
84
+
85
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
86
+ dim = head_dim
87
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
88
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
89
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
90
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
91
+ return orig_inv_freq * scale
92
+
93
+ class SpiderPortalGQA(nn.Module):
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ self.config = config
97
+ self.hidden_size = config.hidden_size
98
+ self.num_heads = config.num_attention_heads
99
+ self.num_kv_heads = config.num_key_value_heads
100
+ self.head_dim = config.hidden_size // config.num_attention_heads
101
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
102
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
103
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
104
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
105
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
106
+ self.attention_dropout = config.attention_dropout
107
+ rope_scaling = getattr(config, 'rope_scaling', None)
108
+ if rope_scaling and rope_scaling.get("type") == "yarn":
109
+ factor = rope_scaling.get("factor", 1.0)
110
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
111
+ inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
112
+ else:
113
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
114
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
115
+ def _rotate_half(self, x):
116
+ x1 = x[..., :x.shape[-1] // 2]
117
+ x2 = x[..., x.shape[-1] // 2:]
118
+ return torch.cat((-x2, x1), dim=-1)
119
+ def _apply_rotary(self, x, cos, sin):
120
+ return (x * cos) + (self._rotate_half(x) * sin)
121
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
122
+ bsz, q_len, _ = hidden_states.size()
123
+ query_states = self.q_proj(hidden_states)
124
+ key_states = self.k_proj(hidden_states)
125
+ value_states = self.v_proj(hidden_states)
126
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
127
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
128
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
129
+ if position_ids is None:
130
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
131
+ max_pos = position_ids.max().item() + 1
132
+ seq_len = max(max_pos, q_len)
133
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
134
+ freqs = torch.outer(t, self.inv_freq)
135
+ emb = torch.cat((freqs, freqs), dim=-1)
136
+ cos, sin = emb.cos(), emb.sin()
137
+ cos = cos[position_ids].unsqueeze(1)
138
+ sin = sin[position_ids].unsqueeze(1)
139
+ query_states = self._apply_rotary(query_states, cos, sin)
140
+ key_states = self._apply_rotary(key_states, cos, sin)
141
+ if past_key_value is not None:
142
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
143
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
144
+ past_kv = (key_states, value_states) if use_cache else None
145
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
146
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
147
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
148
+ if attention_mask is not None:
149
+ attn_weights = attn_weights + attention_mask
150
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
151
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
152
+ attn_output = torch.matmul(attn_weights, value_states)
153
+ attn_output = attn_output.transpose(1, 2).contiguous()
154
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
155
+ return self.o_proj(attn_output), past_kv
156
+
157
+ class SpiderPortalExpert(nn.Module):
158
+ def __init__(self, config, intermediate_size=None):
159
+ super().__init__()
160
+ inter_size = intermediate_size or config.intermediate_size
161
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
162
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
163
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
164
+ self.act_fn = nn.SiLU()
165
+ def forward(self, hidden_states):
166
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
167
+
168
+ class SpiderPortalRouter(nn.Module):
169
+ def __init__(self, config):
170
+ super().__init__()
171
+ self.num_experts = config.num_experts
172
+ self.num_experts_per_tok = config.num_experts_per_tok
173
+ self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
174
+ self.register_buffer("router_bias", torch.zeros(config.num_experts))
175
+ def forward(self, hidden_states):
176
+ router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
177
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
178
+ biased_logits = router_logits + self.router_bias
179
+ biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
180
+ top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
181
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
182
+ top_weights = top_weights.to(hidden_states.dtype)
183
+ mean_probs = routing_weights.mean(dim=0)
184
+ aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
185
+ return top_weights, top_indices, aux_loss
186
+
187
+ class SpiderPortalMoE(nn.Module):
188
+ def __init__(self, config):
189
+ super().__init__()
190
+ self.config = config
191
+ self.num_experts = config.num_experts
192
+ self.num_experts_per_tok = config.num_experts_per_tok
193
+ self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
194
+ self.shared_expert = SpiderPortalExpert(config)
195
+ self.router = SpiderPortalRouter(config)
196
+ def forward(self, hidden_states):
197
+ batch_size, seq_len, hidden_dim = hidden_states.shape
198
+ top_weights, top_indices, aux_loss = self.router(hidden_states)
199
+ flat_hidden = hidden_states.view(-1, hidden_dim)
200
+ final_output = torch.zeros_like(flat_hidden)
201
+ for expert_idx in range(self.num_experts_per_tok):
202
+ expert_ids = top_indices[:, expert_idx]
203
+ expert_weights = top_weights[:, expert_idx:expert_idx+1]
204
+ for e in range(self.num_experts):
205
+ mask = expert_ids == e
206
+ if mask.any():
207
+ expert_output = self.experts[e](flat_hidden[mask])
208
+ final_output[mask] += expert_output * expert_weights[mask]
209
+ shared_output = self.shared_expert(flat_hidden)
210
+ final_output = final_output + shared_output
211
+ return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
212
+
213
+ class SpiderPortalDenseLayer(nn.Module):
214
+ def __init__(self, config):
215
+ super().__init__()
216
+ self.self_attn = SpiderPortalGQA(config)
217
+ dense_intermediate = config.hidden_size * 4 // 3
218
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
219
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
221
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
222
+ attn_input = self.input_layernorm(hidden_states)
223
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
224
+ hidden_states = hidden_states + attn_output
225
+ ffn_input = self.post_attention_layernorm(hidden_states)
226
+ ffn_output = self.ffn(ffn_input)
227
+ hidden_states = hidden_states + ffn_output
228
+ return hidden_states, past_kv
229
+
230
+ class SpiderPortalMoELayer(nn.Module):
231
+ def __init__(self, config, layer_idx):
232
+ super().__init__()
233
+ self.layer_idx = layer_idx
234
+ self.self_attn = SpiderPortalGQA(config)
235
+ self.moe = SpiderPortalMoE(config)
236
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
239
+ attn_input = self.input_layernorm(hidden_states)
240
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
241
+ hidden_states = hidden_states + attn_output
242
+ moe_input = self.post_attention_layernorm(hidden_states)
243
+ moe_output, aux_loss = self.moe(moe_input)
244
+ hidden_states = hidden_states + moe_output
245
+ return hidden_states, aux_loss, past_kv
246
+
247
+ class LTIInjection(nn.Module):
248
+ def __init__(self, config):
249
+ super().__init__()
250
+ self.hidden_size = config.hidden_size
251
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
252
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
253
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
254
+ with torch.no_grad():
255
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
256
+ def get_A(self):
257
+ return -torch.exp(self.log_A)
258
+ def forward(self, h_t, e):
259
+ A = self.get_A()
260
+ return A * h_t + self.B(e)
261
+
262
+ class ACTHalting(nn.Module):
263
+ def __init__(self, config):
264
+ super().__init__()
265
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
266
+ self.threshold = config.act_threshold
267
+ def forward(self, hidden_states):
268
+ return torch.sigmoid(self.halt_predictor(hidden_states))
269
+
270
+ class LoRAAdapter(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ rank = config.lora_rank
274
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
275
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
276
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
277
+ with torch.no_grad():
278
+ self.scale.weight.data.zero_()
279
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
280
+ def forward(self, x, loop_t):
281
+ max_t = self.scale.num_embeddings - 1
282
+ t_idx = min(loop_t, max_t)
283
+ s = self.scale(torch.tensor(t_idx, device=x.device))
284
+ down = self.down(x) * s
285
+ return down @ self.B
286
+
287
+ class SpiderPortalMoEModel(nn.Module):
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ self.config = config
291
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
292
+ self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
293
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
294
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
295
+ self.injection = LTIInjection(config)
296
+ self.act_halting = ACTHalting(config)
297
+ self.lora_adapter = LoRAAdapter(config)
298
+ self.loop_embed_dim = config.loop_embed_dim
299
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
300
+ n_loops = n_loops or self.config.max_loop_iters
301
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
302
+ total_aux_loss = 0.0
303
+ for layer in self.prelude_layers:
304
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
305
+ e = hidden_states.clone()
306
+ B, T_seq, D = hidden_states.shape
307
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
308
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
309
+ h_out = torch.zeros_like(hidden_states)
310
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
311
+ for t in range(n_loops):
312
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
313
+ if t > 0:
314
+ injection = self.injection(hidden_states, input_embedding)
315
+ hidden_states = hidden_states + injection
316
+ new_past_key_values = []
317
+ for i, layer in enumerate(self.recurrent_layers):
318
+ hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
319
+ total_aux_loss = total_aux_loss + aux_loss
320
+ new_past_key_values.append(past_kv)
321
+ lora_delta = self.lora_adapter(hidden_states, t)
322
+ hidden_states = hidden_states + lora_delta
323
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
324
+ still_running = ~halted
325
+ remainder = (1.0 - cumulative_p).clamp(min=0)
326
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
327
+ weight = weight * still_running.to(hidden_states.dtype)
328
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
329
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
330
+ halted = halted | (cumulative_p >= self.config.act_threshold)
331
+ if halted.all() and not self.training:
332
+ break
333
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
334
+ hidden_states = h_out + never_halted * hidden_states
335
+ for layer in self.coda_layers:
336
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
337
+ hidden_states = self.norm(hidden_states)
338
+ return hidden_states, total_aux_loss, new_past_key_values
339
+
340
+ class SpiderPortalForConditionalGeneration(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ self.config = config
344
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
345
+ self.model = SpiderPortalMoEModel(config)
346
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
347
+ if config.tie_word_embeddings:
348
+ self.lm_head.weight = self.embed_tokens.weight
349
+ self.apply(self._init_weights)
350
+ def _init_weights(self, module):
351
+ if isinstance(module, nn.Linear):
352
+ if hasattr(self, 'model') and module is self.model.injection.B:
353
+ return
354
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
355
+ if module.bias is not None:
356
+ module.bias.data.zero_()
357
+ elif isinstance(module, nn.Embedding):
358
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
359
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
360
+ hidden_states = self.embed_tokens(input_ids)
361
+ model_dtype = next(self.model.parameters()).dtype
362
+ hidden_states = hidden_states.to(model_dtype)
363
+ input_embedding = hidden_states.clone()
364
+ if attention_mask is None:
365
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
366
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
367
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
368
+ causal_mask = causal_mask.triu(1)
369
+ hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
370
+ logits = self.lm_head(hidden_states)
371
+ loss = None
372
+ if labels is not None:
373
+ shift_logits = logits[..., :-1, :].contiguous()
374
+ shift_labels = labels[..., 1:].contiguous()
375
+ loss_fct = CrossEntropyLoss()
376
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
377
+ loss = loss + self.config.router_aux_loss_coef * aux_loss
378
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
379
+ def get_num_params(self):
380
+ total = sum(p.numel() for p in self.parameters())
381
+ return {"total": total, "trainable": total}
382
+
383
+ def train_ddp(local_rank, world_size):
384
+ dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
385
+ torch.cuda.set_device(local_rank)
386
+ device = torch.device(f"cuda:{local_rank}")
387
+ is_master = local_rank == 0
388
+
389
+ if is_master:
390
+ print(f"Training on {world_size} GPUs (DDP)")
391
+ for i in range(world_size):
392
+ p = torch.cuda.get_device_properties(i)
393
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({p.total_memory / 1e9:.1f}GB)")
394
+
395
+ config = SpiderPortalConfig(
396
+ hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
397
+ num_key_value_heads=2, intermediate_size=1024,
398
+ num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
399
+ router_aux_loss_coef=0.05, max_loop_iters=2,
400
+ prelude_layers=2, coda_layers=2, lora_rank=32,
401
+ rope_theta=10000000.0,
402
+ rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
403
+ max_position_embeddings=131072, sliding_window=4096,
404
+ tie_word_embeddings=True,
405
+ )
406
+
407
+ model = SpiderPortalForConditionalGeneration(config)
408
+ model = model.to(torch.bfloat16).to(device)
409
+ model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
410
+
411
+ if is_master:
412
+ params = model.module.get_num_params()
413
+ print(f"Model: {params['total']/1e6:.1f}M params")
414
+ print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
415
+
416
+ BASE_LR = 2e-5
417
+ WARMUP_STEPS = 1000
418
+ optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01)
419
+
420
+ import pandas as pd
421
+ data_dir = Path("/kaggle/input/datasets/cliworks/spiderp-custom-1")
422
+ all_records = []
423
+ if data_dir.exists():
424
+ pkl_file = data_dir / "spiderportal_combined.pkl"
425
+ if pkl_file.exists():
426
+ df = pd.read_pickle(pkl_file)
427
+ all_records = df.to_dict("records")
428
+ else:
429
+ for f in data_dir.glob("*.parquet"):
430
+ try:
431
+ df_pq = pd.read_parquet(f)
432
+ if all(c in df_pq.columns for c in ["instruction", "input", "output"]):
433
+ all_records.extend(df_pq[["instruction", "input", "output"]].to_dict("records"))
434
+ except:
435
+ pass
436
+
437
+ if not all_records:
438
+ if is_master:
439
+ print("No data found, creating synthetic data...")
440
+ all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
441
+
442
+ if is_master:
443
+ print(f"Loaded {len(all_records)} samples")
444
+
445
+ from transformers import AutoTokenizer
446
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
447
+ tokenizer.pad_token = tokenizer.eos_token
448
+
449
+ BATCH_SIZE = 8
450
+ GRAD_ACCUM = 2
451
+ MAX_LEN = 256
452
+ EPOCHS = 3
453
+ N_LOOPS = 1
454
+
455
+ if is_master:
456
+ print(f"Batch size: {BATCH_SIZE}, Grad accum: {GRAD_ACCUM}")
457
+ print(f"Effective batch: {BATCH_SIZE * GRAD_ACCUM * world_size}")
458
+ print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup")
459
+
460
+ def build_prompt(sample):
461
+ instruction = str(sample.get("instruction", "")).strip()
462
+ inp = str(sample.get("input", "")).strip()
463
+ output = str(sample.get("output", "")).strip()
464
+ if inp:
465
+ return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
466
+ return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
467
+
468
+ def format_sample(sample, tokenizer, max_len):
469
+ text = build_prompt(sample) + tokenizer.eos_token
470
+ enc = tokenizer(text, truncation=True, max_length=max_len, padding="max_length", return_tensors="pt")
471
+ input_ids = enc["input_ids"].squeeze(0)
472
+ labels = input_ids.clone()
473
+ prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
474
+ mask_len = min(len(prefix_ids), labels.shape[0])
475
+ labels[:mask_len] = -100
476
+ return {"input_ids": input_ids, "labels": labels}
477
+
478
+ shard_size = len(all_records) // world_size
479
+ start_idx = local_rank * shard_size
480
+ end_idx = start_idx + shard_size if local_rank < world_size - 1 else len(all_records)
481
+ local_samples = all_records[start_idx:end_idx]
482
+
483
+ if is_master:
484
+ print(f"Data sharding: {len(all_records)} total -> {len(local_samples)} per GPU")
485
+
486
+ global_step = 0
487
+ best_loss = float('inf')
488
+ start_time = time.time()
489
+
490
+ for epoch in range(1, EPOCHS + 1):
491
+ random.shuffle(local_samples)
492
+ total_loss = 0
493
+ num_batches = 0
494
+ optimizer.zero_grad()
495
+
496
+ for i in range(0, len(local_samples), BATCH_SIZE):
497
+ batch_samples = local_samples[i:i+BATCH_SIZE]
498
+ if len(batch_samples) < BATCH_SIZE:
499
+ continue
500
+
501
+ if global_step < WARMUP_STEPS:
502
+ lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
503
+ for param_group in optimizer.param_groups:
504
+ param_group['lr'] = lr
505
+
506
+ batch = [format_sample(s, tokenizer, MAX_LEN) for s in batch_samples]
507
+ input_ids = torch.stack([b["input_ids"] for b in batch]).to(device)
508
+ labels = torch.stack([b["labels"] for b in batch]).to(device)
509
+
510
+ if global_step == 0 and is_master:
511
+ print(" [First forward pass - CUDA graph building...]")
512
+
513
+ outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
514
+ loss = outputs["loss"] / GRAD_ACCUM
515
+ loss.backward()
516
+
517
+ if (i // BATCH_SIZE + 1) % GRAD_ACCUM == 0:
518
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
519
+ optimizer.step()
520
+ optimizer.zero_grad()
521
+ global_step += 1
522
+
523
+ total_loss += loss.item() * GRAD_ACCUM
524
+ num_batches += 1
525
+
526
+ if is_master and ((i // BATCH_SIZE) == 0 or (global_step < 20) or (global_step % 50 == 0)):
527
+ avg_loss = total_loss / max(num_batches, 1)
528
+ elapsed = time.time() - start_time
529
+ steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
530
+ current_lr = optimizer.param_groups[0]['lr']
531
+ print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | avg_loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.1f} steps/hr")
532
+
533
+ if is_master and global_step > 0 and global_step % 500 == 0:
534
+ output_dir = Path(f"/kaggle/working/spiderportal-v5-ep{epoch}-step{global_step}")
535
+ output_dir.mkdir(parents=True, exist_ok=True)
536
+ state_dict = {k: v.cpu() for k, v in model.module.state_dict().items()}
537
+ torch.save(state_dict, output_dir / "model.pt")
538
+ print(f"Saved checkpoint at step {global_step}")
539
+
540
+ if is_master:
541
+ avg_loss = total_loss / max(num_batches, 1)
542
+ print(f"Epoch {epoch}/{EPOCHS} | avg_loss={avg_loss:.4f} | Time: {(time.time()-start_time)/60:.1f}min")
543
+ output_dir = Path(f"/kaggle/working/spiderportal-v5-ep{epoch}")
544
+ output_dir.mkdir(parents=True, exist_ok=True)
545
+ state_dict = {k: v.cpu() for k, v in model.module.state_dict().items()}
546
+ torch.save(state_dict, output_dir / "model.pt")
547
+ if avg_loss < best_loss:
548
+ best_loss = avg_loss
549
+
550
+ if is_master:
551
+ print(f"Training complete! Best loss: {best_loss:.4f}")
552
+ print(f"Total time: {(time.time() - start_time)/60:.1f} minutes")
553
+
554
+ dist.destroy_process_group()
555
+
556
+ if __name__ == "__main__":
557
+ import sys
558
+ print(f"Script started, CUDA devices: {torch.cuda.device_count()}")
559
+ sys.stdout.flush()
560
+
561
+ num_gpus = torch.cuda.device_count()
562
+
563
+ if num_gpus <= 1:
564
+ dist.init_process_group = lambda *a, **k: None
565
+ dist.destroy_process_group = lambda: None
566
+ train_ddp(0, 1)
567
+ else:
568
+ print(f"Launching DDP with {num_gpus} GPUs...")
569
+ sys.stdout.flush()
570
+ import os
571
+ import signal
572
+
573
+ # Set NCCL environment before forking
574
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
575
+ os.environ["MASTER_PORT"] = "29500"
576
+ os.environ["WORLD_SIZE"] = str(num_gpus)
577
+ os.environ["LOCAL_WORLD_SIZE"] = str(num_gpus)
578
+
579
+ procs = []
580
+ for rank in range(num_gpus):
581
+ print(f"Forking rank {rank}...")
582
+ sys.stdout.flush()
583
+ pid = os.fork()
584
+ if pid == 0:
585
+ # Child process
586
+ os.environ["RANK"] = str(rank)
587
+ os.environ["LOCAL_RANK"] = str(rank)
588
+ print(f"Child {rank} starting training...")
589
+ sys.stdout.flush()
590
+ train_ddp(rank, num_gpus)
591
+ os._exit(0)
592
+ else:
593
+ procs.append(pid)
594
+ print(f"Parent: forked child {rank} with PID {pid}")
595
+ sys.stdout.flush()
596
+
597
+ # Parent process: wait for children
598
+ print("Parent waiting for children...")
599
+ sys.stdout.flush()
600
+ for pid in procs:
601
+ os.waitpid(pid, 0)
AI-Transfer/scripts/train_single_gpu.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """SpiderPortal v5 — Single-GPU Optimized Training.
3
+
4
+ For RTX PRO 6000 (96GB) or similar large-VRAM GPU.
5
+ No DDP, maximal batch size, torch.compile, pre-tokenized data.
6
+
7
+ Usage:
8
+ python train_single_gpu.py
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+ import os
16
+ import json
17
+ import gc
18
+ import random
19
+ import time
20
+ from pathlib import Path
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple, Dict, List
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+ @dataclass
26
+ class SpiderPortalConfig:
27
+ vocab_size: int = 50278
28
+ hidden_size: int = 384
29
+ num_hidden_layers: int = 8
30
+ num_attention_heads: int = 8
31
+ num_key_value_heads: int = 2
32
+ intermediate_size: int = 1024
33
+ hidden_act: str = "silu"
34
+ num_experts: int = 64
35
+ num_experts_per_tok: int = 1
36
+ num_shared_experts: int = 1
37
+ router_aux_loss_coef: float = 0.05
38
+ max_loop_iters: int = 4
39
+ act_threshold: float = 0.5
40
+ max_position_embeddings: int = 131072
41
+ rope_theta: float = 10000000.0
42
+ rope_scaling: dict = None
43
+ sliding_window: int = 4096
44
+ attention_dropout: float = 0.0
45
+ rms_norm_eps: float = 1e-6
46
+ initializer_range: float = 0.02
47
+ use_cache: bool = True
48
+ tie_word_embeddings: bool = True
49
+ prelude_layers: int = 2
50
+ coda_layers: int = 2
51
+ lora_rank: int = 32
52
+ loop_embed_dim: int = 48
53
+ vision_hidden_size: int = 384
54
+ audio_hidden_size: int = 512
55
+ vision_num_frames: int = 60
56
+ vision_tokens_per_frame: int = 256
57
+ vision_temporal_tokens: int = 64
58
+ vision_temporal_layers: int = 2
59
+ model_type: str = "spiderportal"
60
+ torch_dtype: str = "bfloat16"
61
+
62
+ def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
63
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
64
+ angles = loop_t * freqs
65
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
66
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
67
+ emb_full[:loop_dim] = emb
68
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
69
+
70
+ class SpiderPortalRMSNorm(nn.Module):
71
+ def __init__(self, hidden_size, eps=1e-6):
72
+ super().__init__()
73
+ self.weight = nn.Parameter(torch.ones(hidden_size))
74
+ self.variance_epsilon = eps
75
+ def forward(self, hidden_states):
76
+ input_dtype = hidden_states.dtype
77
+ hidden_states = hidden_states.to(torch.float32)
78
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
+ return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
81
+
82
+ def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
83
+ dim = head_dim
84
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
85
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
86
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
87
+ scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
88
+ return orig_inv_freq * scale
89
+
90
+ class SpiderPortalGQA(nn.Module):
91
+ def __init__(self, config):
92
+ super().__init__()
93
+ self.config = config
94
+ self.hidden_size = config.hidden_size
95
+ self.num_heads = config.num_attention_heads
96
+ self.num_kv_heads = config.num_key_value_heads
97
+ self.head_dim = config.hidden_size // config.num_attention_heads
98
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
99
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
100
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
101
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
102
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
103
+ self.attention_dropout = config.attention_dropout
104
+ rope_scaling = getattr(config, 'rope_scaling', None)
105
+ if rope_scaling and rope_scaling.get("type") == "yarn":
106
+ factor = rope_scaling.get("factor", 1.0)
107
+ orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
108
+ inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
109
+ else:
110
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
111
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
112
+ def _rotate_half(self, x):
113
+ x1 = x[..., :x.shape[-1] // 2]
114
+ x2 = x[..., x.shape[-1] // 2:]
115
+ return torch.cat((-x2, x1), dim=-1)
116
+ def _apply_rotary(self, x, cos, sin):
117
+ return (x * cos) + (self._rotate_half(x) * sin)
118
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
119
+ bsz, q_len, _ = hidden_states.size()
120
+ query_states = self.q_proj(hidden_states)
121
+ key_states = self.k_proj(hidden_states)
122
+ value_states = self.v_proj(hidden_states)
123
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
124
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
125
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
126
+ if position_ids is None:
127
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
128
+ max_pos = position_ids.max().item() + 1
129
+ seq_len = max(max_pos, q_len)
130
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ cos, sin = emb.cos(), emb.sin()
134
+ cos = cos[position_ids].unsqueeze(1)
135
+ sin = sin[position_ids].unsqueeze(1)
136
+ query_states = self._apply_rotary(query_states, cos, sin)
137
+ key_states = self._apply_rotary(key_states, cos, sin)
138
+ if past_key_value is not None:
139
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
140
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
141
+ past_kv = (key_states, value_states) if use_cache else None
142
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
143
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
144
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
145
+ if attention_mask is not None:
146
+ attn_weights = attn_weights + attention_mask
147
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
148
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
149
+ attn_output = torch.matmul(attn_weights, value_states)
150
+ attn_output = attn_output.transpose(1, 2).contiguous()
151
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
152
+ return self.o_proj(attn_output), past_kv
153
+
154
+ class SpiderPortalExpert(nn.Module):
155
+ def __init__(self, config, intermediate_size=None):
156
+ super().__init__()
157
+ inter_size = intermediate_size or config.intermediate_size
158
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
159
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
160
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
161
+ self.act_fn = nn.SiLU()
162
+ def forward(self, hidden_states):
163
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
164
+
165
+ class SpiderPortalRouter(nn.Module):
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.num_experts = config.num_experts
169
+ self.num_experts_per_tok = config.num_experts_per_tok
170
+ self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
171
+ self.register_buffer("router_bias", torch.zeros(config.num_experts))
172
+ def forward(self, hidden_states):
173
+ router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
174
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
175
+ biased_logits = router_logits + self.router_bias
176
+ biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
177
+ top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
178
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
179
+ top_weights = top_weights.to(hidden_states.dtype)
180
+ mean_probs = routing_weights.mean(dim=0)
181
+ aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
182
+ return top_weights, top_indices, aux_loss
183
+
184
+ class SpiderPortalMoE(nn.Module):
185
+ def __init__(self, config):
186
+ super().__init__()
187
+ self.config = config
188
+ self.num_experts = config.num_experts
189
+ self.num_experts_per_tok = config.num_experts_per_tok
190
+ self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
191
+ self.shared_expert = SpiderPortalExpert(config)
192
+ self.router = SpiderPortalRouter(config)
193
+ def forward(self, hidden_states):
194
+ batch_size, seq_len, hidden_dim = hidden_states.shape
195
+ top_weights, top_indices, aux_loss = self.router(hidden_states)
196
+ flat_hidden = hidden_states.view(-1, hidden_dim)
197
+ final_output = torch.zeros_like(flat_hidden)
198
+ for expert_idx in range(self.num_experts_per_tok):
199
+ expert_ids = top_indices[:, expert_idx]
200
+ expert_weights = top_weights[:, expert_idx:expert_idx+1]
201
+ for e in range(self.num_experts):
202
+ mask = expert_ids == e
203
+ if mask.any():
204
+ expert_output = self.experts[e](flat_hidden[mask])
205
+ final_output[mask] += expert_output * expert_weights[mask]
206
+ shared_output = self.shared_expert(flat_hidden)
207
+ final_output = final_output + shared_output
208
+ return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
209
+
210
+ class SpiderPortalDenseLayer(nn.Module):
211
+ def __init__(self, config):
212
+ super().__init__()
213
+ self.self_attn = SpiderPortalGQA(config)
214
+ dense_intermediate = config.hidden_size * 4 // 3
215
+ self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
216
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
219
+ attn_input = self.input_layernorm(hidden_states)
220
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
221
+ hidden_states = hidden_states + attn_output
222
+ ffn_input = self.post_attention_layernorm(hidden_states)
223
+ ffn_output = self.ffn(ffn_input)
224
+ hidden_states = hidden_states + ffn_output
225
+ return hidden_states, past_kv
226
+
227
+ class SpiderPortalMoELayer(nn.Module):
228
+ def __init__(self, config, layer_idx):
229
+ super().__init__()
230
+ self.layer_idx = layer_idx
231
+ self.self_attn = SpiderPortalGQA(config)
232
+ self.moe = SpiderPortalMoE(config)
233
+ self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234
+ self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
236
+ attn_input = self.input_layernorm(hidden_states)
237
+ attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
238
+ hidden_states = hidden_states + attn_output
239
+ moe_input = self.post_attention_layernorm(hidden_states)
240
+ moe_output, aux_loss = self.moe(moe_input)
241
+ hidden_states = hidden_states + moe_output
242
+ return hidden_states, aux_loss, past_kv
243
+
244
+ class LTIInjection(nn.Module):
245
+ def __init__(self, config):
246
+ super().__init__()
247
+ self.hidden_size = config.hidden_size
248
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
249
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
250
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
251
+ with torch.no_grad():
252
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
253
+ def get_A(self):
254
+ return -torch.exp(self.log_A)
255
+ def forward(self, h_t, e):
256
+ A = self.get_A()
257
+ return A * h_t + self.B(e)
258
+
259
+ class ACTHalting(nn.Module):
260
+ def __init__(self, config):
261
+ super().__init__()
262
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
263
+ self.threshold = config.act_threshold
264
+ def forward(self, hidden_states):
265
+ return torch.sigmoid(self.halt_predictor(hidden_states))
266
+
267
+ class LoRAAdapter(nn.Module):
268
+ def __init__(self, config):
269
+ super().__init__()
270
+ rank = config.lora_rank
271
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
272
+ self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
273
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
274
+ with torch.no_grad():
275
+ self.scale.weight.data.zero_()
276
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
277
+ def forward(self, x, loop_t):
278
+ max_t = self.scale.num_embeddings - 1
279
+ t_idx = min(loop_t, max_t)
280
+ s = self.scale(torch.tensor(t_idx, device=x.device))
281
+ down = self.down(x) * s
282
+ return down @ self.B
283
+
284
+ class SpiderPortalMoEModel(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.config = config
288
+ self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
289
+ self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
290
+ self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
291
+ self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
292
+ self.injection = LTIInjection(config)
293
+ self.act_halting = ACTHalting(config)
294
+ self.lora_adapter = LoRAAdapter(config)
295
+ self.loop_embed_dim = config.loop_embed_dim
296
+ def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
297
+ n_loops = n_loops or self.config.max_loop_iters
298
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
299
+ total_aux_loss = 0.0
300
+ for layer in self.prelude_layers:
301
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
302
+ e = hidden_states.clone()
303
+ B, T_seq, D = hidden_states.shape
304
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
305
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
306
+ h_out = torch.zeros_like(hidden_states)
307
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
308
+ for t in range(n_loops):
309
+ h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
310
+ if t > 0:
311
+ injection = self.injection(hidden_states, input_embedding)
312
+ hidden_states = hidden_states + injection
313
+ new_past_key_values = []
314
+ for i, layer in enumerate(self.recurrent_layers):
315
+ hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
316
+ total_aux_loss = total_aux_loss + aux_loss
317
+ new_past_key_values.append(past_kv)
318
+ lora_delta = self.lora_adapter(hidden_states, t)
319
+ hidden_states = hidden_states + lora_delta
320
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
321
+ still_running = ~halted
322
+ remainder = (1.0 - cumulative_p).clamp(min=0)
323
+ weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
324
+ weight = weight * still_running.to(hidden_states.dtype)
325
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
326
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
327
+ halted = halted | (cumulative_p >= self.config.act_threshold)
328
+ if halted.all() and not self.training:
329
+ break
330
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
331
+ hidden_states = h_out + never_halted * hidden_states
332
+ for layer in self.coda_layers:
333
+ hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
334
+ hidden_states = self.norm(hidden_states)
335
+ return hidden_states, total_aux_loss, new_past_key_values
336
+
337
+ class SpiderPortalForConditionalGeneration(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.config = config
341
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
342
+ self.model = SpiderPortalMoEModel(config)
343
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
344
+ if config.tie_word_embeddings:
345
+ self.lm_head.weight = self.embed_tokens.weight
346
+ self.apply(self._init_weights)
347
+ def _init_weights(self, module):
348
+ if isinstance(module, nn.Linear):
349
+ if hasattr(self, 'model') and module is self.model.injection.B:
350
+ return
351
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
352
+ if module.bias is not None:
353
+ module.bias.data.zero_()
354
+ elif isinstance(module, nn.Embedding):
355
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
356
+ def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
357
+ hidden_states = self.embed_tokens(input_ids)
358
+ model_dtype = next(self.model.parameters()).dtype
359
+ hidden_states = hidden_states.to(model_dtype)
360
+ input_embedding = hidden_states.clone()
361
+ if attention_mask is None:
362
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
363
+ causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
364
+ causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
365
+ causal_mask = causal_mask.triu(1)
366
+ hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
367
+ logits = self.lm_head(hidden_states)
368
+ loss = None
369
+ if labels is not None:
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = labels[..., 1:].contiguous()
372
+ loss_fct = CrossEntropyLoss()
373
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
374
+ loss = loss + self.config.router_aux_loss_coef * aux_loss
375
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
376
+ def get_num_params(self):
377
+ total = sum(p.numel() for p in self.parameters())
378
+ return {"total": total, "trainable": total}
379
+
380
+ def train_single_gpu():
381
+ device = torch.device("cuda")
382
+ gpu_name = torch.cuda.get_device_name(0)
383
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
384
+ print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
385
+
386
+ config = SpiderPortalConfig(
387
+ hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
388
+ num_key_value_heads=2, intermediate_size=1024,
389
+ num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
390
+ router_aux_loss_coef=0.05, max_loop_iters=2,
391
+ prelude_layers=2, coda_layers=2, lora_rank=32,
392
+ rope_theta=10000000.0,
393
+ rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
394
+ max_position_embeddings=131072, sliding_window=4096,
395
+ tie_word_embeddings=True,
396
+ )
397
+
398
+ print("Building model...")
399
+ model = SpiderPortalForConditionalGeneration(config)
400
+ model = model.to(torch.bfloat16).to(device)
401
+
402
+ params = model.get_num_params()
403
+ print(f"Model: {params['total']/1e6:.1f}M params")
404
+ print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
405
+
406
+ try:
407
+ model = torch.compile(model, mode="reduce-overhead")
408
+ print("torch.compile: enabled")
409
+ except Exception:
410
+ print("torch.compile: not available, using eager mode")
411
+
412
+ BASE_LR = 2e-5
413
+ WARMUP_STEPS = 1000
414
+ optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01)
415
+
416
+ import pandas as pd
417
+ data_dir = Path(__file__).parent / "data"
418
+ all_records = []
419
+ pkl_file = data_dir / "spiderportal_combined.pkl"
420
+ if pkl_file.exists():
421
+ print(f"Loading dataset from {pkl_file}...")
422
+ df = pd.read_pickle(pkl_file)
423
+ all_records = df.to_dict("records")
424
+ else:
425
+ print(f"No dataset found at {pkl_file}, creating synthetic data...")
426
+ all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
427
+
428
+ print(f"Loaded {len(all_records):,} samples")
429
+
430
+ from transformers import AutoTokenizer
431
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
432
+ tokenizer.pad_token = tokenizer.eos_token
433
+
434
+ BATCH_SIZE = 128
435
+ MAX_LEN = 256
436
+ EPOCHS = 3
437
+ N_LOOPS = 2
438
+
439
+ print(f"Batch size: {BATCH_SIZE} (no grad accum)")
440
+ print(f"Effective batch: {BATCH_SIZE}")
441
+ print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup")
442
+ print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}")
443
+
444
+ def build_prompt(sample):
445
+ instruction = str(sample.get("instruction", "")).strip()
446
+ inp = str(sample.get("input", "")).strip()
447
+ output = str(sample.get("output", "")).strip()
448
+ if inp:
449
+ return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
450
+ return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
451
+
452
+ print("Pre-tokenizing dataset...")
453
+ prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
454
+ mask_len = len(prefix_ids)
455
+
456
+ pre_tokenized = []
457
+ for i, sample in enumerate(all_records):
458
+ instruction = str(sample.get("instruction", "")).strip()
459
+ inp = str(sample.get("input", "")).strip()
460
+ output = str(sample.get("output", "")).strip()
461
+ if inp:
462
+ text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token
463
+ else:
464
+ text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token
465
+ enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length")
466
+ input_ids = enc["input_ids"]
467
+ labels = input_ids[:]
468
+ for j in range(min(mask_len, len(labels))):
469
+ labels[j] = -100
470
+ pre_tokenized.append((input_ids, labels))
471
+ if (i + 1) % 50000 == 0:
472
+ print(f" Tokenized {i+1:,}/{len(all_records):,}")
473
+
474
+ print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples")
475
+ del all_records
476
+ gc.collect()
477
+
478
+ global_step = 0
479
+ best_loss = float('inf')
480
+ start_time = time.time()
481
+ checkpoint_dir = Path("checkpoints")
482
+ checkpoint_dir.mkdir(exist_ok=True)
483
+ step_ckpt_files = []
484
+
485
+ for epoch in range(1, EPOCHS + 1):
486
+ if epoch > 1:
487
+ for f in step_ckpt_files:
488
+ if f.exists():
489
+ f.unlink()
490
+ print(f" Deleted old step checkpoint: {f.name}")
491
+ step_ckpt_files.clear()
492
+ gc.collect()
493
+
494
+ indices = list(range(len(pre_tokenized)))
495
+ random.shuffle(indices)
496
+ total_loss = 0
497
+ num_batches = 0
498
+ optimizer.zero_grad()
499
+
500
+ for batch_start in range(0, len(indices), BATCH_SIZE):
501
+ batch_indices = indices[batch_start:batch_start + BATCH_SIZE]
502
+ if len(batch_indices) < BATCH_SIZE:
503
+ continue
504
+
505
+ if global_step < WARMUP_STEPS:
506
+ lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
507
+ for param_group in optimizer.param_groups:
508
+ param_group['lr'] = lr
509
+
510
+ batch_input_ids = []
511
+ batch_labels = []
512
+ for idx in batch_indices:
513
+ input_ids, labels = pre_tokenized[idx]
514
+ batch_input_ids.append(input_ids)
515
+ batch_labels.append(labels)
516
+
517
+ input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device)
518
+ labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
519
+
520
+ if global_step == 0:
521
+ print(" [First forward pass - compiling...]")
522
+
523
+ outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
524
+ loss = outputs["loss"]
525
+ loss.backward()
526
+
527
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
528
+ optimizer.step()
529
+ optimizer.zero_grad()
530
+ global_step += 1
531
+
532
+ total_loss += loss.item()
533
+ num_batches += 1
534
+
535
+ if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0:
536
+ avg_loss = total_loss / max(num_batches, 1)
537
+ elapsed = time.time() - start_time
538
+ steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
539
+ current_lr = optimizer.param_groups[0]['lr']
540
+ samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0
541
+ print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec")
542
+
543
+ if global_step > 0 and global_step % 500 == 0:
544
+ ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt"
545
+ state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
546
+ torch.save(state_dict, ckpt_path)
547
+ step_ckpt_files.append(ckpt_path)
548
+ size_mb = ckpt_path.stat().st_size / (1024 * 1024)
549
+ print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
550
+
551
+ avg_loss = total_loss / max(num_batches, 1)
552
+ epoch_time = (time.time() - start_time) / 60
553
+ print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min")
554
+
555
+ ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt"
556
+ torch.save({
557
+ "step": global_step,
558
+ "epoch": epoch,
559
+ "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
560
+ "optimizer_state_dict": optimizer.state_dict(),
561
+ "config": config.__dict__,
562
+ }, ckpt_path)
563
+ size_mb = ckpt_path.stat().st_size / (1024 * 1024)
564
+ print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
565
+
566
+ if avg_loss < best_loss:
567
+ best_loss = avg_loss
568
+ best_path = checkpoint_dir / "spiderportal-v5-best.pt"
569
+ torch.save({
570
+ "step": global_step,
571
+ "epoch": epoch,
572
+ "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
573
+ "optimizer_state_dict": optimizer.state_dict(),
574
+ "config": config.__dict__,
575
+ }, best_path)
576
+ size_mb = best_path.stat().st_size / (1024 * 1024)
577
+ print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)")
578
+
579
+ total_time = (time.time() - start_time) / 3600
580
+ print(f"\nTraining complete!")
581
+ print(f"Best loss: {best_loss:.4f}")
582
+ print(f"Total time: {total_time:.2f} hours")
583
+ print(f"Total steps: {global_step}")
584
+ print(f"Checkpoints saved to: {checkpoint_dir}")
585
+
586
+ if __name__ == "__main__":
587
+ train_single_gpu()
SpiderPortal-Base-v5/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SpiderPortalForConditionalGeneration"
4
+ ],
5
+ "model_type": "spiderportal",
6
+ "vocab_size": 50278,
7
+ "hidden_size": 384,
8
+ "num_hidden_layers": 8,
9
+ "num_attention_heads": 8,
10
+ "num_key_value_heads": 2,
11
+ "intermediate_size": 1024,
12
+ "hidden_act": "silu",
13
+ "max_position_embeddings": 131072,
14
+ "rope_theta": 10000000.0,
15
+ "rope_scaling": {
16
+ "type": "yarn",
17
+ "factor": 4.0,
18
+ "original_max_position_embeddings": 32768
19
+ },
20
+ "sliding_window": 4096,
21
+ "rms_norm_eps": 1e-06,
22
+ "initializer_range": 0.02,
23
+ "tie_word_embeddings": true,
24
+ "torch_dtype": "bfloat16",
25
+ "num_experts": 64,
26
+ "num_experts_per_tok": 1,
27
+ "num_shared_experts": 1,
28
+ "router_aux_loss_coef": 0.05,
29
+ "max_loop_iters": 4,
30
+ "act_threshold": 0.5,
31
+ "prelude_layers": 2,
32
+ "coda_layers": 2,
33
+ "lora_rank": 32,
34
+ "enable_talker": true,
35
+ "vision_encoder_type": "eupe_vit_s",
36
+ "vision_encoder_path": "facebook/EUPE-ViT-S",
37
+ "vision_hidden_size": 384,
38
+ "audio_encoder_type": "moonshine_base",
39
+ "audio_encoder_path": "UsefulSensors/moonshine-base",
40
+ "audio_hidden_size": 512,
41
+ "vision_num_frames": 60,
42
+ "vision_tokens_per_frame": 256,
43
+ "vision_temporal_tokens": 64,
44
+ "vision_temporal_layers": 2
45
+ }
SpiderPortal-Base-v5/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
SpiderPortal-Base-v5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:256d05574c2c2ffe38f5c3b5eb7badfc614e35ac7898452a4e43e922cc41d19c
3
+ size 1288358204
SpiderPortal-Base-v5/special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "<|endoftext|>",
3
+ "eos_token": "<|im_end|>"
4
+ }
SpiderPortal-Base-v5/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
SpiderPortal-Base-v5/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 1024}
SpiderPortal-Base-v5/vocab.json ADDED
The diff for this file is too large to render. See raw diff