CLIWorks commited on
Commit
546fd8f
·
verified ·
1 Parent(s): edb6a10

Upload transfer_weights.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transfer_weights.py +1883 -0
transfer_weights.py ADDED
@@ -0,0 +1,1883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Weight transfer from Qwen3.5-2B donor to Spider-FLEXITOKENS architecture.
3
+
4
+ Implements the weight transfer pipeline per D-09 and D-10:
5
+ - Loads Qwen3.5-2B via HF transformers
6
+ - Filters to full_attention layers only (discards linear_attention)
7
+ - SVD decomposition converts standard GQA attention to MLA format
8
+ - Direct copies where shapes match (o_proj, layer norms)
9
+ - Reinitializes incompatible weights (embeddings, boundary predictor, FFN)
10
+ - Reports transfer coverage as percentage
11
+
12
+ Usage:
13
+ python scripts/transfer_weights.py --donor Qwen/Qwen3.5-2B --output models/Spider-FLEXITOKENS-init/ --config spider_flexitokens_997m
14
+ """
15
+
16
+ import argparse
17
+ import hashlib
18
+ import json
19
+ import math
20
+ import os
21
+ import sys
22
+ from dataclasses import dataclass, field
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ # Import canonical Spider architecture components from spider.py
31
+ # (replaces previously duplicated code — per VERIFICATION gap #2, #4, #5)
32
+ from spider import (
33
+ SENTINEL_TOKENS,
34
+ is_sentinel_token,
35
+ create_modality_mask,
36
+ BoundaryPredictor,
37
+ downsample,
38
+ upsample,
39
+ SpiderConfig as _CanonicalSpiderConfig,
40
+ spider_flexitokens_997m as _canonical_config_fn,
41
+ )
42
+
43
+ # Reverse mapping for sentinel token IDs to names (IN-01 fix: computed once)
44
+ _TOKEN_NAMES_BY_ID = {v: k for k, v in SENTINEL_TOKENS.items()}
45
+
46
+
47
+ # ============================================================================
48
+ # Sentinel Token Vocabulary — imported from spider.py (D-06, D-11)
49
+ # ============================================================================
50
+ # SENTINEL_TOKENS, is_sentinel_token, create_modality_mask are now imported
51
+ # from spider.py. _SENTINEL_PAIRS and _MODALITY_SENTINEL_IDS are used
52
+ # locally for transfer logic.
53
+ _SENTINEL_PAIRS = [
54
+ (SENTINEL_TOKENS['IMG_START'], SENTINEL_TOKENS['IMG_END']), # (259, 260)
55
+ (SENTINEL_TOKENS['AUD_START'], SENTINEL_TOKENS['AUD_END']), # (261, 262)
56
+ (SENTINEL_TOKENS['VID_START'], SENTINEL_TOKENS['VID_END']), # (263, 264)
57
+ ]
58
+ _MODALITY_SENTINEL_IDS = {259, 260, 261, 262, 263, 264}
59
+
60
+
61
+ # ============================================================================
62
+ # BoundaryPredictor — imported from spider.py (D-04, D-11)
63
+ # ============================================================================
64
+ # BoundaryPredictor is now imported from spider.py.
65
+
66
+
67
+ # ============================================================================
68
+ # Downsample / Upsample — imported from spider.py (D-05, D-08, D-11)
69
+ # ============================================================================
70
+ # downsample, upsample, _downsample_common, _downsample_final are now
71
+ # imported from spider.py.
72
+
73
+
74
+ # ============================================================================
75
+ # Spider Configuration
76
+ # ============================================================================
77
+
78
+ @dataclass
79
+ class SpiderConfig:
80
+ """Spider-FLEXITOKENS model configuration (hidden_size=2048).
81
+
82
+ Based on mythos-fineweb-moe.py SpiderPortalConfig with byte-level
83
+ tokenization and MLA attention. Mirrors canonical spider.py config.
84
+ """
85
+ # Core architecture
86
+ vocab_size: int = 272 # 256 bytes + 16 specials (D-06)
87
+ hidden_size: int = 2048
88
+ num_hidden_layers: int = 6 # recurrent layers
89
+ num_attention_heads: int = 16
90
+ num_key_value_heads: int = 4 # not used directly in MLA but kept for compat
91
+ intermediate_size: int = 1024
92
+ hidden_act: str = "silu"
93
+
94
+ # MoE configuration (D-20, D-21: shared-projection MoE)
95
+ num_experts: int = 32
96
+ num_experts_per_tok: int = 2
97
+ num_shared_experts: int = 1
98
+ router_aux_loss_coef: float = 0.05
99
+ shared_intermediate_size: int = 6144
100
+ expert_core_rank: int = 256
101
+ shared_expert_intermediate_size: int = 7424
102
+ prelude_coda_intermediate_size: int = 4096
103
+
104
+ # RDT configuration
105
+ max_loop_iters: int = 16
106
+ act_threshold: float = 0.5
107
+ prelude_layers: int = 2
108
+ coda_layers: int = 2
109
+ lora_rank: int = 128
110
+ loop_embed_dim: int = 128
111
+
112
+ # MLA parameters (DeepSeek-V2 style)
113
+ kv_lora_rank: int = 128
114
+ q_lora_rank: int = 256
115
+ qk_rope_head_dim: int = 64
116
+ qk_nope_head_dim: int = 64
117
+ v_head_dim: int = 64
118
+
119
+ # Attention / RoPE
120
+ max_position_embeddings: int = 262144 # 256k context
121
+ rope_theta: float = 10000000.0
122
+ rope_scaling: Optional[Dict] = field(default_factory=lambda: {
123
+ "type": "yarn",
124
+ "factor": 8.0,
125
+ "original_max_position_embeddings": 32768,
126
+ })
127
+ sliding_window: int = 8192 # local attention window
128
+ attention_dropout: float = 0.0
129
+ rms_norm_eps: float = 1e-6
130
+ initializer_range: float = 0.02
131
+
132
+ # Embeddings / head
133
+ tie_word_embeddings: bool = True # Tied per D-06 (byte-level vocab)
134
+
135
+ # Metadata
136
+ model_type: str = "spider"
137
+ torch_dtype: str = "bfloat16"
138
+
139
+ # BoundaryPredictor
140
+ bp_d_inner: int = 8192
141
+
142
+ # Engram (N-gram memory, D-20 revision)
143
+ engram_layers: list = None # set in __post_init__
144
+ engram_table_size: int = 8191
145
+ engram_heads: int = 4
146
+ engram_dim: int = 128
147
+ engram_offload: bool = True
148
+
149
+ # Multimodal
150
+ vision_hidden_size: int = 2048
151
+ audio_hidden_size: int = 512
152
+ vision_num_frames: int = 60
153
+ vision_tokens_per_frame: int = 256
154
+ vision_temporal_tokens: int = 64
155
+ vision_temporal_layers: int = 2
156
+
157
+ @property
158
+ def head_dim(self):
159
+ return self.qk_nope_head_dim + self.qk_rope_head_dim # 128
160
+
161
+ def __post_init__(self):
162
+ if self.engram_layers is None:
163
+ self.engram_layers = [1, 4]
164
+
165
+
166
+ def spider_flexitokens_997m() -> SpiderConfig:
167
+ """Spider-FLEXITOKENS 997M config."""
168
+ return SpiderConfig()
169
+
170
+
171
+ # ============================================================================
172
+ # Dummy Donor (for testing without downloading 6GB model)
173
+ # ============================================================================
174
+
175
+ def create_dummy_donor(num_layers: int = 4, full_attention_layers: Optional[List[int]] = None, mini: bool = False):
176
+ """Create a dummy Qwen3.5-2B-like donor state dict and config.
177
+
178
+ Mimics the structure of Qwen3.5-2B with:
179
+ - hidden_size=2048, num_heads=8, num_kv_heads=2, head_dim=256
180
+ - full_attention and linear_attention layer identification
181
+ - intermediate_size=6144
182
+ - vocab_size=248320
183
+
184
+ Args:
185
+ num_layers: Number of layers to create
186
+ full_attention_layers: Indices of full_attention layers (default: all)
187
+ mini: If True, use smaller tensors for fast testing
188
+
189
+ Returns:
190
+ Dict with "state_dict", "config" keys
191
+ """
192
+ hidden_size = 2048
193
+ num_heads = 8
194
+ num_kv_heads = 2
195
+ head_dim = 256 # Qwen3.5-2B: 2048 / 8 = 256
196
+ intermediate_size = 6144
197
+ vocab_size = 248320
198
+
199
+ if full_attention_layers is None:
200
+ # Default: all layers are full_attention for testing
201
+ full_attention_layers = list(range(num_layers))
202
+
203
+ # Scale factor for mini mode (reduces tensor sizes for fast testing)
204
+ scale = 8 if mini else 1
205
+ hs = hidden_size // scale
206
+ n_h = max(num_heads // scale, 1)
207
+ n_kv_h = max(num_kv_heads // scale, 1)
208
+ hd = head_dim # Keep head_dim the same for shape correctness
209
+ inter = intermediate_size // scale
210
+ vs = min(vocab_size, 1024) if mini else vocab_size
211
+
212
+ state_dict = {}
213
+
214
+ # Embeddings
215
+ state_dict["model.embed_tokens.weight"] = torch.randn(vs, hs) * 0.02
216
+
217
+ # Per-layer weights
218
+ for i in range(num_layers):
219
+ prefix = f"model.layers.{i}"
220
+ # Attention projections (Qwen3.5-2B layout)
221
+ state_dict[f"{prefix}.self_attn.q_proj.weight"] = torch.randn(n_h * hd, hs) * 0.02
222
+ state_dict[f"{prefix}.self_attn.k_proj.weight"] = torch.randn(n_kv_h * hd, hs) * 0.02
223
+ state_dict[f"{prefix}.self_attn.v_proj.weight"] = torch.randn(n_kv_h * hd, hs) * 0.02
224
+ state_dict[f"{prefix}.self_attn.o_proj.weight"] = torch.randn(hs, hs) * 0.02
225
+ # Layer norms
226
+ state_dict[f"{prefix}.input_layernorm.weight"] = torch.ones(hs, dtype=torch.float32)
227
+ state_dict[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hs, dtype=torch.float32)
228
+ # FFN (SwiGLU: gate + up + down)
229
+ state_dict[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(inter, hs) * 0.02
230
+ state_dict[f"{prefix}.mlp.up_proj.weight"] = torch.randn(inter, hs) * 0.02
231
+ state_dict[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hs, inter) * 0.02
232
+
233
+ # Final norm
234
+ state_dict["model.norm.weight"] = torch.ones(hs, dtype=torch.float32)
235
+ # LM head
236
+ state_dict["lm_head.weight"] = torch.randn(vs, hs) * 0.02
237
+
238
+ config = {
239
+ "hidden_size": hs,
240
+ "num_attention_heads": n_h,
241
+ "num_key_value_heads": n_kv_h,
242
+ "head_dim": hd,
243
+ "intermediate_size": inter,
244
+ "vocab_size": vs,
245
+ "num_hidden_layers": num_layers,
246
+ "full_attention_layers": full_attention_layers,
247
+ "model_type": "qwen3",
248
+ "mini": mini,
249
+ }
250
+
251
+ return {"state_dict": state_dict, "config": config}
252
+
253
+
254
+ # ============================================================================
255
+ # SVD Decomposition for MLA Conversion
256
+ # ============================================================================
257
+
258
+ def decompose_attention_svd(
259
+ weight: torch.Tensor,
260
+ lora_rank: int,
261
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
262
+ """SVD decompose a weight matrix into low-rank a_proj and b_proj.
263
+
264
+ Per D-10: Decompression (b_proj) matrices initialized from SVD;
265
+ compression (a_proj) matrices are reinitialized with Kaiming init.
266
+
267
+ Args:
268
+ weight: Weight matrix of shape [in_features, out_features] or
269
+ [out_features, in_features]. For Linear(in, out, bias=False),
270
+ PyTorch stores weight as [out_features, in_features].
271
+ lora_rank: Target rank for the low-rank decomposition.
272
+
273
+ Returns:
274
+ Tuple of (a_proj, b_proj) where:
275
+ - a_proj: [in_features, lora_rank] — compression (REINITIALIZED by caller)
276
+ - b_proj: [lora_rank, out_features] — decompression (from SVD)
277
+ """
278
+ # Ensure weight is 2D (IN-03 fix: proper ValueError instead of assert)
279
+ if weight.dim() != 2:
280
+ raise ValueError(f"Expected 2D weight, got {weight.dim()}D")
281
+
282
+ # Determine the orientation: we want to decompose W ≈ a @ b
283
+ # where a: [in_features, rank] and b: [rank, out_features]
284
+ # PyTorch Linear stores as [out_features, in_features]
285
+ # We decompose W.T so: W.T = U @ diag(S) @ Vh
286
+ # a = U[:, :rank] @ diag(S[:rank]) shape [in_features, rank]
287
+ # b = Vh[:rank, :] shape [rank, out_features]
288
+
289
+ # Work in float32 for SVD stability
290
+ weight_f32 = weight.float()
291
+
292
+ # SVD decomposition
293
+ U, S, Vh = torch.linalg.svd(weight_f32, full_matrices=False)
294
+
295
+ # Truncate to target rank
296
+ a_proj = U[:, :lora_rank] @ torch.diag(S[:lora_rank]) # [in_features, rank]
297
+ b_proj = Vh[:lora_rank, :] # [rank, out_features]
298
+
299
+ return a_proj, b_proj
300
+
301
+
302
+ # ============================================================================
303
+ # MoE Expert Splitting
304
+ # ============================================================================
305
+
306
+ def split_dense_to_moe(
307
+ spider_state_dict: Dict[str, torch.Tensor],
308
+ config: SpiderConfig,
309
+ noise_scale: float = 0.02,
310
+ ) -> Dict[str, torch.Tensor]:
311
+ """Initialize SharedProjectionMoE expert cores and router per D-20/D-21.
312
+
313
+ Per D-21: W_gate and W_transform are randomly initialized with small
314
+ normal noise (std=0.02) to break symmetry. shared_up, shared_down,
315
+ and shared_expert are already populated by transfer_qwen_to_spider.
316
+
317
+ Args:
318
+ spider_state_dict: Spider model state dict (mutated in-place)
319
+ config: Spider model config
320
+ noise_scale: Noise std for expert core initialization
321
+
322
+ Returns:
323
+ Updated state dict with SharedProjectionMoE weights
324
+ """
325
+ for layer_idx in range(config.num_hidden_layers):
326
+ rec_prefix = f"model.recurrent_layers.{layer_idx}.moe"
327
+
328
+ # W_gate: [num_experts, hidden_size, expert_core_rank]
329
+ w_gate_key = f"{rec_prefix}.W_gate"
330
+ if w_gate_key not in spider_state_dict:
331
+ spider_state_dict[w_gate_key] = (
332
+ torch.randn(config.num_experts, config.hidden_size, config.expert_core_rank)
333
+ * noise_scale
334
+ )
335
+
336
+ # W_transform: [num_experts, expert_core_rank, shared_intermediate_size]
337
+ w_transform_key = f"{rec_prefix}.W_transform"
338
+ if w_transform_key not in spider_state_dict:
339
+ spider_state_dict[w_transform_key] = (
340
+ torch.randn(config.num_experts, config.expert_core_rank, config.shared_intermediate_size)
341
+ * noise_scale
342
+ )
343
+
344
+ # Router weight: [num_experts, hidden_size]
345
+ router_key = f"{rec_prefix}.router.weight"
346
+ if router_key not in spider_state_dict:
347
+ spider_state_dict[router_key] = (
348
+ torch.randn(config.num_experts, config.hidden_size)
349
+ * config.initializer_range
350
+ )
351
+
352
+ # Router bias: [num_experts]
353
+ router_bias_key = f"{rec_prefix}.router.bias"
354
+ if router_bias_key not in spider_state_dict:
355
+ spider_state_dict[router_bias_key] = torch.zeros(config.num_experts, dtype=torch.float32)
356
+
357
+ return spider_state_dict
358
+
359
+
360
+ # ============================================================================
361
+ # Get Spider Parameter Shapes
362
+ # ============================================================================
363
+
364
+ def get_spider_param_shapes(config: SpiderConfig) -> Dict[str, Tuple[int, ...]]:
365
+ """Return expected parameter shapes for the Spider model.
366
+
367
+ Used for validation that all shapes match after weight transfer.
368
+ """
369
+ shapes = {}
370
+
371
+ # Embeddings
372
+ shapes["embed_tokens.weight"] = (config.vocab_size, config.hidden_size)
373
+ shapes["lm_head.weight"] = (config.vocab_size, config.hidden_size)
374
+
375
+ # BoundaryPredictor: nn.Sequential(Linear(2048, 8192), GELU(), Linear(8192, 1))
376
+ shapes["boundary_predictor.0.weight"] = (config.bp_d_inner, config.hidden_size)
377
+ shapes["boundary_predictor.0.bias"] = (config.bp_d_inner,)
378
+ shapes["boundary_predictor.2.weight"] = (1, config.bp_d_inner)
379
+ shapes["boundary_predictor.2.bias"] = (1,)
380
+
381
+ # null_group for downsample
382
+ shapes["null_group.weight"] = (config.hidden_size,)
383
+
384
+ # down_ln for downsample
385
+ shapes["down_ln.weight"] = (config.hidden_size,)
386
+ shapes["down_ln.bias"] = (config.hidden_size,)
387
+
388
+ head_dim = config.head_dim # 128
389
+
390
+ for section, num_layers in [
391
+ ("prelude_layers", config.prelude_layers),
392
+ ("coda_layers", config.coda_layers),
393
+ ]:
394
+ for i in range(num_layers):
395
+ prefix = f"model.{section}.{i}"
396
+
397
+ # MLA attention projections
398
+ shapes[f"{prefix}.self_attn.q_a_proj.weight"] = (config.q_lora_rank, config.hidden_size)
399
+ shapes[f"{prefix}.self_attn.q_a_layernorm.weight"] = (config.q_lora_rank,)
400
+ shapes[f"{prefix}.self_attn.q_b_proj.weight"] = (config.num_attention_heads * head_dim, config.q_lora_rank)
401
+ shapes[f"{prefix}.self_attn.kv_a_proj_with_mqa.weight"] = (config.kv_lora_rank + config.qk_rope_head_dim, config.hidden_size)
402
+ shapes[f"{prefix}.self_attn.kv_a_layernorm.weight"] = (config.kv_lora_rank,)
403
+ shapes[f"{prefix}.self_attn.kv_b_proj.weight"] = (config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim), config.kv_lora_rank)
404
+ shapes[f"{prefix}.self_attn.o_proj.weight"] = (config.hidden_size, config.num_attention_heads * config.v_head_dim)
405
+
406
+ # Layer norms
407
+ shapes[f"{prefix}.input_layernorm.weight"] = (config.hidden_size,)
408
+ shapes[f"{prefix}.post_attention_layernorm.weight"] = (config.hidden_size,)
409
+
410
+ # FFN (dense for prelude/coda, uses SpiderExpert SwiGLU with prelude_coda_intermediate_size)
411
+ dense_inter = config.prelude_coda_intermediate_size
412
+ shapes[f"{prefix}.ffn.gate_proj.weight"] = (dense_inter, config.hidden_size)
413
+ shapes[f"{prefix}.ffn.up_proj.weight"] = (dense_inter, config.hidden_size)
414
+ shapes[f"{prefix}.ffn.down_proj.weight"] = (config.hidden_size, dense_inter)
415
+
416
+ # Recurrent (MoE) layers
417
+ for i in range(config.num_hidden_layers):
418
+ prefix = f"model.recurrent_layers.{i}"
419
+
420
+ # MLA attention
421
+ shapes[f"{prefix}.self_attn.q_a_proj.weight"] = (config.q_lora_rank, config.hidden_size)
422
+ shapes[f"{prefix}.self_attn.q_a_layernorm.weight"] = (config.q_lora_rank,)
423
+ shapes[f"{prefix}.self_attn.q_b_proj.weight"] = (config.num_attention_heads * head_dim, config.q_lora_rank)
424
+ shapes[f"{prefix}.self_attn.kv_a_proj_with_mqa.weight"] = (config.kv_lora_rank + config.qk_rope_head_dim, config.hidden_size)
425
+ shapes[f"{prefix}.self_attn.kv_a_layernorm.weight"] = (config.kv_lora_rank,)
426
+ shapes[f"{prefix}.self_attn.kv_b_proj.weight"] = (config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim), config.kv_lora_rank)
427
+ shapes[f"{prefix}.self_attn.o_proj.weight"] = (config.hidden_size, config.num_attention_heads * config.v_head_dim)
428
+
429
+ # Layer norms
430
+ shapes[f"{prefix}.input_layernorm.weight"] = (config.hidden_size,)
431
+ shapes[f"{prefix}.post_attention_layernorm.weight"] = (config.hidden_size,)
432
+
433
+ # MoE: SharedProjectionMoE (D-20, D-21)
434
+ # shared_up: Linear(hidden, shared_inter=6144)
435
+ shapes[f"{prefix}.moe.shared_up.weight"] = (config.shared_intermediate_size, config.hidden_size)
436
+ # shared_down: Linear(shared_inter=6144, hidden)
437
+ shapes[f"{prefix}.moe.shared_down.weight"] = (config.hidden_size, config.shared_intermediate_size)
438
+ # W_gate: Parameter [num_experts, hidden, expert_core_rank]
439
+ shapes[f"{prefix}.moe.W_gate"] = (config.num_experts, config.hidden_size, config.expert_core_rank)
440
+ # W_transform: Parameter [num_experts, expert_core_rank, shared_inter]
441
+ shapes[f"{prefix}.moe.W_transform"] = (config.num_experts, config.expert_core_rank, config.shared_intermediate_size)
442
+ # shared_expert: SpiderExpert with inter=shared_expert_intermediate_size
443
+ shapes[f"{prefix}.moe.shared_expert.gate_proj.weight"] = (config.shared_expert_intermediate_size, config.hidden_size)
444
+ shapes[f"{prefix}.moe.shared_expert.up_proj.weight"] = (config.shared_expert_intermediate_size, config.hidden_size)
445
+ shapes[f"{prefix}.moe.shared_expert.down_proj.weight"] = (config.hidden_size, config.shared_expert_intermediate_size)
446
+ # Router
447
+ shapes[f"{prefix}.moe.router.weight"] = (config.num_experts, config.hidden_size)
448
+ shapes[f"{prefix}.moe.router.bias"] = (config.num_experts,)
449
+
450
+ # LoRA adapter
451
+ shapes[f"{prefix}.lora_adapter.down.weight"] = (config.lora_rank, config.hidden_size)
452
+ shapes[f"{prefix}.lora_adapter.B"] = (config.lora_rank, config.hidden_size)
453
+ shapes[f"{prefix}.lora_adapter.scale.weight"] = (config.max_loop_iters, config.lora_rank)
454
+
455
+ # ACT halting
456
+ shapes[f"{prefix}.act_halting.halt_predictor.weight"] = (1, config.hidden_size)
457
+ shapes[f"{prefix}.act_halting.halt_predictor.bias"] = (1,)
458
+
459
+ # Engram (layers 1 and 4 only)
460
+ if i in config.engram_layers:
461
+ engram_mem_dim = config.engram_heads * config.engram_dim
462
+ shapes[f"{prefix}.engram.W_k.weight"] = (config.hidden_size, engram_mem_dim * 2)
463
+ shapes[f"{prefix}.engram.W_v.weight"] = (config.hidden_size, engram_mem_dim * 2)
464
+ shapes[f"{prefix}.engram.conv.weight"] = (config.hidden_size, 1, 4)
465
+ shapes[f"{prefix}.engram.conv.bias"] = (config.hidden_size,)
466
+ shapes[f"{prefix}.engram.q_norm.weight"] = (config.hidden_size,)
467
+ shapes[f"{prefix}.engram.k_norm.weight"] = (config.hidden_size,)
468
+ shapes[f"{prefix}.engram.embed"] = (2, config.engram_heads, config.engram_table_size, config.engram_dim)
469
+ shapes[f"{prefix}.engram.hash_seeds"] = (config.engram_heads * 2,)
470
+ shapes[f"{prefix}.post_engram_layernorm.weight"] = (config.hidden_size,)
471
+
472
+ # LTI injection
473
+ shapes["model.injection.log_A"] = (config.hidden_size,)
474
+ shapes["model.injection.delta_t"] = ()
475
+ shapes["model.injection.B.weight"] = (config.hidden_size, config.hidden_size)
476
+
477
+ # Final norm
478
+ shapes["model.norm.weight"] = (config.hidden_size,)
479
+
480
+ # Loop embedding dimension (config attribute, not a parameter)
481
+ # shapes["model.loop_embed_dim"] = ()
482
+
483
+ # ACT halting for model level
484
+ shapes["model.act_halting.halt_predictor.weight"] = (1, config.hidden_size)
485
+ shapes["model.act_halting.halt_predictor.bias"] = (1,)
486
+
487
+ return shapes
488
+
489
+
490
+ # ============================================================================
491
+ # Weight Adaptation Helper
492
+ # ============================================================================
493
+
494
+ def _adapt_weight(weight, target_out, target_in):
495
+ """Adapt a donor weight matrix to Spider dimensions via padding/cropping.
496
+
497
+ When donor hidden_size differs from Spider's (e.g., in mini test mode),
498
+ we pad or crop the weight matrix to match target dimensions.
499
+
500
+ Args:
501
+ weight: [out_features, in_features] weight tensor from donor
502
+ target_out: Target output dimension
503
+ target_in: Target input dimension
504
+
505
+ Returns:
506
+ Adapted weight tensor of shape [target_out, target_in]
507
+ """
508
+ out_dim, in_dim = weight.shape
509
+
510
+ # Create target-sized tensor with Kaiming init
511
+ adapted = torch.empty(target_out, target_in)
512
+ nn.init.kaiming_uniform_(adapted, a=math.sqrt(5))
513
+
514
+ # Copy what fits from donor
515
+ copy_out = min(out_dim, target_out)
516
+ copy_in = min(in_dim, target_in)
517
+ adapted[:copy_out, :copy_in] = weight[:copy_out, :copy_in]
518
+
519
+ return adapted
520
+
521
+
522
+ # ============================================================================
523
+ # Main Transfer Function
524
+ # ============================================================================
525
+
526
+ def transfer_qwen_to_spider(
527
+ donor_state_dict: Dict[str, torch.Tensor],
528
+ donor_config: Dict,
529
+ spider_config: SpiderConfig,
530
+ noise_scale: float = 0.02,
531
+ ) -> Dict:
532
+ """Transfer weights from Qwen3.5-2B donor to Spider-FLEXITOKENS architecture.
533
+
534
+ Per D-09: Qwen3.5-2B is the weight donor. Per D-10: SVD decomposition
535
+ converts standard GQA attention to MLA format.
536
+
537
+ Transfer rules:
538
+ - o_proj [2048, 2048]: direct copy from donor
539
+ - q_proj → SVD → q_b_proj (q_a_proj reinitialized with Kaiming)
540
+ - k_proj + v_proj → SVD → kv_b_proj (kv_a_proj reinitialized with Kaiming)
541
+ - Layer norms [2048]: direct copy
542
+ - Embeddings: REINIT [272, 2048] (byte-level)
543
+ - BoundaryPredictor: REINIT (no pre-trained source)
544
+ - FFN: REINIT (intermediate_size mismatch 6144 vs 1024)
545
+ - LoRA, ACT, LTI: REINIT (Spider-specific modules)
546
+
547
+ Args:
548
+ donor_state_dict: Qwen3.5-2B state dict
549
+ donor_config: Donor model config dict
550
+ spider_config: Spider model config
551
+ noise_scale: Noise scale for MoE expert perturbation
552
+
553
+ Returns:
554
+ Dict with "spider_state_dict", "transfer_coverage", "layer_mapping"
555
+ """
556
+ hidden_size = spider_config.hidden_size
557
+ q_lora_rank = spider_config.q_lora_rank
558
+ kv_lora_rank = spider_config.kv_lora_rank
559
+ num_heads = spider_config.num_attention_heads
560
+ head_dim = spider_config.head_dim
561
+ qk_nope_head_dim = spider_config.qk_nope_head_dim
562
+ qk_rope_head_dim = spider_config.qk_rope_head_dim
563
+ v_head_dim = spider_config.v_head_dim
564
+
565
+ # Donor dimensions (may differ from Spider in mini/test mode)
566
+ donor_hidden_size = donor_config.get("hidden_size", hidden_size)
567
+ donor_num_heads = donor_config.get("num_attention_heads", 8)
568
+ donor_num_kv_heads = donor_config.get("num_key_value_heads", 2)
569
+ donor_head_dim = donor_config.get("head_dim", 256)
570
+ donor_intermediate_size = donor_config.get("intermediate_size", 6144)
571
+
572
+ # Track parameter counts for coverage report
573
+ donor_param_count = 0
574
+ reinit_param_count = 0
575
+ donor_params = set() # keys that came from donor
576
+ reinit_params = set() # keys that were reinitialized
577
+
578
+ spider_sd = {}
579
+
580
+ # Determine layer mapping from donor to Spider
581
+ full_attention_layers = donor_config.get("full_attention_layers", [])
582
+ num_donor_layers = donor_config.get("num_hidden_layers", 24)
583
+
584
+ # Map donor layers to Spider sections:
585
+ # prelude: 2 layers, recurrent: 6 layers, coda: 2 layers = 10 total
586
+ # Use full_attention layers preferentially
587
+ available_fa = list(full_attention_layers)
588
+
589
+ # Build layer mapping: spider_layer_idx → donor_layer_idx
590
+ layer_mapping = {}
591
+ required_layers = (
592
+ spider_config.prelude_layers
593
+ + spider_config.num_hidden_layers
594
+ + spider_config.coda_layers
595
+ )
596
+
597
+ # Fill from full_attention layers first, then fallback to any layer
598
+ donor_pool = list(available_fa)
599
+ if len(donor_pool) < required_layers:
600
+ # Add remaining layers (including linear_attention) for norms
601
+ all_layers = list(range(num_donor_layers))
602
+ for l in all_layers:
603
+ if l not in donor_pool:
604
+ donor_pool.append(l)
605
+
606
+ for i in range(required_layers):
607
+ if i < len(donor_pool):
608
+ layer_mapping[i] = donor_pool[i]
609
+ else:
610
+ layer_mapping[i] = None # No donor layer available
611
+
612
+ def _kaiming_init(shape):
613
+ """Kaiming uniform initialization for new parameters."""
614
+ tensor = torch.empty(shape)
615
+ nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))
616
+ return tensor
617
+
618
+ def _zeros_init(shape):
619
+ """Zero initialization."""
620
+ return torch.zeros(shape, dtype=torch.float32) # IN-02: explicit dtype
621
+
622
+ def _ones_init(shape):
623
+ """Ones initialization for layer norm weights."""
624
+ return torch.ones(shape, dtype=torch.float32)
625
+
626
+ # ---- 1. Embeddings: REINIT for byte-level vocab ----
627
+ embed_weight = _kaiming_init((spider_config.vocab_size, hidden_size))
628
+ spider_sd["embed_tokens.weight"] = embed_weight
629
+ reinit_param_count += embed_weight.numel()
630
+ reinit_params.add("embed_tokens.weight")
631
+
632
+ lm_head_weight = _kaiming_init((spider_config.vocab_size, hidden_size))
633
+ spider_sd["lm_head.weight"] = lm_head_weight
634
+ reinit_param_count += lm_head_weight.numel()
635
+ reinit_params.add("lm_head.weight")
636
+
637
+ # ---- 2. BoundaryPredictor: REINIT (no pre-trained source) ----
638
+ bp_0_weight = _kaiming_init((spider_config.bp_d_inner, hidden_size))
639
+ bp_0_bias = _zeros_init((spider_config.bp_d_inner,))
640
+ bp_2_weight = _kaiming_init((1, spider_config.bp_d_inner))
641
+ bp_2_bias = _zeros_init((1,))
642
+ spider_sd["boundary_predictor.0.weight"] = bp_0_weight
643
+ spider_sd["boundary_predictor.0.bias"] = bp_0_bias
644
+ spider_sd["boundary_predictor.2.weight"] = bp_2_weight
645
+ spider_sd["boundary_predictor.2.bias"] = bp_2_bias
646
+ reinit_param_count += bp_0_weight.numel() + bp_0_bias.numel()
647
+ reinit_param_count += bp_2_weight.numel() + bp_2_bias.numel()
648
+ reinit_params.add("boundary_predictor.0.weight")
649
+ reinit_params.add("boundary_predictor.2.weight")
650
+
651
+ # ---- 3. null_group and down_ln for downsample/upsample ----
652
+ null_group = _zeros_init((hidden_size,))
653
+ spider_sd["null_group.weight"] = null_group
654
+ reinit_param_count += null_group.numel()
655
+ reinit_params.add("null_group.weight")
656
+
657
+ down_ln_w = torch.ones(hidden_size, dtype=torch.float32)
658
+ down_ln_b = _zeros_init((hidden_size,))
659
+ spider_sd["down_ln.weight"] = down_ln_w
660
+ spider_sd["down_ln.bias"] = down_ln_b
661
+ reinit_param_count += down_ln_w.numel() + down_ln_b.numel()
662
+ reinit_params.add("down_ln.weight")
663
+
664
+ # ---- 4. Layer-by-layer weight transfer ----
665
+ for section_name, num_layers in [
666
+ ("prelude_layers", spider_config.prelude_layers),
667
+ ("recurrent_layers", spider_config.num_hidden_layers),
668
+ ("coda_layers", spider_config.coda_layers),
669
+ ]:
670
+ is_recurrent = section_name == "recurrent_layers"
671
+
672
+ for layer_idx in range(num_layers):
673
+ # WR-02 fix: accumulate spider_layer_idx across sections so
674
+ # coda layers map to distinct donor layers instead of reusing
675
+ # prelude donor layers
676
+ spider_layer_idx = ({
677
+ "prelude_layers": 0,
678
+ "recurrent_layers": spider_config.prelude_layers,
679
+ "coda_layers": spider_config.prelude_layers + spider_config.num_hidden_layers,
680
+ }[section_name] + layer_idx)
681
+ donor_layer_idx = layer_mapping.get(spider_layer_idx)
682
+
683
+ prefix = f"model.{section_name}.{layer_idx}"
684
+
685
+ if donor_layer_idx is not None:
686
+ donor_prefix = f"model.layers.{donor_layer_idx}"
687
+ else:
688
+ donor_prefix = None
689
+
690
+ # ---- Attention: MLA via SVD ----
691
+ # q_proj: [num_heads_donor * head_dim_donor, hidden_size_donor]
692
+ if donor_prefix is not None:
693
+ donor_q_key = f"{donor_prefix}.self_attn.q_proj.weight"
694
+ donor_q = donor_state_dict.get(donor_q_key)
695
+ else:
696
+ donor_q = None
697
+
698
+ if donor_q is not None and donor_q.shape[0] == donor_num_heads * donor_head_dim and donor_q.shape[1] == donor_hidden_size:
699
+ # SVD decompose q_proj → q_b_proj
700
+ # donor_q shape: [out, in] — PyTorch Linear stores [out, in]
701
+ # We want: q_a_proj weight [q_lora_rank, hidden_size] and
702
+ # q_b_proj weight [num_heads * head_dim, q_lora_rank]
703
+ # SVD decompose: donor_q.T = [in, out] → a=[in,rank], b=[rank,out]
704
+ # a_svd: [in, rank], b_svd: [rank, out]
705
+ # q_a_proj.weight = a_svd.T = [rank, in] → matches nn.Linear(hidden, q_lora_rank)
706
+ # q_b_proj.weight = b_svd.T = [out, rank] → matches nn.Linear(q_lora_rank, num_heads*head_dim)
707
+ # When donor_hidden_size != hidden_size, we adapt the SVD decomposition
708
+ effective_q = donor_q
709
+ if donor_hidden_size != hidden_size:
710
+ # Pad/crop donor_q to match Spider dimensions
711
+ effective_q = _adapt_weight(donor_q, donor_num_heads * donor_head_dim, hidden_size)
712
+
713
+ q_a_svd, q_b_svd = decompose_attention_svd(effective_q, q_lora_rank)
714
+ # Per D-10: q_a_proj is REINITIALIZED, q_b_proj from SVD
715
+ q_a_proj = _kaiming_init((q_lora_rank, hidden_size))
716
+ q_b_proj = q_b_svd.T # [out, rank] — transposed for PyTorch Linear [out, in]
717
+ donor_param_count += q_b_proj.numel()
718
+ reinit_param_count += q_a_proj.numel()
719
+ donor_params.add(f"{prefix}.self_attn.q_b_proj.weight")
720
+ reinit_params.add(f"{prefix}.self_attn.q_a_proj.weight")
721
+ else:
722
+ q_a_proj = _kaiming_init((q_lora_rank, hidden_size))
723
+ q_b_proj = _kaiming_init((num_heads * head_dim, q_lora_rank))
724
+ reinit_param_count += q_a_proj.numel() + q_b_proj.numel()
725
+ reinit_params.add(f"{prefix}.self_attn.q_a_proj.weight")
726
+ reinit_params.add(f"{prefix}.self_attn.q_b_proj.weight")
727
+
728
+ spider_sd[f"{prefix}.self_attn.q_a_proj.weight"] = q_a_proj
729
+ spider_sd[f"{prefix}.self_attn.q_b_proj.weight"] = q_b_proj
730
+
731
+ # q_a_layernorm
732
+ q_a_ln = torch.ones(q_lora_rank, dtype=torch.float32)
733
+ spider_sd[f"{prefix}.self_attn.q_a_layernorm.weight"] = q_a_ln
734
+ reinit_param_count += q_a_ln.numel()
735
+ reinit_params.add(f"{prefix}.self_attn.q_a_layernorm.weight")
736
+
737
+ # k_proj + v_proj → SVD → kv_a_proj_with_mqa, kv_b_proj
738
+ if donor_prefix is not None:
739
+ donor_k_key = f"{donor_prefix}.self_attn.k_proj.weight"
740
+ donor_v_key = f"{donor_prefix}.self_attn.v_proj.weight"
741
+ donor_k = donor_state_dict.get(donor_k_key)
742
+ donor_v = donor_state_dict.get(donor_v_key)
743
+ else:
744
+ donor_k = None
745
+ donor_v = None
746
+
747
+ if donor_k is not None and donor_v is not None:
748
+ # Concatenate k_proj and v_proj along output dim
749
+ # donor_k: [num_kv_heads * head_dim_donor, hidden_size_donor]
750
+ # donor_v: [num_kv_heads * head_dim_donor, hidden_size_donor]
751
+ # Combined: [num_kv_heads * head_dim_donor * 2, hidden_size_donor]
752
+ combined_kv = torch.cat([donor_k, donor_v], dim=0)
753
+
754
+ # Adapt dimensions if donor hidden_size differs from Spider's
755
+ if donor_hidden_size != hidden_size:
756
+ combined_kv_out = donor_num_kv_heads * donor_head_dim * 2
757
+ combined_kv = _adapt_weight(combined_kv, combined_kv_out, hidden_size)
758
+
759
+ # Transpose for SVD: we want [hidden_size, combined_kv_out]
760
+ kv_a_svd, kv_b_svd = decompose_attention_svd(combined_kv.T, kv_lora_rank)
761
+ # kv_a_svd: [hidden_size, rank], kv_b_svd: [rank, combined_kv_out]
762
+ # Per D-10: kv_a_proj (compression) REINITIALIZED
763
+ # kv_b_proj (decompression) from SVD
764
+
765
+ # kv_a_proj_with_mqa.weight: [kv_lora_rank + qk_rope_head_dim, hidden_size]
766
+ # = [128 + 64, 2048] = [192, 2048]
767
+ kv_a_with_mqa = _kaiming_init(
768
+ (kv_lora_rank + qk_rope_head_dim, hidden_size)
769
+ )
770
+
771
+ # kv_b_proj.weight: [num_heads*(qk_nope+v_head), kv_lora_rank]
772
+ # = [16*(64+64), 128] = [2048, 128]
773
+ # SVD gives kv_b_svd: [128, 1024] → transpose: [1024, 128]
774
+ # This is smaller than [2048, 128], so pad with Kaiming init
775
+ kv_b_proj_weight = _kaiming_init(
776
+ (num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank)
777
+ ) # [2048, 128]
778
+ svd_contribution = kv_b_svd.T # [1024, 128]
779
+ # Copy SVD result into the beginning of kv_b_proj_weight
780
+ rows_to_copy = min(svd_contribution.shape[0], kv_b_proj_weight.shape[0])
781
+ kv_b_proj_weight[:rows_to_copy, :] = svd_contribution[:rows_to_copy]
782
+
783
+ # Count: SVD-initialized rows count as donor, padding as reinit
784
+ donor_rows = rows_to_copy
785
+ reinit_rows = kv_b_proj_weight.shape[0] - donor_rows
786
+ donor_param_count += donor_rows * kv_b_proj_weight.shape[1]
787
+ reinit_param_count += reinit_rows * kv_b_proj_weight.shape[1]
788
+
789
+ reinit_param_count += kv_a_with_mqa.numel()
790
+ donor_params.add(f"{prefix}.self_attn.kv_b_proj.weight")
791
+ reinit_params.add(f"{prefix}.self_attn.kv_a_proj_with_mqa.weight")
792
+ else:
793
+ kv_a_with_mqa = _kaiming_init(
794
+ (kv_lora_rank + qk_rope_head_dim, hidden_size)
795
+ )
796
+ kv_b_proj_weight = _kaiming_init(
797
+ (num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank)
798
+ )
799
+ reinit_param_count += kv_a_with_mqa.numel() + kv_b_proj_weight.numel()
800
+ reinit_params.add(f"{prefix}.self_attn.kv_a_proj_with_mqa.weight")
801
+ reinit_params.add(f"{prefix}.self_attn.kv_b_proj.weight")
802
+
803
+ spider_sd[f"{prefix}.self_attn.kv_a_proj_with_mqa.weight"] = kv_a_with_mqa
804
+ spider_sd[f"{prefix}.self_attn.kv_b_proj.weight"] = kv_b_proj_weight
805
+
806
+ # kv_a_layernorm
807
+ kv_a_ln = torch.ones(kv_lora_rank, dtype=torch.float32)
808
+ spider_sd[f"{prefix}.self_attn.kv_a_layernorm.weight"] = kv_a_ln
809
+ reinit_param_count += kv_a_ln.numel()
810
+ reinit_params.add(f"{prefix}.self_attn.kv_a_layernorm.weight")
811
+
812
+ # o_proj: copy from donor where possible
813
+ # Donor o_proj: [donor_hidden_size, donor_hidden_size]
814
+ # Spider o_proj: [hidden_size, num_heads * v_head_dim]
815
+ if donor_prefix is not None:
816
+ donor_o_key = f"{donor_prefix}.self_attn.o_proj.weight"
817
+ donor_o = donor_state_dict.get(donor_o_key)
818
+ else:
819
+ donor_o = None
820
+
821
+ o_proj_shape = (hidden_size, num_heads * v_head_dim) # [2048, 1024]
822
+ o_proj = _kaiming_init(o_proj_shape)
823
+ if donor_o is not None:
824
+ # Copy what fits from donor's o_proj
825
+ rows_to_copy = min(donor_o.shape[0], o_proj.shape[0])
826
+ cols_to_copy = min(donor_o.shape[1], o_proj.shape[1])
827
+ o_proj[:rows_to_copy, :cols_to_copy] = donor_o[:rows_to_copy, :cols_to_copy]
828
+ donor_param_count += rows_to_copy * cols_to_copy
829
+ remaining = o_proj.numel() - rows_to_copy * cols_to_copy
830
+ if remaining > 0:
831
+ reinit_param_count += remaining
832
+ donor_params.add(f"{prefix}.self_attn.o_proj.weight")
833
+ else:
834
+ reinit_param_count += o_proj.numel()
835
+ reinit_params.add(f"{prefix}.self_attn.o_proj.weight")
836
+ spider_sd[f"{prefix}.self_attn.o_proj.weight"] = o_proj
837
+
838
+ # Layer norms: direct copy where shapes match, adapt otherwise
839
+ for norm_name in ["input_layernorm.weight", "post_attention_layernorm.weight"]:
840
+ if donor_prefix is not None:
841
+ donor_norm_key = f"{donor_prefix}.{norm_name}"
842
+ donor_norm = donor_state_dict.get(donor_norm_key)
843
+ else:
844
+ donor_norm = None
845
+
846
+ if donor_norm is not None and donor_norm.shape == (hidden_size,):
847
+ spider_sd[f"{prefix}.{norm_name}"] = donor_norm.clone()
848
+ donor_param_count += donor_norm.numel()
849
+ donor_params.add(f"{prefix}.{norm_name}")
850
+ elif donor_norm is not None and donor_norm.shape[0] != hidden_size:
851
+ # Adapt: pad/crop layer norm to match Spider hidden_size
852
+ adapted_norm = torch.ones(hidden_size, dtype=torch.float32)
853
+ copy_size = min(donor_norm.shape[0], hidden_size)
854
+ adapted_norm[:copy_size] = donor_norm[:copy_size]
855
+ spider_sd[f"{prefix}.{norm_name}"] = adapted_norm
856
+ donor_param_count += copy_size
857
+ reinit_param_count += hidden_size - copy_size
858
+ donor_params.add(f"{prefix}.{norm_name}")
859
+ else:
860
+ ln = torch.ones(hidden_size, dtype=torch.float32)
861
+ spider_sd[f"{prefix}.{norm_name}"] = ln
862
+ reinit_param_count += ln.numel()
863
+ reinit_params.add(f"{prefix}.{norm_name}")
864
+
865
+ # ---- FFN / MoE ----
866
+ if is_recurrent:
867
+ # SharedProjectionMoE (D-20, D-21):
868
+ # shared_up: Linear(hidden, shared_inter=6144) — DIRECT copy from donor up_proj
869
+ # shared_down: Linear(shared_inter=6144, hidden) — DIRECT copy from donor down_proj
870
+ # shared_expert: SpiderExpert with inter=7424 — partial copy from donor FFN
871
+ # W_gate, W_transform: random init — created by split_dense_to_moe
872
+
873
+ # Stride mapping: Spider layer i → Qwen layer i*4 (layers 0,4,8,12,16,20)
874
+ # for 6 recurrent layers out of 24 Qwen layers
875
+ if donor_layer_idx is not None:
876
+ qwen_layer_for_ffn = donor_layer_idx
877
+ else:
878
+ qwen_layer_for_ffn = None
879
+
880
+ # ---- shared_up: direct copy from donor up_proj ----
881
+ # Spider shared_up.weight: [shared_inter=6144, hidden=2048]
882
+ # Qwen up_proj.weight: [inter=6144, hidden=2048] — EXACT MATCH (D-32)
883
+ shared_up_key = f"{prefix}.moe.shared_up.weight"
884
+ shared_up_shape = (spider_config.shared_intermediate_size, hidden_size)
885
+ if qwen_layer_for_ffn is not None:
886
+ donor_up_key = f"model.layers.{qwen_layer_for_ffn}.mlp.up_proj.weight"
887
+ donor_up = donor_state_dict.get(donor_up_key)
888
+ else:
889
+ donor_up = None
890
+
891
+ if donor_up is not None and donor_up.shape == shared_up_shape:
892
+ spider_sd[shared_up_key] = donor_up.clone().float()
893
+ donor_param_count += donor_up.numel()
894
+ donor_params.add(shared_up_key)
895
+ elif donor_up is not None:
896
+ shared_up_w = _kaiming_init(shared_up_shape)
897
+ rows_copy = min(donor_up.shape[0], shared_up_shape[0])
898
+ cols_copy = min(donor_up.shape[1], shared_up_shape[1])
899
+ shared_up_w[:rows_copy, :cols_copy] = donor_up[:rows_copy, :cols_copy].float()
900
+ spider_sd[shared_up_key] = shared_up_w
901
+ donor_param_count += rows_copy * cols_copy
902
+ reinit_param_count += shared_up_w.numel() - rows_copy * cols_copy
903
+ donor_params.add(shared_up_key)
904
+ else:
905
+ spider_sd[shared_up_key] = _kaiming_init(shared_up_shape)
906
+ reinit_param_count += shared_up_shape[0] * shared_up_shape[1]
907
+ reinit_params.add(shared_up_key)
908
+
909
+ # ---- shared_down: direct copy from donor down_proj ----
910
+ # Spider shared_down.weight: [hidden=2048, shared_inter=6144]
911
+ # Qwen down_proj.weight: [hidden=2048, inter=6144] — EXACT MATCH (D-32)
912
+ shared_down_key = f"{prefix}.moe.shared_down.weight"
913
+ shared_down_shape = (hidden_size, spider_config.shared_intermediate_size)
914
+ if qwen_layer_for_ffn is not None:
915
+ donor_down_key = f"model.layers.{qwen_layer_for_ffn}.mlp.down_proj.weight"
916
+ donor_down = donor_state_dict.get(donor_down_key)
917
+ else:
918
+ donor_down = None
919
+
920
+ if donor_down is not None and donor_down.shape == shared_down_shape:
921
+ spider_sd[shared_down_key] = donor_down.clone().float()
922
+ donor_param_count += donor_down.numel()
923
+ donor_params.add(shared_down_key)
924
+ elif donor_down is not None:
925
+ shared_down_w = _kaiming_init(shared_down_shape)
926
+ rows_copy = min(donor_down.shape[0], shared_down_shape[0])
927
+ cols_copy = min(donor_down.shape[1], shared_down_shape[1])
928
+ shared_down_w[:rows_copy, :cols_copy] = donor_down[:rows_copy, :cols_copy].float()
929
+ spider_sd[shared_down_key] = shared_down_w
930
+ donor_param_count += rows_copy * cols_copy
931
+ reinit_param_count += shared_down_w.numel() - rows_copy * cols_copy
932
+ donor_params.add(shared_down_key)
933
+ else:
934
+ spider_sd[shared_down_key] = _kaiming_init(shared_down_shape)
935
+ reinit_param_count += shared_down_shape[0] * shared_down_shape[1]
936
+ reinit_params.add(shared_down_key)
937
+
938
+ # ---- shared_expert: partial transfer from donor FFN (6144→7424) ----
939
+ # Spider shared_expert has inter=7424 (D-21: larger than donor's 6144)
940
+ # First 6144 rows/cols from donor, remaining 1280 randomly initialized
941
+ shared_expert_inter = spider_config.shared_expert_intermediate_size
942
+ if qwen_layer_for_ffn is not None:
943
+ donor_gate_key = f"model.layers.{qwen_layer_for_ffn}.mlp.gate_proj.weight"
944
+ donor_se_up_key = f"model.layers.{qwen_layer_for_ffn}.mlp.up_proj.weight"
945
+ donor_se_down_key = f"model.layers.{qwen_layer_for_ffn}.mlp.down_proj.weight"
946
+ donor_se_gate = donor_state_dict.get(donor_gate_key)
947
+ donor_se_up = donor_state_dict.get(donor_se_up_key)
948
+ donor_se_down = donor_state_dict.get(donor_se_down_key)
949
+ else:
950
+ donor_se_gate = donor_se_up = donor_se_down = None
951
+
952
+ for proj_name, spider_shape in [
953
+ ("gate_proj", (shared_expert_inter, hidden_size)),
954
+ ("up_proj", (shared_expert_inter, hidden_size)),
955
+ ("down_proj", (hidden_size, shared_expert_inter)),
956
+ ]:
957
+ key = f"{prefix}.moe.shared_expert.{proj_name}.weight"
958
+ w = _kaiming_init(spider_shape)
959
+
960
+ if proj_name in ("gate_proj", "up_proj"):
961
+ donor_src = donor_se_gate if proj_name == "gate_proj" else donor_se_up
962
+ if donor_src is not None:
963
+ rows_copy = min(donor_src.shape[0], spider_shape[0])
964
+ cols_copy = min(donor_src.shape[1], spider_shape[1])
965
+ w[:rows_copy, :cols_copy] = donor_src[:rows_copy, :cols_copy].float()
966
+ donor_param_count += rows_copy * cols_copy
967
+ reinit_param_count += w.numel() - rows_copy * cols_copy
968
+ donor_params.add(key)
969
+ else:
970
+ reinit_param_count += w.numel()
971
+ reinit_params.add(key)
972
+ else: # down_proj: [hidden, shared_expert_inter]
973
+ if donor_se_down is not None:
974
+ rows_copy = min(donor_se_down.shape[0], spider_shape[0])
975
+ cols_copy = min(donor_se_down.shape[1], spider_shape[1])
976
+ w[:rows_copy, :cols_copy] = donor_se_down[:rows_copy, :cols_copy].float()
977
+ donor_param_count += rows_copy * cols_copy
978
+ reinit_param_count += w.numel() - rows_copy * cols_copy
979
+ donor_params.add(key)
980
+ else:
981
+ reinit_param_count += w.numel()
982
+ reinit_params.add(key)
983
+
984
+ spider_sd[key] = w
985
+
986
+ # W_gate and W_transform will be created by split_dense_to_moe
987
+
988
+ # LoRA adapter
989
+ lora_down = _kaiming_init((spider_config.lora_rank, hidden_size))
990
+ lora_B = torch.zeros(spider_config.lora_rank, hidden_size, dtype=torch.float32)
991
+ lora_scale = torch.zeros(spider_config.max_loop_iters, spider_config.lora_rank, dtype=torch.float32)
992
+ spider_sd[f"{prefix}.lora_adapter.down.weight"] = lora_down
993
+ spider_sd[f"{prefix}.lora_adapter.B"] = lora_B
994
+ spider_sd[f"{prefix}.lora_adapter.scale.weight"] = lora_scale
995
+ reinit_param_count += lora_down.numel() + lora_B.numel() + lora_scale.numel()
996
+ reinit_params.add(f"{prefix}.lora_adapter.down.weight")
997
+
998
+ # ACT halting
999
+ halt_w = _kaiming_init((1, hidden_size))
1000
+ halt_b = _zeros_init((1,))
1001
+ spider_sd[f"{prefix}.act_halting.halt_predictor.weight"] = halt_w
1002
+ spider_sd[f"{prefix}.act_halting.halt_predictor.bias"] = halt_b
1003
+ reinit_param_count += halt_w.numel() + halt_b.numel()
1004
+ reinit_params.add(f"{prefix}.act_halting.halt_predictor.weight")
1005
+
1006
+ # Engram (layers 1 and 4 only — D-20 revision)
1007
+ if layer_idx in spider_config.engram_layers:
1008
+ engram_mem_dim = spider_config.engram_heads * spider_config.engram_dim
1009
+ engram_W_k = _kaiming_init((hidden_size, engram_mem_dim * 2))
1010
+ engram_W_v = _kaiming_init((hidden_size, engram_mem_dim * 2))
1011
+ engram_conv_w = _kaiming_init((hidden_size, 1, 4))
1012
+ engram_conv_b = _zeros_init((hidden_size,))
1013
+ engram_q_norm = _ones_init((hidden_size,))
1014
+ engram_k_norm = _ones_init((hidden_size,))
1015
+ engram_embed = torch.zeros(
1016
+ 2, spider_config.engram_heads, spider_config.engram_table_size, spider_config.engram_dim
1017
+ )
1018
+ engram_hash = torch.arange(spider_config.engram_heads * 2, dtype=torch.float32)
1019
+ post_engram_norm = _ones_init((hidden_size,))
1020
+
1021
+ spider_sd[f"{prefix}.engram.W_k.weight"] = engram_W_k
1022
+ spider_sd[f"{prefix}.engram.W_v.weight"] = engram_W_v
1023
+ spider_sd[f"{prefix}.engram.conv.weight"] = engram_conv_w
1024
+ spider_sd[f"{prefix}.engram.conv.bias"] = engram_conv_b
1025
+ spider_sd[f"{prefix}.engram.q_norm.weight"] = engram_q_norm
1026
+ spider_sd[f"{prefix}.engram.k_norm.weight"] = engram_k_norm
1027
+ spider_sd[f"{prefix}.engram.embed"] = engram_embed
1028
+ spider_sd[f"{prefix}.engram.hash_seeds"] = engram_hash
1029
+ spider_sd[f"{prefix}.post_engram_layernorm.weight"] = post_engram_norm
1030
+
1031
+ engram_params = (engram_W_k.numel() + engram_W_v.numel() + engram_conv_w.numel() +
1032
+ engram_conv_b.numel() + engram_q_norm.numel() + engram_k_norm.numel() +
1033
+ engram_embed.numel() + engram_hash.numel() + post_engram_norm.numel())
1034
+ reinit_param_count += engram_params
1035
+ else:
1036
+ # Dense FFN for prelude/coda: partial transfer from donor FFN
1037
+ # Spider uses prelude_coda_intermediate_size=4096 (D-21)
1038
+ # Donor has intermediate_size=6144 → copy first 4096 rows/cols
1039
+ dense_inter = spider_config.prelude_coda_intermediate_size
1040
+ if donor_layer_idx is not None:
1041
+ donor_gate_key = f"model.layers.{donor_layer_idx}.mlp.gate_proj.weight"
1042
+ donor_up_key = f"model.layers.{donor_layer_idx}.mlp.up_proj.weight"
1043
+ donor_down_key = f"model.layers.{donor_layer_idx}.mlp.down_proj.weight"
1044
+ donor_d_gate = donor_state_dict.get(donor_gate_key)
1045
+ donor_d_up = donor_state_dict.get(donor_up_key)
1046
+ donor_d_down = donor_state_dict.get(donor_down_key)
1047
+ else:
1048
+ donor_d_gate = donor_d_up = donor_d_down = None
1049
+
1050
+ for proj_name, shape, donor_src in [
1051
+ ("gate_proj", (dense_inter, hidden_size), donor_d_gate),
1052
+ ("up_proj", (dense_inter, hidden_size), donor_d_up),
1053
+ ("down_proj", (hidden_size, dense_inter), donor_d_down),
1054
+ ]:
1055
+ w = _kaiming_init(shape)
1056
+ key = f"{prefix}.ffn.{proj_name}.weight"
1057
+
1058
+ if donor_src is not None:
1059
+ if proj_name in ("gate_proj", "up_proj"):
1060
+ rows_copy = min(donor_src.shape[0], shape[0])
1061
+ cols_copy = min(donor_src.shape[1], shape[1])
1062
+ w[:rows_copy, :cols_copy] = donor_src[:rows_copy, :cols_copy].float()
1063
+ else:
1064
+ rows_copy = min(donor_src.shape[0], shape[0])
1065
+ cols_copy = min(donor_src.shape[1], shape[1])
1066
+ w[:rows_copy, :cols_copy] = donor_src[:rows_copy, :cols_copy].float()
1067
+ donor_param_count += rows_copy * cols_copy
1068
+ reinit_param_count += w.numel() - rows_copy * cols_copy
1069
+ donor_params.add(key)
1070
+ else:
1071
+ reinit_param_count += w.numel()
1072
+ reinit_params.add(key)
1073
+
1074
+ spider_sd[key] = w
1075
+
1076
+ # ---- 5. LTI Injection: REINIT (Spider-specific) ----
1077
+ log_A = torch.full((hidden_size,), -2.0)
1078
+ delta_t = torch.tensor(1.0)
1079
+ B_weight = torch.randn(hidden_size, hidden_size) * 0.01
1080
+ spider_sd["model.injection.log_A"] = log_A
1081
+ spider_sd["model.injection.delta_t"] = delta_t
1082
+ spider_sd["model.injection.B.weight"] = B_weight
1083
+ reinit_param_count += log_A.numel() + delta_t.numel() + B_weight.numel()
1084
+ reinit_params.add("model.injection.B.weight")
1085
+
1086
+ # ---- 6. Final norm: try to copy from donor, adapt dimensions ----
1087
+ donor_final_norm = donor_state_dict.get("model.norm.weight")
1088
+ if donor_final_norm is not None and donor_final_norm.shape == (hidden_size,):
1089
+ spider_sd["model.norm.weight"] = donor_final_norm.clone()
1090
+ donor_param_count += donor_final_norm.numel()
1091
+ donor_params.add("model.norm.weight")
1092
+ elif donor_final_norm is not None:
1093
+ # Adapt: pad/crop to match Spider hidden_size
1094
+ adapted_norm = torch.ones(hidden_size, dtype=torch.float32)
1095
+ copy_size = min(donor_final_norm.shape[0], hidden_size)
1096
+ adapted_norm[:copy_size] = donor_final_norm[:copy_size]
1097
+ spider_sd["model.norm.weight"] = adapted_norm
1098
+ donor_param_count += copy_size
1099
+ reinit_param_count += hidden_size - copy_size
1100
+ donor_params.add("model.norm.weight")
1101
+ else:
1102
+ spider_sd["model.norm.weight"] = torch.ones(hidden_size, dtype=torch.float32)
1103
+ reinit_param_count += hidden_size
1104
+ reinit_params.add("model.norm.weight")
1105
+
1106
+ # ---- 7. Model-level ACT halting: REINIT ----
1107
+ halt_w = _kaiming_init((1, hidden_size))
1108
+ halt_b = _zeros_init((1,))
1109
+ spider_sd["model.act_halting.halt_predictor.weight"] = halt_w
1110
+ spider_sd["model.act_halting.halt_predictor.bias"] = halt_b
1111
+ reinit_param_count += halt_w.numel() + halt_b.numel()
1112
+
1113
+ # ---- 8. Apply MoE expert splitting ----
1114
+ spider_sd = split_dense_to_moe(spider_sd, spider_config, noise_scale=noise_scale)
1115
+
1116
+ # Count SharedProjectionMoE params created by split_dense_to_moe
1117
+ for layer_idx in range(spider_config.num_hidden_layers):
1118
+ rec_prefix = f"model.recurrent_layers.{layer_idx}.moe"
1119
+ # W_gate and W_transform are random init
1120
+ for core_key in [f"{rec_prefix}.W_gate", f"{rec_prefix}.W_transform"]:
1121
+ if core_key in spider_sd and core_key not in reinit_params and core_key not in donor_params:
1122
+ reinit_param_count += spider_sd[core_key].numel()
1123
+ reinit_params.add(core_key)
1124
+ # Router
1125
+ for router_key in [f"{rec_prefix}.router.weight", f"{rec_prefix}.router.bias"]:
1126
+ if router_key in spider_sd and router_key not in reinit_params and router_key not in donor_params:
1127
+ reinit_param_count += spider_sd[router_key].numel()
1128
+ reinit_params.add(router_key)
1129
+
1130
+ # ---- 9. Compute transfer coverage ----
1131
+ total_params = donor_param_count + reinit_param_count
1132
+ if total_params > 0:
1133
+ donor_pct = (donor_param_count / total_params) * 100.0
1134
+ reinit_pct = (reinit_param_count / total_params) * 100.0
1135
+ else:
1136
+ donor_pct = 0.0
1137
+ reinit_pct = 0.0
1138
+
1139
+ transfer_coverage = {
1140
+ "donor_params": donor_param_count,
1141
+ "reinit_params": reinit_param_count,
1142
+ "total_params": total_params,
1143
+ "donor_pct": round(donor_pct, 2),
1144
+ "reinit_pct": round(reinit_pct, 2),
1145
+ "donor_keys": sorted(donor_params),
1146
+ "reinit_keys": sorted(reinit_params),
1147
+ }
1148
+
1149
+ # Print report
1150
+ print("=" * 60)
1151
+ print("Weight Transfer Report")
1152
+ print("=" * 60)
1153
+ print(f" Donor: Qwen3.5-2B ({donor_config.get('num_hidden_layers', '?')} layers)")
1154
+ print(f" Target: Spider-FLEXITOKENS ({spider_config.prelude_layers}+{spider_config.num_hidden_layers}+{spider_config.coda_layers} layers)")
1155
+ print(f" Full attention layers used: {len(full_attention_layers)}")
1156
+ print(f" Layer mapping: {layer_mapping}")
1157
+ print()
1158
+ print(f" Total params: {total_params:>12,} ({total_params/1e6:.1f}M)")
1159
+ print(f" Donor-originated: {donor_param_count:>12,} ({donor_param_count/1e6:.1f}M) = {donor_pct:.1f}%")
1160
+ print(f" Reinitialized: {reinit_param_count:>12,} ({reinit_param_count/1e6:.1f}M) = {reinit_pct:.1f}%")
1161
+ print()
1162
+ print(f" Transfer coverage: {donor_pct:.1f}% from donor, {reinit_pct:.1f}% reinitialized")
1163
+ print("=" * 60)
1164
+
1165
+ return {
1166
+ "spider_state_dict": spider_sd,
1167
+ "transfer_coverage": transfer_coverage,
1168
+ "layer_mapping": layer_mapping,
1169
+ }
1170
+
1171
+
1172
+ # ============================================================================
1173
+ # SpiderMoEModel — Multimodal Forward Pass (D-11, 02-03)
1174
+ # ============================================================================
1175
+
1176
+ class SpiderMoEModel(nn.Module):
1177
+ """Spider-FLEXITOKENS model with multimodal forward pass.
1178
+
1179
+ Implements the full forward pass wiring per D-11:
1180
+ - Text bytes → embed → prelude layers → BoundaryPredictor → downsample →
1181
+ recurrent core → upsample → coda layers → lm_head → logits
1182
+ - Modality tokens (vision/audio/video) are injected at sentinel-marked
1183
+ positions and bypass the BoundaryPredictor entirely.
1184
+ - Sentinel-gated passthrough: modality_mask forces boundary=1.0 at
1185
+ sentinel+modality positions, preventing cross-modality merges.
1186
+
1187
+ This is a simplified model that implements the forward pass logic
1188
+ without the full SpiderPortalMLA attention (which requires position
1189
+ IDs, KV cache, etc.). It uses simple linear projections to demonstrate
1190
+ the multimodal wiring and parameter budget.
1191
+ """
1192
+
1193
+ def __init__(self, config: SpiderConfig):
1194
+ super().__init__()
1195
+ self.config = config
1196
+
1197
+ # Embeddings: 272 vocab (256 bytes + 16 specials)
1198
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
1199
+ # LM head (not tied per D-06)
1200
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1201
+
1202
+ # BoundaryPredictor
1203
+ self.boundary_predictor = BoundaryPredictor(config)
1204
+
1205
+ # null_group for downsample
1206
+ self.null_group = nn.Parameter(torch.zeros(config.hidden_size, dtype=torch.float32)) # IN-02
1207
+
1208
+ # Downsample layer norm
1209
+ self.down_ln = nn.LayerNorm(config.hidden_size)
1210
+
1211
+ # Prelude layers (2 dense layers with simplified attention + FFN)
1212
+ self.prelude_layers = nn.ModuleList([
1213
+ self._make_dense_layer(config) for _ in range(config.prelude_layers)
1214
+ ])
1215
+
1216
+ # Recurrent layers (6 MoE layers with simplified attention + MoE)
1217
+ self.recurrent_layers = nn.ModuleList([
1218
+ self._make_moe_layer(config, i) for i in range(config.num_hidden_layers)
1219
+ ])
1220
+
1221
+ # Coda layers (2 dense layers with simplified attention + FFN)
1222
+ self.coda_layers = nn.ModuleList([
1223
+ self._make_dense_layer(config) for _ in range(config.coda_layers)
1224
+ ])
1225
+
1226
+ # Final norm
1227
+ self.norm = nn.LayerNorm(config.hidden_size)
1228
+
1229
+ # LTI injection
1230
+ self.injection = _LTIInjection(config)
1231
+
1232
+ # ACT halting
1233
+ self.act_halting = _ACTHalting(config)
1234
+
1235
+ # LoRA adapter (per recurrent layer)
1236
+ self.lora_adapters = nn.ModuleList([
1237
+ _LoRAAdapter(config) for _ in range(config.num_hidden_layers)
1238
+ ])
1239
+
1240
+ self.loop_embed_dim = config.loop_embed_dim
1241
+ self.max_loop_iters = config.max_loop_iters
1242
+
1243
+ def _make_dense_layer(self, config):
1244
+ """Create a simplified dense layer (prelude/coda)."""
1245
+ return _DenseLayer(config)
1246
+
1247
+ def _make_moe_layer(self, config, layer_idx):
1248
+ """Create a simplified MoE layer (recurrent)."""
1249
+ return _MoELayer(config, layer_idx)
1250
+
1251
+ def _inject_modality_features(
1252
+ self,
1253
+ hidden_states: torch.Tensor,
1254
+ input_ids: torch.Tensor,
1255
+ features: list,
1256
+ modality: str = 'IMG',
1257
+ ) -> torch.Tensor:
1258
+ """Replace placeholder embeddings with actual encoder features at modality regions.
1259
+
1260
+ Per D-11: Modality tokens (vision, audio, video) are injected at
1261
+ sentinel-marked positions in the hidden_states sequence. The caller
1262
+ constructs input_ids with sentinel tokens (e.g., IMG_START, IMG_END)
1263
+ marking modality regions. Between sentinel pairs, the initial
1264
+ embeddings are placeholders — this method replaces them with the
1265
+ actual encoder features.
1266
+
1267
+ T-02-06 mitigation: Validates feature shape and sentinel pair count.
1268
+
1269
+ Args:
1270
+ hidden_states: [B, L, D] hidden states after embedding.
1271
+ input_ids: [B, L] token IDs with sentinel markers.
1272
+ features: List of tensors, one per sentinel pair per batch item.
1273
+ Each tensor has shape [num_tokens, hidden_size].
1274
+ modality: Modality type prefix ('IMG', 'AUD', 'VID').
1275
+
1276
+ Returns:
1277
+ hidden_states with modality features injected at sentinel regions.
1278
+
1279
+ Raises:
1280
+ ValueError: If feature shape doesn't match [num_tokens, hidden_size]
1281
+ or sentinel pair count doesn't match feature count.
1282
+ """
1283
+ start_token = SENTINEL_TOKENS[f'{modality}_START']
1284
+ end_token = SENTINEL_TOKENS[f'{modality}_END']
1285
+
1286
+ for b in range(hidden_states.shape[0]):
1287
+ starts = (input_ids[b] == start_token).nonzero(as_tuple=True)[0]
1288
+ ends = (input_ids[b] == end_token).nonzero(as_tuple=True)[0]
1289
+
1290
+ if len(starts) != len(ends):
1291
+ raise ValueError(
1292
+ f"Batch {b}: mismatched {modality} sentinel pairs — "
1293
+ f"{len(starts)} {_TOKEN_NAMES_BY_ID[start_token]}(s) vs "
1294
+ f"{len(ends)} {_TOKEN_NAMES_BY_ID[end_token]}(s)."
1295
+ )
1296
+
1297
+ if len(starts) != len(features):
1298
+ raise ValueError(
1299
+ f"Batch {b}: {modality} sentinel pair count ({len(starts)}) "
1300
+ f"doesn't match feature count ({len(features)})."
1301
+ )
1302
+
1303
+ for s, e, feat in zip(starts, ends, features):
1304
+ # T-02-06: Validate feature shape
1305
+ num_tokens = e - s - 1 # tokens between sentinels
1306
+ if feat.shape[0] != num_tokens:
1307
+ raise ValueError(
1308
+ f"Batch {b}: {modality} feature has {feat.shape[0]} tokens "
1309
+ f"but sentinel region has {num_tokens} positions "
1310
+ f"(from pos {s+1} to {e-1})."
1311
+ )
1312
+ if feat.shape[1] != hidden_states.shape[-1]:
1313
+ raise ValueError(
1314
+ f"Batch {b}: {modality} feature hidden_size {feat.shape[1]} "
1315
+ f"doesn't match model hidden_size {hidden_states.shape[-1]}."
1316
+ )
1317
+ # Replace placeholder embeddings with actual features
1318
+ hidden_states[b, s + 1:e] = feat.to(hidden_states.dtype)
1319
+
1320
+ return hidden_states
1321
+
1322
+ def forward(
1323
+ self,
1324
+ input_ids: torch.Tensor,
1325
+ attention_mask: Optional[torch.Tensor] = None,
1326
+ position_ids: Optional[torch.Tensor] = None,
1327
+ past_key_values: Optional[list] = None,
1328
+ inputs_embeds: Optional[torch.Tensor] = None,
1329
+ vision_features: Optional[list] = None,
1330
+ audio_features: Optional[list] = None,
1331
+ video_features: Optional[list] = None,
1332
+ **kwargs,
1333
+ ) -> torch.Tensor:
1334
+ """Forward pass with multimodal sentinel-gated passthrough.
1335
+
1336
+ Per D-11:
1337
+ - All positions go through embed_tokens (bytes get byte embeddings,
1338
+ sentinels get special embeddings, modality tokens get placeholder embeddings)
1339
+ - External encoder features are injected at sentinel-marked positions
1340
+ - BoundaryPredictor operates on the embedded sequence with modality_mask
1341
+ - Text bytes go through BP → downsample → recurrent → upsample → coda → logits
1342
+ - Modality tokens bypass BP and enter downsampled sequence at sentinel positions
1343
+
1344
+ Args:
1345
+ input_ids: [B, L] token IDs with optional sentinel markers.
1346
+ attention_mask: Optional attention mask (not used in simplified model).
1347
+ position_ids: Optional position IDs (not used in simplified model).
1348
+ past_key_values: Optional KV cache (not used in simplified model).
1349
+ inputs_embeds: Optional pre-computed embeddings.
1350
+ vision_features: Optional list of tensors, each [num_tokens, hidden_size].
1351
+ audio_features: Optional list of tensors, each [num_tokens, hidden_size].
1352
+ video_features: Optional list of tensors, each [num_tokens, hidden_size].
1353
+
1354
+ Returns:
1355
+ logits: [B, L, vocab_size] output logits.
1356
+ """
1357
+ B, L = input_ids.shape
1358
+
1359
+ # 1. Embed all tokens (bytes, sentinels, modality placeholders)
1360
+ if inputs_embeds is not None:
1361
+ hidden_states = inputs_embeds
1362
+ else:
1363
+ hidden_states = self.embed_tokens(input_ids) # [B, L, D]
1364
+
1365
+ # 2. Inject external modality features at sentinel positions
1366
+ if vision_features is not None:
1367
+ hidden_states = self._inject_modality_features(
1368
+ hidden_states, input_ids, vision_features, 'IMG'
1369
+ )
1370
+ if audio_features is not None:
1371
+ hidden_states = self._inject_modality_features(
1372
+ hidden_states, input_ids, audio_features, 'AUD'
1373
+ )
1374
+ if video_features is not None:
1375
+ hidden_states = self._inject_modality_features(
1376
+ hidden_states, input_ids, video_features, 'VID'
1377
+ )
1378
+
1379
+ # 3. Prelude layers
1380
+ for layer in self.prelude_layers:
1381
+ hidden_states = layer(hidden_states)
1382
+
1383
+ # 4. Boundary prediction with modality mask
1384
+ modality_mask = create_modality_mask(input_ids) # [B, L]
1385
+ soft_boundaries, hard_boundaries = self.boundary_predictor(
1386
+ hidden_states, modality_mask=modality_mask
1387
+ )
1388
+
1389
+ # 5. Downsample with boundaries
1390
+ # Apply layer norm before downsample
1391
+ hidden_states_normed = self.down_ln(hidden_states)
1392
+ null_group = self.null_group.unsqueeze(0).unsqueeze(0).expand(1, B, -1)
1393
+ shortened = downsample(hard_boundaries, hidden_states_normed, null_group)
1394
+ # shortened: [S, B, D]
1395
+
1396
+ # 6. Recurrent core with RDT looping
1397
+ # Convert shortened from SBD to BLD for recurrent layers
1398
+ hidden_states = shortened.permute(1, 0, 2) # [B, S, D]
1399
+
1400
+ n_loops = self.max_loop_iters
1401
+ input_embedding = hidden_states.clone()
1402
+
1403
+ for t in range(n_loops):
1404
+ # Loop index embedding
1405
+ loop_emb = _loop_index_embedding(hidden_states, t, self.loop_embed_dim)
1406
+
1407
+ if t > 0:
1408
+ # LTI injection
1409
+ injection = self.injection(hidden_states, input_embedding)
1410
+ hidden_states = hidden_states + injection
1411
+
1412
+ # Recurrent layers
1413
+ for i, layer in enumerate(self.recurrent_layers):
1414
+ # LoRA adaptation for this loop iteration
1415
+ lora_out = self.lora_adapters[i](hidden_states, t)
1416
+ hidden_states = layer(hidden_states + lora_out * 0.01)
1417
+
1418
+ # 7. Upsample back to original sequence length
1419
+ # Convert back to SBD for upsample
1420
+ hidden_states_sbd = hidden_states.permute(1, 0, 2) # [S, B, D]
1421
+ hidden_states = upsample(hard_boundaries, hidden_states_sbd) # [B, L, D]
1422
+
1423
+ # 8. Coda layers
1424
+ for layer in self.coda_layers:
1425
+ hidden_states = layer(hidden_states)
1426
+
1427
+ # 9. Final norm + LM head
1428
+ hidden_states = self.norm(hidden_states)
1429
+ logits = self.lm_head(hidden_states) # [B, L, vocab_size]
1430
+
1431
+ return logits
1432
+
1433
+
1434
+ # ============================================================================
1435
+ # Simplified sub-modules for SpiderMoEModel
1436
+ # ============================================================================
1437
+
1438
+ class _DenseLayer(nn.Module):
1439
+ """Simplified dense layer for prelude/coda (attention + FFN)."""
1440
+
1441
+ def __init__(self, config: SpiderConfig):
1442
+ super().__init__()
1443
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
1444
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
1445
+ # Simplified self-attention (single-head for parameter efficiency demo)
1446
+ self.self_attn = nn.MultiheadAttention(
1447
+ config.hidden_size, num_heads=4, batch_first=True
1448
+ )
1449
+ # FFN with SwiGLU-like structure
1450
+ self.ffn = _SwiGLUFFN(config.hidden_size, config.prelude_coda_intermediate_size)
1451
+
1452
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1453
+ # Self-attention with residual
1454
+ residual = hidden_states
1455
+ hidden_states = self.input_layernorm(hidden_states)
1456
+ attn_out, _ = self.self_attn(
1457
+ hidden_states, hidden_states, hidden_states
1458
+ )
1459
+ hidden_states = residual + attn_out
1460
+
1461
+ # FFN with residual
1462
+ residual = hidden_states
1463
+ hidden_states = self.post_attention_layernorm(hidden_states)
1464
+ ffn_out = self.ffn(hidden_states)
1465
+ hidden_states = residual + ffn_out
1466
+
1467
+ return hidden_states
1468
+
1469
+
1470
+ class _MoELayer(nn.Module):
1471
+ """Simplified MoE layer for recurrent core."""
1472
+
1473
+ def __init__(self, config: SpiderConfig, layer_idx: int = 0):
1474
+ super().__init__()
1475
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
1476
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
1477
+ # Simplified self-attention
1478
+ self.self_attn = nn.MultiheadAttention(
1479
+ config.hidden_size, num_heads=4, batch_first=True
1480
+ )
1481
+ # MoE FFN (SharedProjectionMoE per D-20, D-21)
1482
+ self.moe = _SharedProjectionMoE(config)
1483
+
1484
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1485
+ # Self-attention with residual
1486
+ residual = hidden_states
1487
+ hidden_states = self.input_layernorm(hidden_states)
1488
+ attn_out, _ = self.self_attn(
1489
+ hidden_states, hidden_states, hidden_states
1490
+ )
1491
+ hidden_states = residual + attn_out
1492
+
1493
+ # MoE FFN with residual
1494
+ residual = hidden_states
1495
+ hidden_states = self.post_attention_layernorm(hidden_states)
1496
+ moe_out, _z_loss = self.moe(hidden_states)
1497
+ hidden_states = residual + moe_out
1498
+
1499
+ return hidden_states
1500
+
1501
+
1502
+ class _SwiGLUFFN(nn.Module):
1503
+ """SwiGLU FFN: gate_proj, up_proj, down_proj."""
1504
+
1505
+ def __init__(self, hidden_size: int, intermediate_size: int):
1506
+ super().__init__()
1507
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
1508
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
1509
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
1510
+
1511
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1512
+ return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
1513
+
1514
+
1515
+ class _SharedProjectionMoE(nn.Module):
1516
+ """SharedProjectionMoE matching spider.py architecture (D-20, D-21).
1517
+
1518
+ Shared up/down projections computed once per token, rank-256 expert cores
1519
+ specialize on the shared representation.
1520
+ """
1521
+
1522
+ def __init__(self, config: SpiderConfig):
1523
+ super().__init__()
1524
+ self.num_experts = config.num_experts
1525
+ self.num_experts_per_tok = config.num_experts_per_tok
1526
+ self.shared_inter = config.shared_intermediate_size
1527
+ self.expert_core_rank = config.expert_core_rank
1528
+ self.hidden_size = config.hidden_size
1529
+
1530
+ self.shared_up = nn.Linear(config.hidden_size, config.shared_intermediate_size, bias=False)
1531
+ self.shared_down = nn.Linear(config.shared_intermediate_size, config.hidden_size, bias=False)
1532
+
1533
+ self.W_gate = nn.Parameter(
1534
+ torch.randn(config.num_experts, config.hidden_size, config.expert_core_rank) * 0.02
1535
+ )
1536
+ self.W_transform = nn.Parameter(
1537
+ torch.randn(config.num_experts, config.expert_core_rank, config.shared_intermediate_size) * 0.02
1538
+ )
1539
+
1540
+ self.shared_expert = _SwiGLUFFN(config.hidden_size, config.shared_expert_intermediate_size)
1541
+
1542
+ self.router = nn.Linear(config.hidden_size, config.num_experts, bias=True)
1543
+ self.router.bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32))
1544
+
1545
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1546
+ B, L, D = x.shape
1547
+
1548
+ shared_hidden = nn.functional.silu(self.shared_up(x))
1549
+
1550
+ shared_out = self.shared_expert(x)
1551
+
1552
+ router_logits = self.router(x)
1553
+ router_probs = nn.functional.softmax(router_logits, dim=-1)
1554
+
1555
+ top2_probs, top2_indices = router_probs.topk(self.num_experts_per_tok, dim=-1)
1556
+ top2_probs = top2_probs / top2_probs.sum(dim=-1, keepdim=True)
1557
+
1558
+ x_flat = x.reshape(B * L, D)
1559
+ shared_hidden_flat = shared_hidden.reshape(B * L, self.shared_inter)
1560
+
1561
+ routed_out = torch.zeros(B * L, D, device=x.device, dtype=x.dtype)
1562
+
1563
+ for k in range(self.num_experts_per_tok):
1564
+ expert_indices = top2_indices[:, :, k].reshape(B * L)
1565
+ expert_weights = top2_probs[:, :, k].reshape(B * L)
1566
+
1567
+ for e in range(self.num_experts):
1568
+ mask = (expert_indices == e)
1569
+ if not mask.any():
1570
+ continue
1571
+ expert_input = x_flat[mask]
1572
+ expert_sh = shared_hidden_flat[mask]
1573
+
1574
+ gate = expert_input @ self.W_gate[e]
1575
+ core = gate @ self.W_transform[e]
1576
+ expert_output = self.shared_down(core * expert_sh)
1577
+
1578
+ routed_out[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
1579
+
1580
+ routed_out = routed_out.reshape(B, L, D)
1581
+
1582
+ z_loss = (router_logits.logsumexp(dim=-1) ** 2).mean()
1583
+
1584
+ return shared_out + routed_out, z_loss
1585
+
1586
+
1587
+ class _LTIInjection(nn.Module):
1588
+ """Linear Time-Invariant injection module."""
1589
+
1590
+ def __init__(self, config: SpiderConfig):
1591
+ super().__init__()
1592
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
1593
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
1594
+ self.B_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
1595
+
1596
+ def forward(self, h_t: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
1597
+ A = torch.exp(self.log_A)
1598
+ decay = A * self.delta_t
1599
+ B_e = self.B_proj(e)
1600
+ return decay.unsqueeze(0).unsqueeze(0) * B_e
1601
+
1602
+
1603
+ class _ACTHalting(nn.Module):
1604
+ """Adaptive Computation Time halting module."""
1605
+
1606
+ def __init__(self, config: SpiderConfig):
1607
+ super().__init__()
1608
+ self.halt_predictor = nn.Linear(config.hidden_size, 1, bias=True)
1609
+
1610
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1611
+ return torch.sigmoid(self.halt_predictor(hidden_states)).squeeze(-1)
1612
+
1613
+
1614
+ class _LoRAAdapter(nn.Module):
1615
+ """LoRA adapter for per-loop adaptation in recurrent layers."""
1616
+
1617
+ def __init__(self, config: SpiderConfig):
1618
+ super().__init__()
1619
+ self.down = nn.Linear(config.hidden_size, config.lora_rank, bias=False)
1620
+ self.up = nn.Linear(config.lora_rank, config.hidden_size, bias=False)
1621
+ # CR-01 fix: zero init the up-projection per LoRA convention
1622
+ nn.init.zeros_(self.up.weight)
1623
+ self.scale_embeddings = nn.Embedding(config.max_loop_iters, config.lora_rank)
1624
+
1625
+ def forward(self, x: torch.Tensor, loop_iter: int) -> torch.Tensor:
1626
+ down = self.down(x)
1627
+ scale = self.scale_embeddings(torch.tensor([loop_iter], device=x.device))
1628
+ scaled = down * scale.squeeze(0)
1629
+ return self.up(scaled)
1630
+
1631
+
1632
+ def _loop_index_embedding(
1633
+ hidden_states: torch.Tensor,
1634
+ loop_iter: int,
1635
+ embed_dim: int,
1636
+ ) -> torch.Tensor:
1637
+ """Generate sinusoidal loop index embedding.
1638
+
1639
+ Provides positional-like encoding for the loop iteration index,
1640
+ allowing the model to differentiate between iterations of the
1641
+ recurrent depth loop.
1642
+ """
1643
+ B, L, D = hidden_states.shape
1644
+ device = hidden_states.device
1645
+
1646
+ # Sinusoidal embedding for loop iteration
1647
+ pos = torch.tensor([loop_iter], device=device, dtype=hidden_states.dtype)
1648
+ dim = torch.arange(embed_dim, device=device, dtype=hidden_states.dtype)
1649
+ freq = pos / (10000 ** (2 * dim / embed_dim))
1650
+
1651
+ # Interleave sin and cos
1652
+ emb = torch.zeros(embed_dim, device=device, dtype=hidden_states.dtype)
1653
+ emb[0::2] = torch.sin(freq[::2][:emb[0::2].shape[0]])
1654
+ emb[1::2] = torch.cos(freq[1::2][:emb[1::2].shape[0]])
1655
+
1656
+ # Broadcast to [B, L, embed_dim] and pad to D if needed
1657
+ emb = emb.unsqueeze(0).unsqueeze(0).expand(B, L, -1)
1658
+ if embed_dim < D:
1659
+ padding = torch.zeros(B, L, D - embed_dim, device=device, dtype=hidden_states.dtype)
1660
+ emb = torch.cat([emb, padding], dim=-1)
1661
+ elif embed_dim > D:
1662
+ emb = emb[:, :, :D]
1663
+
1664
+ return emb
1665
+
1666
+
1667
+ # ============================================================================
1668
+ # Save & Config Export
1669
+ # ============================================================================
1670
+
1671
+ def save_spider_model(
1672
+ spider_state_dict: Dict[str, torch.Tensor],
1673
+ config: SpiderConfig,
1674
+ output_dir: Path,
1675
+ ):
1676
+ """Save Spider model state dict and config to output directory.
1677
+
1678
+ Handles weight tying per safetensors pattern from init_spiderportal.py.
1679
+ """
1680
+ output_dir = Path(output_dir)
1681
+ output_dir.mkdir(parents=True, exist_ok=True)
1682
+
1683
+ # Handle weight tying: safetensors refuses shared tensors
1684
+ save_sd = {}
1685
+ for name, param in spider_state_dict.items():
1686
+ # Ensure tensor is contiguous (required by safetensors)
1687
+ # Transposes (.T) and slices can produce non-contiguous tensors
1688
+ save_sd[name] = param.contiguous()
1689
+
1690
+ if config.tie_word_embeddings and "lm_head.weight" in save_sd:
1691
+ del save_sd["lm_head.weight"]
1692
+ print(" Note: lm_head.weight tied to embed_tokens.weight (saved once)")
1693
+
1694
+ # Save as safetensors
1695
+ try:
1696
+ from safetensors.torch import save_file
1697
+ save_file(save_sd, output_dir / "model.safetensors")
1698
+ except ImportError:
1699
+ # Fallback to PyTorch save
1700
+ torch.save(save_sd, output_dir / "model.pt")
1701
+ print(" Warning: safetensors not available, saved as model.pt")
1702
+
1703
+ # Save config
1704
+ cfg_dict = {
1705
+ "architectures": ["SpiderForConditionalGeneration"],
1706
+ "model_type": config.model_type,
1707
+ "vocab_size": config.vocab_size,
1708
+ "hidden_size": config.hidden_size,
1709
+ "num_hidden_layers": config.num_hidden_layers,
1710
+ "num_attention_heads": config.num_attention_heads,
1711
+ "num_key_value_heads": config.num_key_value_heads,
1712
+ "intermediate_size": config.intermediate_size,
1713
+ "hidden_act": config.hidden_act,
1714
+ "max_position_embeddings": config.max_position_embeddings,
1715
+ "rope_theta": config.rope_theta,
1716
+ "rope_scaling": config.rope_scaling,
1717
+ "sliding_window": config.sliding_window,
1718
+ "rms_norm_eps": config.rms_norm_eps,
1719
+ "initializer_range": config.initializer_range,
1720
+ "tie_word_embeddings": config.tie_word_embeddings,
1721
+ "torch_dtype": config.torch_dtype,
1722
+ # MoE
1723
+ "num_experts": config.num_experts,
1724
+ "num_experts_per_tok": config.num_experts_per_tok,
1725
+ "num_shared_experts": config.num_shared_experts,
1726
+ "router_aux_loss_coef": config.router_aux_loss_coef,
1727
+ "shared_intermediate_size": config.shared_intermediate_size,
1728
+ "expert_core_rank": config.expert_core_rank,
1729
+ "shared_expert_intermediate_size": config.shared_expert_intermediate_size,
1730
+ "prelude_coda_intermediate_size": config.prelude_coda_intermediate_size,
1731
+ # MLA
1732
+ "kv_lora_rank": config.kv_lora_rank,
1733
+ "q_lora_rank": config.q_lora_rank,
1734
+ "qk_rope_head_dim": config.qk_rope_head_dim,
1735
+ "qk_nope_head_dim": config.qk_nope_head_dim,
1736
+ "v_head_dim": config.v_head_dim,
1737
+ # RDT
1738
+ "max_loop_iters": config.max_loop_iters,
1739
+ "act_threshold": config.act_threshold,
1740
+ "prelude_layers": config.prelude_layers,
1741
+ "coda_layers": config.coda_layers,
1742
+ "lora_rank": config.lora_rank,
1743
+ # BoundaryPredictor
1744
+ "bp_d_inner": config.bp_d_inner,
1745
+ # Multimodal
1746
+ "vision_hidden_size": config.vision_hidden_size,
1747
+ "audio_hidden_size": config.audio_hidden_size,
1748
+ "vision_num_frames": config.vision_num_frames,
1749
+ "vision_tokens_per_frame": config.vision_tokens_per_frame,
1750
+ "vision_temporal_tokens": config.vision_temporal_tokens,
1751
+ "vision_temporal_layers": config.vision_temporal_layers,
1752
+ }
1753
+ with open(output_dir / "config.json", "w") as f:
1754
+ json.dump(cfg_dict, f, indent=2)
1755
+
1756
+ # Compute SHA256 of model file for integrity check (T-02-03 mitigation)
1757
+ model_file = output_dir / "model.safetensors"
1758
+ if not model_file.exists():
1759
+ model_file = output_dir / "model.pt"
1760
+ if model_file.exists():
1761
+ sha256 = hashlib.sha256()
1762
+ with open(model_file, "rb") as f:
1763
+ for chunk in iter(lambda: f.read(8192), b""):
1764
+ sha256.update(chunk)
1765
+ print(f" Model SHA256: {sha256.hexdigest()[:16]}...")
1766
+ with open(output_dir / "model.sha256", "w") as f:
1767
+ f.write(sha256.hexdigest())
1768
+
1769
+ print(f" Saved to {output_dir}")
1770
+ if model_file.exists():
1771
+ print(f" Model file size: {model_file.stat().st_size / 1e6:.1f} MB")
1772
+
1773
+
1774
+ # ============================================================================
1775
+ # CLI Entry Point
1776
+ # ============================================================================
1777
+
1778
+ def main():
1779
+ parser = argparse.ArgumentParser(
1780
+ description="Transfer weights from Qwen3.5-2B to Spider-FLEXITOKENS"
1781
+ )
1782
+ parser.add_argument(
1783
+ "--donor", type=str, default="Qwen/Qwen3.5-2B",
1784
+ help="HuggingFace model ID or local path for donor model"
1785
+ )
1786
+ parser.add_argument(
1787
+ "--output", type=str, default="models/Spider-FLEXITOKENS-init/",
1788
+ help="Output directory for Spider model"
1789
+ )
1790
+ parser.add_argument(
1791
+ "--config", type=str, default="spider_flexitokens_997m",
1792
+ help="Spider model configuration name"
1793
+ )
1794
+ parser.add_argument(
1795
+ "--noise-scale", type=float, default=0.02,
1796
+ help="Noise scale for MoE expert perturbation"
1797
+ )
1798
+ parser.add_argument(
1799
+ "--dry-run", action="store_true",
1800
+ help="Run with dummy donor (no download required)"
1801
+ )
1802
+ args = parser.parse_args()
1803
+
1804
+ # Select config
1805
+ config_map = {
1806
+ "spider_flexitokens_997m": spider_flexitokens_997m(),
1807
+ }
1808
+ spider_config = config_map.get(args.config, spider_flexitokens_997m())
1809
+
1810
+ if args.dry_run:
1811
+ print("DRY RUN: Using dummy donor (no download)")
1812
+ donor = create_dummy_donor(num_layers=10, full_attention_layers=list(range(10)))
1813
+ donor_sd = donor["state_dict"]
1814
+ donor_cfg = donor["config"]
1815
+ else:
1816
+ # Load actual Qwen3.5-2B from HuggingFace
1817
+ print(f"Loading donor model: {args.donor}")
1818
+ try:
1819
+ from transformers import AutoModelForCausalLM, AutoConfig
1820
+ donor_model = AutoModelForCausalLM.from_pretrained(
1821
+ args.donor, torch_dtype=torch.bfloat16, device_map="cpu"
1822
+ )
1823
+ donor_cfg_obj = AutoConfig.from_pretrained(args.donor)
1824
+
1825
+ # Extract full_attention layers from Qwen3.5-2B config
1826
+ # Qwen3.5-2B has hybrid attention: some full, some linear
1827
+ full_attention_layers = getattr(
1828
+ donor_cfg_obj, "full_attention_layers", None
1829
+ )
1830
+ if full_attention_layers is None:
1831
+ # Fallback: assume layers with attention_type == "full"
1832
+ # Qwen3.5-2B: 18 linear + 6 full attention in 24 layers
1833
+ full_attention_layers = []
1834
+ for i in range(donor_cfg_obj.num_hidden_layers):
1835
+ layer_cfg = getattr(donor_cfg_obj, f"layer_{i}", None)
1836
+ if layer_cfg and getattr(layer_cfg, "attention_type", "full") == "full":
1837
+ full_attention_layers.append(i)
1838
+ if not full_attention_layers:
1839
+ # If no layer-level info, use known pattern for Qwen3.5-2B
1840
+ full_attention_layers = [3, 7, 11, 15, 19, 23]
1841
+
1842
+ donor_sd = donor_model.state_dict()
1843
+ donor_cfg = {
1844
+ "hidden_size": donor_cfg_obj.hidden_size,
1845
+ "num_attention_heads": donor_cfg_obj.num_attention_heads,
1846
+ "num_key_value_heads": getattr(donor_cfg_obj, "num_key_value_heads", 2),
1847
+ "head_dim": getattr(donor_cfg_obj, "head_dim",
1848
+ donor_cfg_obj.hidden_size // donor_cfg_obj.num_attention_heads),
1849
+ "intermediate_size": donor_cfg_obj.intermediate_size,
1850
+ "vocab_size": donor_cfg_obj.vocab_size,
1851
+ "num_hidden_layers": donor_cfg_obj.num_hidden_layers,
1852
+ "full_attention_layers": full_attention_layers,
1853
+ "model_type": getattr(donor_cfg_obj, "model_type", "qwen3"),
1854
+ }
1855
+ except ImportError:
1856
+ print("Error: transformers library required for loading donor model.")
1857
+ print("Install with: pip install transformers")
1858
+ sys.exit(1)
1859
+ except Exception as e:
1860
+ print(f"Error loading donor model: {e}")
1861
+ print("Use --dry-run for testing without download.")
1862
+ sys.exit(1)
1863
+
1864
+ # Run transfer
1865
+ result = transfer_qwen_to_spider(
1866
+ donor_state_dict=donor_sd,
1867
+ donor_config=donor_cfg,
1868
+ spider_config=spider_config,
1869
+ noise_scale=args.noise_scale,
1870
+ )
1871
+
1872
+ # Save
1873
+ save_spider_model(
1874
+ spider_state_dict=result["spider_state_dict"],
1875
+ config=spider_config,
1876
+ output_dir=Path(args.output),
1877
+ )
1878
+
1879
+ print("\nWeight transfer complete!")
1880
+
1881
+
1882
+ if __name__ == "__main__":
1883
+ main()