| |
| """Weight transfer from Qwen3.5-2B donor to Spider-FLEXITOKENS architecture. |
| |
| Implements the weight transfer pipeline per D-09 and D-10: |
| - Loads Qwen3.5-2B via HF transformers |
| - Filters to full_attention layers only (discards linear_attention) |
| - SVD decomposition converts standard GQA attention to MLA format |
| - Direct copies where shapes match (o_proj, layer norms) |
| - Reinitializes incompatible weights (embeddings, boundary predictor, FFN) |
| - Reports transfer coverage as percentage |
| |
| Usage: |
| python scripts/transfer_weights.py --donor Qwen/Qwen3.5-2B --output models/Spider-FLEXITOKENS-init/ --config spider_flexitokens_997m |
| """ |
|
|
| import argparse |
| import hashlib |
| import json |
| import math |
| import os |
| import sys |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| from spider import ( |
| SENTINEL_TOKENS, |
| is_sentinel_token, |
| create_modality_mask, |
| BoundaryPredictor, |
| downsample, |
| upsample, |
| SpiderConfig as _CanonicalSpiderConfig, |
| spider_flexitokens_997m as _canonical_config_fn, |
| ) |
|
|
| |
| _TOKEN_NAMES_BY_ID = {v: k for k, v in SENTINEL_TOKENS.items()} |
|
|
|
|
| |
| |
| |
| |
| |
| |
| _SENTINEL_PAIRS = [ |
| (SENTINEL_TOKENS['IMG_START'], SENTINEL_TOKENS['IMG_END']), |
| (SENTINEL_TOKENS['AUD_START'], SENTINEL_TOKENS['AUD_END']), |
| (SENTINEL_TOKENS['VID_START'], SENTINEL_TOKENS['VID_END']), |
| ] |
| _MODALITY_SENTINEL_IDS = {259, 260, 261, 262, 263, 264} |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class SpiderConfig: |
| """Spider-FLEXITOKENS model configuration (hidden_size=2048). |
| |
| Based on mythos-fineweb-moe.py SpiderPortalConfig with byte-level |
| tokenization and MLA attention. Mirrors canonical spider.py config. |
| """ |
| |
| vocab_size: int = 272 |
| hidden_size: int = 2048 |
| num_hidden_layers: int = 6 |
| num_attention_heads: int = 16 |
| num_key_value_heads: int = 4 |
| intermediate_size: int = 1024 |
| hidden_act: str = "silu" |
|
|
| |
| num_experts: int = 32 |
| num_experts_per_tok: int = 2 |
| num_shared_experts: int = 1 |
| router_aux_loss_coef: float = 0.05 |
| shared_intermediate_size: int = 6144 |
| expert_core_rank: int = 256 |
| shared_expert_intermediate_size: int = 7424 |
| prelude_coda_intermediate_size: int = 4096 |
|
|
| |
| max_loop_iters: int = 16 |
| act_threshold: float = 0.5 |
| prelude_layers: int = 2 |
| coda_layers: int = 2 |
| lora_rank: int = 128 |
| loop_embed_dim: int = 128 |
|
|
| |
| kv_lora_rank: int = 128 |
| q_lora_rank: int = 256 |
| qk_rope_head_dim: int = 64 |
| qk_nope_head_dim: int = 64 |
| v_head_dim: int = 64 |
|
|
| |
| max_position_embeddings: int = 262144 |
| rope_theta: float = 10000000.0 |
| rope_scaling: Optional[Dict] = field(default_factory=lambda: { |
| "type": "yarn", |
| "factor": 8.0, |
| "original_max_position_embeddings": 32768, |
| }) |
| sliding_window: int = 8192 |
| attention_dropout: float = 0.0 |
| rms_norm_eps: float = 1e-6 |
| initializer_range: float = 0.02 |
|
|
| |
| tie_word_embeddings: bool = True |
|
|
| |
| model_type: str = "spider" |
| torch_dtype: str = "bfloat16" |
|
|
| |
| bp_d_inner: int = 8192 |
|
|
| |
| engram_layers: list = None |
| engram_table_size: int = 8191 |
| engram_heads: int = 4 |
| engram_dim: int = 128 |
| engram_offload: bool = True |
|
|
| |
| vision_hidden_size: int = 2048 |
| audio_hidden_size: int = 512 |
| vision_num_frames: int = 60 |
| vision_tokens_per_frame: int = 256 |
| vision_temporal_tokens: int = 64 |
| vision_temporal_layers: int = 2 |
|
|
| @property |
| def head_dim(self): |
| return self.qk_nope_head_dim + self.qk_rope_head_dim |
|
|
| def __post_init__(self): |
| if self.engram_layers is None: |
| self.engram_layers = [1, 4] |
|
|
|
|
| def spider_flexitokens_997m() -> SpiderConfig: |
| """Spider-FLEXITOKENS 997M config.""" |
| return SpiderConfig() |
|
|
|
|
| |
| |
| |
|
|
| def create_dummy_donor(num_layers: int = 4, full_attention_layers: Optional[List[int]] = None, mini: bool = False): |
| """Create a dummy Qwen3.5-2B-like donor state dict and config. |
| |
| Mimics the structure of Qwen3.5-2B with: |
| - hidden_size=2048, num_heads=8, num_kv_heads=2, head_dim=256 |
| - full_attention and linear_attention layer identification |
| - intermediate_size=6144 |
| - vocab_size=248320 |
| |
| Args: |
| num_layers: Number of layers to create |
| full_attention_layers: Indices of full_attention layers (default: all) |
| mini: If True, use smaller tensors for fast testing |
| |
| Returns: |
| Dict with "state_dict", "config" keys |
| """ |
| hidden_size = 2048 |
| num_heads = 8 |
| num_kv_heads = 2 |
| head_dim = 256 |
| intermediate_size = 6144 |
| vocab_size = 248320 |
|
|
| if full_attention_layers is None: |
| |
| full_attention_layers = list(range(num_layers)) |
|
|
| |
| scale = 8 if mini else 1 |
| hs = hidden_size // scale |
| n_h = max(num_heads // scale, 1) |
| n_kv_h = max(num_kv_heads // scale, 1) |
| hd = head_dim |
| inter = intermediate_size // scale |
| vs = min(vocab_size, 1024) if mini else vocab_size |
|
|
| state_dict = {} |
|
|
| |
| state_dict["model.embed_tokens.weight"] = torch.randn(vs, hs) * 0.02 |
|
|
| |
| for i in range(num_layers): |
| prefix = f"model.layers.{i}" |
| |
| state_dict[f"{prefix}.self_attn.q_proj.weight"] = torch.randn(n_h * hd, hs) * 0.02 |
| state_dict[f"{prefix}.self_attn.k_proj.weight"] = torch.randn(n_kv_h * hd, hs) * 0.02 |
| state_dict[f"{prefix}.self_attn.v_proj.weight"] = torch.randn(n_kv_h * hd, hs) * 0.02 |
| state_dict[f"{prefix}.self_attn.o_proj.weight"] = torch.randn(hs, hs) * 0.02 |
| |
| state_dict[f"{prefix}.input_layernorm.weight"] = torch.ones(hs, dtype=torch.float32) |
| state_dict[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hs, dtype=torch.float32) |
| |
| state_dict[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(inter, hs) * 0.02 |
| state_dict[f"{prefix}.mlp.up_proj.weight"] = torch.randn(inter, hs) * 0.02 |
| state_dict[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hs, inter) * 0.02 |
|
|
| |
| state_dict["model.norm.weight"] = torch.ones(hs, dtype=torch.float32) |
| |
| state_dict["lm_head.weight"] = torch.randn(vs, hs) * 0.02 |
|
|
| config = { |
| "hidden_size": hs, |
| "num_attention_heads": n_h, |
| "num_key_value_heads": n_kv_h, |
| "head_dim": hd, |
| "intermediate_size": inter, |
| "vocab_size": vs, |
| "num_hidden_layers": num_layers, |
| "full_attention_layers": full_attention_layers, |
| "model_type": "qwen3", |
| "mini": mini, |
| } |
|
|
| return {"state_dict": state_dict, "config": config} |
|
|
|
|
| |
| |
| |
|
|
| def decompose_attention_svd( |
| weight: torch.Tensor, |
| lora_rank: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """SVD decompose a weight matrix into low-rank a_proj and b_proj. |
| |
| Per D-10: Decompression (b_proj) matrices initialized from SVD; |
| compression (a_proj) matrices are reinitialized with Kaiming init. |
| |
| Args: |
| weight: Weight matrix of shape [in_features, out_features] or |
| [out_features, in_features]. For Linear(in, out, bias=False), |
| PyTorch stores weight as [out_features, in_features]. |
| lora_rank: Target rank for the low-rank decomposition. |
| |
| Returns: |
| Tuple of (a_proj, b_proj) where: |
| - a_proj: [in_features, lora_rank] — compression (REINITIALIZED by caller) |
| - b_proj: [lora_rank, out_features] — decompression (from SVD) |
| """ |
| |
| if weight.dim() != 2: |
| raise ValueError(f"Expected 2D weight, got {weight.dim()}D") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| weight_f32 = weight.float() |
|
|
| |
| U, S, Vh = torch.linalg.svd(weight_f32, full_matrices=False) |
|
|
| |
| a_proj = U[:, :lora_rank] @ torch.diag(S[:lora_rank]) |
| b_proj = Vh[:lora_rank, :] |
|
|
| return a_proj, b_proj |
|
|
|
|
| |
| |
| |
|
|
| def split_dense_to_moe( |
| spider_state_dict: Dict[str, torch.Tensor], |
| config: SpiderConfig, |
| noise_scale: float = 0.02, |
| ) -> Dict[str, torch.Tensor]: |
| """Initialize SharedProjectionMoE expert cores and router per D-20/D-21. |
| |
| Per D-21: W_gate and W_transform are randomly initialized with small |
| normal noise (std=0.02) to break symmetry. shared_up, shared_down, |
| and shared_expert are already populated by transfer_qwen_to_spider. |
| |
| Args: |
| spider_state_dict: Spider model state dict (mutated in-place) |
| config: Spider model config |
| noise_scale: Noise std for expert core initialization |
| |
| Returns: |
| Updated state dict with SharedProjectionMoE weights |
| """ |
| for layer_idx in range(config.num_hidden_layers): |
| rec_prefix = f"model.recurrent_layers.{layer_idx}.moe" |
|
|
| |
| w_gate_key = f"{rec_prefix}.W_gate" |
| if w_gate_key not in spider_state_dict: |
| spider_state_dict[w_gate_key] = ( |
| torch.randn(config.num_experts, config.hidden_size, config.expert_core_rank) |
| * noise_scale |
| ) |
|
|
| |
| w_transform_key = f"{rec_prefix}.W_transform" |
| if w_transform_key not in spider_state_dict: |
| spider_state_dict[w_transform_key] = ( |
| torch.randn(config.num_experts, config.expert_core_rank, config.shared_intermediate_size) |
| * noise_scale |
| ) |
|
|
| |
| router_key = f"{rec_prefix}.router.weight" |
| if router_key not in spider_state_dict: |
| spider_state_dict[router_key] = ( |
| torch.randn(config.num_experts, config.hidden_size) |
| * config.initializer_range |
| ) |
|
|
| |
| router_bias_key = f"{rec_prefix}.router.bias" |
| if router_bias_key not in spider_state_dict: |
| spider_state_dict[router_bias_key] = torch.zeros(config.num_experts, dtype=torch.float32) |
|
|
| return spider_state_dict |
|
|
|
|
| |
| |
| |
|
|
| def get_spider_param_shapes(config: SpiderConfig) -> Dict[str, Tuple[int, ...]]: |
| """Return expected parameter shapes for the Spider model. |
| |
| Used for validation that all shapes match after weight transfer. |
| """ |
| shapes = {} |
|
|
| |
| shapes["embed_tokens.weight"] = (config.vocab_size, config.hidden_size) |
| shapes["lm_head.weight"] = (config.vocab_size, config.hidden_size) |
|
|
| |
| shapes["boundary_predictor.0.weight"] = (config.bp_d_inner, config.hidden_size) |
| shapes["boundary_predictor.0.bias"] = (config.bp_d_inner,) |
| shapes["boundary_predictor.2.weight"] = (1, config.bp_d_inner) |
| shapes["boundary_predictor.2.bias"] = (1,) |
|
|
| |
| shapes["null_group.weight"] = (config.hidden_size,) |
|
|
| |
| shapes["down_ln.weight"] = (config.hidden_size,) |
| shapes["down_ln.bias"] = (config.hidden_size,) |
|
|
| head_dim = config.head_dim |
|
|
| for section, num_layers in [ |
| ("prelude_layers", config.prelude_layers), |
| ("coda_layers", config.coda_layers), |
| ]: |
| for i in range(num_layers): |
| prefix = f"model.{section}.{i}" |
|
|
| |
| shapes[f"{prefix}.self_attn.q_a_proj.weight"] = (config.q_lora_rank, config.hidden_size) |
| shapes[f"{prefix}.self_attn.q_a_layernorm.weight"] = (config.q_lora_rank,) |
| shapes[f"{prefix}.self_attn.q_b_proj.weight"] = (config.num_attention_heads * head_dim, config.q_lora_rank) |
| shapes[f"{prefix}.self_attn.kv_a_proj_with_mqa.weight"] = (config.kv_lora_rank + config.qk_rope_head_dim, config.hidden_size) |
| shapes[f"{prefix}.self_attn.kv_a_layernorm.weight"] = (config.kv_lora_rank,) |
| shapes[f"{prefix}.self_attn.kv_b_proj.weight"] = (config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim), config.kv_lora_rank) |
| shapes[f"{prefix}.self_attn.o_proj.weight"] = (config.hidden_size, config.num_attention_heads * config.v_head_dim) |
|
|
| |
| shapes[f"{prefix}.input_layernorm.weight"] = (config.hidden_size,) |
| shapes[f"{prefix}.post_attention_layernorm.weight"] = (config.hidden_size,) |
|
|
| |
| dense_inter = config.prelude_coda_intermediate_size |
| shapes[f"{prefix}.ffn.gate_proj.weight"] = (dense_inter, config.hidden_size) |
| shapes[f"{prefix}.ffn.up_proj.weight"] = (dense_inter, config.hidden_size) |
| shapes[f"{prefix}.ffn.down_proj.weight"] = (config.hidden_size, dense_inter) |
|
|
| |
| for i in range(config.num_hidden_layers): |
| prefix = f"model.recurrent_layers.{i}" |
|
|
| |
| shapes[f"{prefix}.self_attn.q_a_proj.weight"] = (config.q_lora_rank, config.hidden_size) |
| shapes[f"{prefix}.self_attn.q_a_layernorm.weight"] = (config.q_lora_rank,) |
| shapes[f"{prefix}.self_attn.q_b_proj.weight"] = (config.num_attention_heads * head_dim, config.q_lora_rank) |
| shapes[f"{prefix}.self_attn.kv_a_proj_with_mqa.weight"] = (config.kv_lora_rank + config.qk_rope_head_dim, config.hidden_size) |
| shapes[f"{prefix}.self_attn.kv_a_layernorm.weight"] = (config.kv_lora_rank,) |
| shapes[f"{prefix}.self_attn.kv_b_proj.weight"] = (config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim), config.kv_lora_rank) |
| shapes[f"{prefix}.self_attn.o_proj.weight"] = (config.hidden_size, config.num_attention_heads * config.v_head_dim) |
|
|
| |
| shapes[f"{prefix}.input_layernorm.weight"] = (config.hidden_size,) |
| shapes[f"{prefix}.post_attention_layernorm.weight"] = (config.hidden_size,) |
|
|
| |
| |
| shapes[f"{prefix}.moe.shared_up.weight"] = (config.shared_intermediate_size, config.hidden_size) |
| |
| shapes[f"{prefix}.moe.shared_down.weight"] = (config.hidden_size, config.shared_intermediate_size) |
| |
| shapes[f"{prefix}.moe.W_gate"] = (config.num_experts, config.hidden_size, config.expert_core_rank) |
| |
| shapes[f"{prefix}.moe.W_transform"] = (config.num_experts, config.expert_core_rank, config.shared_intermediate_size) |
| |
| shapes[f"{prefix}.moe.shared_expert.gate_proj.weight"] = (config.shared_expert_intermediate_size, config.hidden_size) |
| shapes[f"{prefix}.moe.shared_expert.up_proj.weight"] = (config.shared_expert_intermediate_size, config.hidden_size) |
| shapes[f"{prefix}.moe.shared_expert.down_proj.weight"] = (config.hidden_size, config.shared_expert_intermediate_size) |
| |
| shapes[f"{prefix}.moe.router.weight"] = (config.num_experts, config.hidden_size) |
| shapes[f"{prefix}.moe.router.bias"] = (config.num_experts,) |
|
|
| |
| shapes[f"{prefix}.lora_adapter.down.weight"] = (config.lora_rank, config.hidden_size) |
| shapes[f"{prefix}.lora_adapter.B"] = (config.lora_rank, config.hidden_size) |
| shapes[f"{prefix}.lora_adapter.scale.weight"] = (config.max_loop_iters, config.lora_rank) |
|
|
| |
| shapes[f"{prefix}.act_halting.halt_predictor.weight"] = (1, config.hidden_size) |
| shapes[f"{prefix}.act_halting.halt_predictor.bias"] = (1,) |
|
|
| |
| if i in config.engram_layers: |
| engram_mem_dim = config.engram_heads * config.engram_dim |
| shapes[f"{prefix}.engram.W_k.weight"] = (config.hidden_size, engram_mem_dim * 2) |
| shapes[f"{prefix}.engram.W_v.weight"] = (config.hidden_size, engram_mem_dim * 2) |
| shapes[f"{prefix}.engram.conv.weight"] = (config.hidden_size, 1, 4) |
| shapes[f"{prefix}.engram.conv.bias"] = (config.hidden_size,) |
| shapes[f"{prefix}.engram.q_norm.weight"] = (config.hidden_size,) |
| shapes[f"{prefix}.engram.k_norm.weight"] = (config.hidden_size,) |
| shapes[f"{prefix}.engram.embed"] = (2, config.engram_heads, config.engram_table_size, config.engram_dim) |
| shapes[f"{prefix}.engram.hash_seeds"] = (config.engram_heads * 2,) |
| shapes[f"{prefix}.post_engram_layernorm.weight"] = (config.hidden_size,) |
|
|
| |
| shapes["model.injection.log_A"] = (config.hidden_size,) |
| shapes["model.injection.delta_t"] = () |
| shapes["model.injection.B.weight"] = (config.hidden_size, config.hidden_size) |
|
|
| |
| shapes["model.norm.weight"] = (config.hidden_size,) |
|
|
| |
| |
|
|
| |
| shapes["model.act_halting.halt_predictor.weight"] = (1, config.hidden_size) |
| shapes["model.act_halting.halt_predictor.bias"] = (1,) |
|
|
| return shapes |
|
|
|
|
| |
| |
| |
|
|
| def _adapt_weight(weight, target_out, target_in): |
| """Adapt a donor weight matrix to Spider dimensions via padding/cropping. |
| |
| When donor hidden_size differs from Spider's (e.g., in mini test mode), |
| we pad or crop the weight matrix to match target dimensions. |
| |
| Args: |
| weight: [out_features, in_features] weight tensor from donor |
| target_out: Target output dimension |
| target_in: Target input dimension |
| |
| Returns: |
| Adapted weight tensor of shape [target_out, target_in] |
| """ |
| out_dim, in_dim = weight.shape |
|
|
| |
| adapted = torch.empty(target_out, target_in) |
| nn.init.kaiming_uniform_(adapted, a=math.sqrt(5)) |
|
|
| |
| copy_out = min(out_dim, target_out) |
| copy_in = min(in_dim, target_in) |
| adapted[:copy_out, :copy_in] = weight[:copy_out, :copy_in] |
|
|
| return adapted |
|
|
|
|
| |
| |
| |
|
|
| def transfer_qwen_to_spider( |
| donor_state_dict: Dict[str, torch.Tensor], |
| donor_config: Dict, |
| spider_config: SpiderConfig, |
| noise_scale: float = 0.02, |
| ) -> Dict: |
| """Transfer weights from Qwen3.5-2B donor to Spider-FLEXITOKENS architecture. |
| |
| Per D-09: Qwen3.5-2B is the weight donor. Per D-10: SVD decomposition |
| converts standard GQA attention to MLA format. |
| |
| Transfer rules: |
| - o_proj [2048, 2048]: direct copy from donor |
| - q_proj → SVD → q_b_proj (q_a_proj reinitialized with Kaiming) |
| - k_proj + v_proj → SVD → kv_b_proj (kv_a_proj reinitialized with Kaiming) |
| - Layer norms [2048]: direct copy |
| - Embeddings: REINIT [272, 2048] (byte-level) |
| - BoundaryPredictor: REINIT (no pre-trained source) |
| - FFN: REINIT (intermediate_size mismatch 6144 vs 1024) |
| - LoRA, ACT, LTI: REINIT (Spider-specific modules) |
| |
| Args: |
| donor_state_dict: Qwen3.5-2B state dict |
| donor_config: Donor model config dict |
| spider_config: Spider model config |
| noise_scale: Noise scale for MoE expert perturbation |
| |
| Returns: |
| Dict with "spider_state_dict", "transfer_coverage", "layer_mapping" |
| """ |
| hidden_size = spider_config.hidden_size |
| q_lora_rank = spider_config.q_lora_rank |
| kv_lora_rank = spider_config.kv_lora_rank |
| num_heads = spider_config.num_attention_heads |
| head_dim = spider_config.head_dim |
| qk_nope_head_dim = spider_config.qk_nope_head_dim |
| qk_rope_head_dim = spider_config.qk_rope_head_dim |
| v_head_dim = spider_config.v_head_dim |
|
|
| |
| donor_hidden_size = donor_config.get("hidden_size", hidden_size) |
| donor_num_heads = donor_config.get("num_attention_heads", 8) |
| donor_num_kv_heads = donor_config.get("num_key_value_heads", 2) |
| donor_head_dim = donor_config.get("head_dim", 256) |
| donor_intermediate_size = donor_config.get("intermediate_size", 6144) |
|
|
| |
| donor_param_count = 0 |
| reinit_param_count = 0 |
| donor_params = set() |
| reinit_params = set() |
|
|
| spider_sd = {} |
|
|
| |
| full_attention_layers = donor_config.get("full_attention_layers", []) |
| num_donor_layers = donor_config.get("num_hidden_layers", 24) |
|
|
| |
| |
| |
| available_fa = list(full_attention_layers) |
|
|
| |
| layer_mapping = {} |
| required_layers = ( |
| spider_config.prelude_layers |
| + spider_config.num_hidden_layers |
| + spider_config.coda_layers |
| ) |
|
|
| |
| donor_pool = list(available_fa) |
| if len(donor_pool) < required_layers: |
| |
| all_layers = list(range(num_donor_layers)) |
| for l in all_layers: |
| if l not in donor_pool: |
| donor_pool.append(l) |
|
|
| for i in range(required_layers): |
| if i < len(donor_pool): |
| layer_mapping[i] = donor_pool[i] |
| else: |
| layer_mapping[i] = None |
|
|
| def _kaiming_init(shape): |
| """Kaiming uniform initialization for new parameters.""" |
| tensor = torch.empty(shape) |
| nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) |
| return tensor |
|
|
| def _zeros_init(shape): |
| """Zero initialization.""" |
| return torch.zeros(shape, dtype=torch.float32) |
|
|
| def _ones_init(shape): |
| """Ones initialization for layer norm weights.""" |
| return torch.ones(shape, dtype=torch.float32) |
|
|
| |
| embed_weight = _kaiming_init((spider_config.vocab_size, hidden_size)) |
| spider_sd["embed_tokens.weight"] = embed_weight |
| reinit_param_count += embed_weight.numel() |
| reinit_params.add("embed_tokens.weight") |
|
|
| lm_head_weight = _kaiming_init((spider_config.vocab_size, hidden_size)) |
| spider_sd["lm_head.weight"] = lm_head_weight |
| reinit_param_count += lm_head_weight.numel() |
| reinit_params.add("lm_head.weight") |
|
|
| |
| bp_0_weight = _kaiming_init((spider_config.bp_d_inner, hidden_size)) |
| bp_0_bias = _zeros_init((spider_config.bp_d_inner,)) |
| bp_2_weight = _kaiming_init((1, spider_config.bp_d_inner)) |
| bp_2_bias = _zeros_init((1,)) |
| spider_sd["boundary_predictor.0.weight"] = bp_0_weight |
| spider_sd["boundary_predictor.0.bias"] = bp_0_bias |
| spider_sd["boundary_predictor.2.weight"] = bp_2_weight |
| spider_sd["boundary_predictor.2.bias"] = bp_2_bias |
| reinit_param_count += bp_0_weight.numel() + bp_0_bias.numel() |
| reinit_param_count += bp_2_weight.numel() + bp_2_bias.numel() |
| reinit_params.add("boundary_predictor.0.weight") |
| reinit_params.add("boundary_predictor.2.weight") |
|
|
| |
| null_group = _zeros_init((hidden_size,)) |
| spider_sd["null_group.weight"] = null_group |
| reinit_param_count += null_group.numel() |
| reinit_params.add("null_group.weight") |
|
|
| down_ln_w = torch.ones(hidden_size, dtype=torch.float32) |
| down_ln_b = _zeros_init((hidden_size,)) |
| spider_sd["down_ln.weight"] = down_ln_w |
| spider_sd["down_ln.bias"] = down_ln_b |
| reinit_param_count += down_ln_w.numel() + down_ln_b.numel() |
| reinit_params.add("down_ln.weight") |
|
|
| |
| for section_name, num_layers in [ |
| ("prelude_layers", spider_config.prelude_layers), |
| ("recurrent_layers", spider_config.num_hidden_layers), |
| ("coda_layers", spider_config.coda_layers), |
| ]: |
| is_recurrent = section_name == "recurrent_layers" |
|
|
| for layer_idx in range(num_layers): |
| |
| |
| |
| spider_layer_idx = ({ |
| "prelude_layers": 0, |
| "recurrent_layers": spider_config.prelude_layers, |
| "coda_layers": spider_config.prelude_layers + spider_config.num_hidden_layers, |
| }[section_name] + layer_idx) |
| donor_layer_idx = layer_mapping.get(spider_layer_idx) |
|
|
| prefix = f"model.{section_name}.{layer_idx}" |
|
|
| if donor_layer_idx is not None: |
| donor_prefix = f"model.layers.{donor_layer_idx}" |
| else: |
| donor_prefix = None |
|
|
| |
| |
| if donor_prefix is not None: |
| donor_q_key = f"{donor_prefix}.self_attn.q_proj.weight" |
| donor_q = donor_state_dict.get(donor_q_key) |
| else: |
| donor_q = None |
|
|
| if donor_q is not None and donor_q.shape[0] == donor_num_heads * donor_head_dim and donor_q.shape[1] == donor_hidden_size: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| effective_q = donor_q |
| if donor_hidden_size != hidden_size: |
| |
| effective_q = _adapt_weight(donor_q, donor_num_heads * donor_head_dim, hidden_size) |
|
|
| q_a_svd, q_b_svd = decompose_attention_svd(effective_q, q_lora_rank) |
| |
| q_a_proj = _kaiming_init((q_lora_rank, hidden_size)) |
| q_b_proj = q_b_svd.T |
| donor_param_count += q_b_proj.numel() |
| reinit_param_count += q_a_proj.numel() |
| donor_params.add(f"{prefix}.self_attn.q_b_proj.weight") |
| reinit_params.add(f"{prefix}.self_attn.q_a_proj.weight") |
| else: |
| q_a_proj = _kaiming_init((q_lora_rank, hidden_size)) |
| q_b_proj = _kaiming_init((num_heads * head_dim, q_lora_rank)) |
| reinit_param_count += q_a_proj.numel() + q_b_proj.numel() |
| reinit_params.add(f"{prefix}.self_attn.q_a_proj.weight") |
| reinit_params.add(f"{prefix}.self_attn.q_b_proj.weight") |
|
|
| spider_sd[f"{prefix}.self_attn.q_a_proj.weight"] = q_a_proj |
| spider_sd[f"{prefix}.self_attn.q_b_proj.weight"] = q_b_proj |
|
|
| |
| q_a_ln = torch.ones(q_lora_rank, dtype=torch.float32) |
| spider_sd[f"{prefix}.self_attn.q_a_layernorm.weight"] = q_a_ln |
| reinit_param_count += q_a_ln.numel() |
| reinit_params.add(f"{prefix}.self_attn.q_a_layernorm.weight") |
|
|
| |
| if donor_prefix is not None: |
| donor_k_key = f"{donor_prefix}.self_attn.k_proj.weight" |
| donor_v_key = f"{donor_prefix}.self_attn.v_proj.weight" |
| donor_k = donor_state_dict.get(donor_k_key) |
| donor_v = donor_state_dict.get(donor_v_key) |
| else: |
| donor_k = None |
| donor_v = None |
|
|
| if donor_k is not None and donor_v is not None: |
| |
| |
| |
| |
| combined_kv = torch.cat([donor_k, donor_v], dim=0) |
|
|
| |
| if donor_hidden_size != hidden_size: |
| combined_kv_out = donor_num_kv_heads * donor_head_dim * 2 |
| combined_kv = _adapt_weight(combined_kv, combined_kv_out, hidden_size) |
|
|
| |
| kv_a_svd, kv_b_svd = decompose_attention_svd(combined_kv.T, kv_lora_rank) |
| |
| |
| |
|
|
| |
| |
| kv_a_with_mqa = _kaiming_init( |
| (kv_lora_rank + qk_rope_head_dim, hidden_size) |
| ) |
|
|
| |
| |
| |
| |
| kv_b_proj_weight = _kaiming_init( |
| (num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank) |
| ) |
| svd_contribution = kv_b_svd.T |
| |
| rows_to_copy = min(svd_contribution.shape[0], kv_b_proj_weight.shape[0]) |
| kv_b_proj_weight[:rows_to_copy, :] = svd_contribution[:rows_to_copy] |
|
|
| |
| donor_rows = rows_to_copy |
| reinit_rows = kv_b_proj_weight.shape[0] - donor_rows |
| donor_param_count += donor_rows * kv_b_proj_weight.shape[1] |
| reinit_param_count += reinit_rows * kv_b_proj_weight.shape[1] |
|
|
| reinit_param_count += kv_a_with_mqa.numel() |
| donor_params.add(f"{prefix}.self_attn.kv_b_proj.weight") |
| reinit_params.add(f"{prefix}.self_attn.kv_a_proj_with_mqa.weight") |
| else: |
| kv_a_with_mqa = _kaiming_init( |
| (kv_lora_rank + qk_rope_head_dim, hidden_size) |
| ) |
| kv_b_proj_weight = _kaiming_init( |
| (num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank) |
| ) |
| reinit_param_count += kv_a_with_mqa.numel() + kv_b_proj_weight.numel() |
| reinit_params.add(f"{prefix}.self_attn.kv_a_proj_with_mqa.weight") |
| reinit_params.add(f"{prefix}.self_attn.kv_b_proj.weight") |
|
|
| spider_sd[f"{prefix}.self_attn.kv_a_proj_with_mqa.weight"] = kv_a_with_mqa |
| spider_sd[f"{prefix}.self_attn.kv_b_proj.weight"] = kv_b_proj_weight |
|
|
| |
| kv_a_ln = torch.ones(kv_lora_rank, dtype=torch.float32) |
| spider_sd[f"{prefix}.self_attn.kv_a_layernorm.weight"] = kv_a_ln |
| reinit_param_count += kv_a_ln.numel() |
| reinit_params.add(f"{prefix}.self_attn.kv_a_layernorm.weight") |
|
|
| |
| |
| |
| if donor_prefix is not None: |
| donor_o_key = f"{donor_prefix}.self_attn.o_proj.weight" |
| donor_o = donor_state_dict.get(donor_o_key) |
| else: |
| donor_o = None |
|
|
| o_proj_shape = (hidden_size, num_heads * v_head_dim) |
| o_proj = _kaiming_init(o_proj_shape) |
| if donor_o is not None: |
| |
| rows_to_copy = min(donor_o.shape[0], o_proj.shape[0]) |
| cols_to_copy = min(donor_o.shape[1], o_proj.shape[1]) |
| o_proj[:rows_to_copy, :cols_to_copy] = donor_o[:rows_to_copy, :cols_to_copy] |
| donor_param_count += rows_to_copy * cols_to_copy |
| remaining = o_proj.numel() - rows_to_copy * cols_to_copy |
| if remaining > 0: |
| reinit_param_count += remaining |
| donor_params.add(f"{prefix}.self_attn.o_proj.weight") |
| else: |
| reinit_param_count += o_proj.numel() |
| reinit_params.add(f"{prefix}.self_attn.o_proj.weight") |
| spider_sd[f"{prefix}.self_attn.o_proj.weight"] = o_proj |
|
|
| |
| for norm_name in ["input_layernorm.weight", "post_attention_layernorm.weight"]: |
| if donor_prefix is not None: |
| donor_norm_key = f"{donor_prefix}.{norm_name}" |
| donor_norm = donor_state_dict.get(donor_norm_key) |
| else: |
| donor_norm = None |
|
|
| if donor_norm is not None and donor_norm.shape == (hidden_size,): |
| spider_sd[f"{prefix}.{norm_name}"] = donor_norm.clone() |
| donor_param_count += donor_norm.numel() |
| donor_params.add(f"{prefix}.{norm_name}") |
| elif donor_norm is not None and donor_norm.shape[0] != hidden_size: |
| |
| adapted_norm = torch.ones(hidden_size, dtype=torch.float32) |
| copy_size = min(donor_norm.shape[0], hidden_size) |
| adapted_norm[:copy_size] = donor_norm[:copy_size] |
| spider_sd[f"{prefix}.{norm_name}"] = adapted_norm |
| donor_param_count += copy_size |
| reinit_param_count += hidden_size - copy_size |
| donor_params.add(f"{prefix}.{norm_name}") |
| else: |
| ln = torch.ones(hidden_size, dtype=torch.float32) |
| spider_sd[f"{prefix}.{norm_name}"] = ln |
| reinit_param_count += ln.numel() |
| reinit_params.add(f"{prefix}.{norm_name}") |
|
|
| |
| if is_recurrent: |
| |
| |
| |
| |
| |
|
|
| |
| |
| if donor_layer_idx is not None: |
| qwen_layer_for_ffn = donor_layer_idx |
| else: |
| qwen_layer_for_ffn = None |
|
|
| |
| |
| |
| shared_up_key = f"{prefix}.moe.shared_up.weight" |
| shared_up_shape = (spider_config.shared_intermediate_size, hidden_size) |
| if qwen_layer_for_ffn is not None: |
| donor_up_key = f"model.layers.{qwen_layer_for_ffn}.mlp.up_proj.weight" |
| donor_up = donor_state_dict.get(donor_up_key) |
| else: |
| donor_up = None |
|
|
| if donor_up is not None and donor_up.shape == shared_up_shape: |
| spider_sd[shared_up_key] = donor_up.clone().float() |
| donor_param_count += donor_up.numel() |
| donor_params.add(shared_up_key) |
| elif donor_up is not None: |
| shared_up_w = _kaiming_init(shared_up_shape) |
| rows_copy = min(donor_up.shape[0], shared_up_shape[0]) |
| cols_copy = min(donor_up.shape[1], shared_up_shape[1]) |
| shared_up_w[:rows_copy, :cols_copy] = donor_up[:rows_copy, :cols_copy].float() |
| spider_sd[shared_up_key] = shared_up_w |
| donor_param_count += rows_copy * cols_copy |
| reinit_param_count += shared_up_w.numel() - rows_copy * cols_copy |
| donor_params.add(shared_up_key) |
| else: |
| spider_sd[shared_up_key] = _kaiming_init(shared_up_shape) |
| reinit_param_count += shared_up_shape[0] * shared_up_shape[1] |
| reinit_params.add(shared_up_key) |
|
|
| |
| |
| |
| shared_down_key = f"{prefix}.moe.shared_down.weight" |
| shared_down_shape = (hidden_size, spider_config.shared_intermediate_size) |
| if qwen_layer_for_ffn is not None: |
| donor_down_key = f"model.layers.{qwen_layer_for_ffn}.mlp.down_proj.weight" |
| donor_down = donor_state_dict.get(donor_down_key) |
| else: |
| donor_down = None |
|
|
| if donor_down is not None and donor_down.shape == shared_down_shape: |
| spider_sd[shared_down_key] = donor_down.clone().float() |
| donor_param_count += donor_down.numel() |
| donor_params.add(shared_down_key) |
| elif donor_down is not None: |
| shared_down_w = _kaiming_init(shared_down_shape) |
| rows_copy = min(donor_down.shape[0], shared_down_shape[0]) |
| cols_copy = min(donor_down.shape[1], shared_down_shape[1]) |
| shared_down_w[:rows_copy, :cols_copy] = donor_down[:rows_copy, :cols_copy].float() |
| spider_sd[shared_down_key] = shared_down_w |
| donor_param_count += rows_copy * cols_copy |
| reinit_param_count += shared_down_w.numel() - rows_copy * cols_copy |
| donor_params.add(shared_down_key) |
| else: |
| spider_sd[shared_down_key] = _kaiming_init(shared_down_shape) |
| reinit_param_count += shared_down_shape[0] * shared_down_shape[1] |
| reinit_params.add(shared_down_key) |
|
|
| |
| |
| |
| shared_expert_inter = spider_config.shared_expert_intermediate_size |
| if qwen_layer_for_ffn is not None: |
| donor_gate_key = f"model.layers.{qwen_layer_for_ffn}.mlp.gate_proj.weight" |
| donor_se_up_key = f"model.layers.{qwen_layer_for_ffn}.mlp.up_proj.weight" |
| donor_se_down_key = f"model.layers.{qwen_layer_for_ffn}.mlp.down_proj.weight" |
| donor_se_gate = donor_state_dict.get(donor_gate_key) |
| donor_se_up = donor_state_dict.get(donor_se_up_key) |
| donor_se_down = donor_state_dict.get(donor_se_down_key) |
| else: |
| donor_se_gate = donor_se_up = donor_se_down = None |
|
|
| for proj_name, spider_shape in [ |
| ("gate_proj", (shared_expert_inter, hidden_size)), |
| ("up_proj", (shared_expert_inter, hidden_size)), |
| ("down_proj", (hidden_size, shared_expert_inter)), |
| ]: |
| key = f"{prefix}.moe.shared_expert.{proj_name}.weight" |
| w = _kaiming_init(spider_shape) |
|
|
| if proj_name in ("gate_proj", "up_proj"): |
| donor_src = donor_se_gate if proj_name == "gate_proj" else donor_se_up |
| if donor_src is not None: |
| rows_copy = min(donor_src.shape[0], spider_shape[0]) |
| cols_copy = min(donor_src.shape[1], spider_shape[1]) |
| w[:rows_copy, :cols_copy] = donor_src[:rows_copy, :cols_copy].float() |
| donor_param_count += rows_copy * cols_copy |
| reinit_param_count += w.numel() - rows_copy * cols_copy |
| donor_params.add(key) |
| else: |
| reinit_param_count += w.numel() |
| reinit_params.add(key) |
| else: |
| if donor_se_down is not None: |
| rows_copy = min(donor_se_down.shape[0], spider_shape[0]) |
| cols_copy = min(donor_se_down.shape[1], spider_shape[1]) |
| w[:rows_copy, :cols_copy] = donor_se_down[:rows_copy, :cols_copy].float() |
| donor_param_count += rows_copy * cols_copy |
| reinit_param_count += w.numel() - rows_copy * cols_copy |
| donor_params.add(key) |
| else: |
| reinit_param_count += w.numel() |
| reinit_params.add(key) |
|
|
| spider_sd[key] = w |
|
|
| |
|
|
| |
| lora_down = _kaiming_init((spider_config.lora_rank, hidden_size)) |
| lora_B = torch.zeros(spider_config.lora_rank, hidden_size, dtype=torch.float32) |
| lora_scale = torch.zeros(spider_config.max_loop_iters, spider_config.lora_rank, dtype=torch.float32) |
| spider_sd[f"{prefix}.lora_adapter.down.weight"] = lora_down |
| spider_sd[f"{prefix}.lora_adapter.B"] = lora_B |
| spider_sd[f"{prefix}.lora_adapter.scale.weight"] = lora_scale |
| reinit_param_count += lora_down.numel() + lora_B.numel() + lora_scale.numel() |
| reinit_params.add(f"{prefix}.lora_adapter.down.weight") |
|
|
| |
| halt_w = _kaiming_init((1, hidden_size)) |
| halt_b = _zeros_init((1,)) |
| spider_sd[f"{prefix}.act_halting.halt_predictor.weight"] = halt_w |
| spider_sd[f"{prefix}.act_halting.halt_predictor.bias"] = halt_b |
| reinit_param_count += halt_w.numel() + halt_b.numel() |
| reinit_params.add(f"{prefix}.act_halting.halt_predictor.weight") |
|
|
| |
| if layer_idx in spider_config.engram_layers: |
| engram_mem_dim = spider_config.engram_heads * spider_config.engram_dim |
| engram_W_k = _kaiming_init((hidden_size, engram_mem_dim * 2)) |
| engram_W_v = _kaiming_init((hidden_size, engram_mem_dim * 2)) |
| engram_conv_w = _kaiming_init((hidden_size, 1, 4)) |
| engram_conv_b = _zeros_init((hidden_size,)) |
| engram_q_norm = _ones_init((hidden_size,)) |
| engram_k_norm = _ones_init((hidden_size,)) |
| engram_embed = torch.zeros( |
| 2, spider_config.engram_heads, spider_config.engram_table_size, spider_config.engram_dim |
| ) |
| engram_hash = torch.arange(spider_config.engram_heads * 2, dtype=torch.float32) |
| post_engram_norm = _ones_init((hidden_size,)) |
|
|
| spider_sd[f"{prefix}.engram.W_k.weight"] = engram_W_k |
| spider_sd[f"{prefix}.engram.W_v.weight"] = engram_W_v |
| spider_sd[f"{prefix}.engram.conv.weight"] = engram_conv_w |
| spider_sd[f"{prefix}.engram.conv.bias"] = engram_conv_b |
| spider_sd[f"{prefix}.engram.q_norm.weight"] = engram_q_norm |
| spider_sd[f"{prefix}.engram.k_norm.weight"] = engram_k_norm |
| spider_sd[f"{prefix}.engram.embed"] = engram_embed |
| spider_sd[f"{prefix}.engram.hash_seeds"] = engram_hash |
| spider_sd[f"{prefix}.post_engram_layernorm.weight"] = post_engram_norm |
|
|
| engram_params = (engram_W_k.numel() + engram_W_v.numel() + engram_conv_w.numel() + |
| engram_conv_b.numel() + engram_q_norm.numel() + engram_k_norm.numel() + |
| engram_embed.numel() + engram_hash.numel() + post_engram_norm.numel()) |
| reinit_param_count += engram_params |
| else: |
| |
| |
| |
| dense_inter = spider_config.prelude_coda_intermediate_size |
| if donor_layer_idx is not None: |
| donor_gate_key = f"model.layers.{donor_layer_idx}.mlp.gate_proj.weight" |
| donor_up_key = f"model.layers.{donor_layer_idx}.mlp.up_proj.weight" |
| donor_down_key = f"model.layers.{donor_layer_idx}.mlp.down_proj.weight" |
| donor_d_gate = donor_state_dict.get(donor_gate_key) |
| donor_d_up = donor_state_dict.get(donor_up_key) |
| donor_d_down = donor_state_dict.get(donor_down_key) |
| else: |
| donor_d_gate = donor_d_up = donor_d_down = None |
|
|
| for proj_name, shape, donor_src in [ |
| ("gate_proj", (dense_inter, hidden_size), donor_d_gate), |
| ("up_proj", (dense_inter, hidden_size), donor_d_up), |
| ("down_proj", (hidden_size, dense_inter), donor_d_down), |
| ]: |
| w = _kaiming_init(shape) |
| key = f"{prefix}.ffn.{proj_name}.weight" |
|
|
| if donor_src is not None: |
| if proj_name in ("gate_proj", "up_proj"): |
| rows_copy = min(donor_src.shape[0], shape[0]) |
| cols_copy = min(donor_src.shape[1], shape[1]) |
| w[:rows_copy, :cols_copy] = donor_src[:rows_copy, :cols_copy].float() |
| else: |
| rows_copy = min(donor_src.shape[0], shape[0]) |
| cols_copy = min(donor_src.shape[1], shape[1]) |
| w[:rows_copy, :cols_copy] = donor_src[:rows_copy, :cols_copy].float() |
| donor_param_count += rows_copy * cols_copy |
| reinit_param_count += w.numel() - rows_copy * cols_copy |
| donor_params.add(key) |
| else: |
| reinit_param_count += w.numel() |
| reinit_params.add(key) |
|
|
| spider_sd[key] = w |
|
|
| |
| log_A = torch.full((hidden_size,), -2.0) |
| delta_t = torch.tensor(1.0) |
| B_weight = torch.randn(hidden_size, hidden_size) * 0.01 |
| spider_sd["model.injection.log_A"] = log_A |
| spider_sd["model.injection.delta_t"] = delta_t |
| spider_sd["model.injection.B.weight"] = B_weight |
| reinit_param_count += log_A.numel() + delta_t.numel() + B_weight.numel() |
| reinit_params.add("model.injection.B.weight") |
|
|
| |
| donor_final_norm = donor_state_dict.get("model.norm.weight") |
| if donor_final_norm is not None and donor_final_norm.shape == (hidden_size,): |
| spider_sd["model.norm.weight"] = donor_final_norm.clone() |
| donor_param_count += donor_final_norm.numel() |
| donor_params.add("model.norm.weight") |
| elif donor_final_norm is not None: |
| |
| adapted_norm = torch.ones(hidden_size, dtype=torch.float32) |
| copy_size = min(donor_final_norm.shape[0], hidden_size) |
| adapted_norm[:copy_size] = donor_final_norm[:copy_size] |
| spider_sd["model.norm.weight"] = adapted_norm |
| donor_param_count += copy_size |
| reinit_param_count += hidden_size - copy_size |
| donor_params.add("model.norm.weight") |
| else: |
| spider_sd["model.norm.weight"] = torch.ones(hidden_size, dtype=torch.float32) |
| reinit_param_count += hidden_size |
| reinit_params.add("model.norm.weight") |
|
|
| |
| halt_w = _kaiming_init((1, hidden_size)) |
| halt_b = _zeros_init((1,)) |
| spider_sd["model.act_halting.halt_predictor.weight"] = halt_w |
| spider_sd["model.act_halting.halt_predictor.bias"] = halt_b |
| reinit_param_count += halt_w.numel() + halt_b.numel() |
|
|
| |
| spider_sd = split_dense_to_moe(spider_sd, spider_config, noise_scale=noise_scale) |
|
|
| |
| for layer_idx in range(spider_config.num_hidden_layers): |
| rec_prefix = f"model.recurrent_layers.{layer_idx}.moe" |
| |
| for core_key in [f"{rec_prefix}.W_gate", f"{rec_prefix}.W_transform"]: |
| if core_key in spider_sd and core_key not in reinit_params and core_key not in donor_params: |
| reinit_param_count += spider_sd[core_key].numel() |
| reinit_params.add(core_key) |
| |
| for router_key in [f"{rec_prefix}.router.weight", f"{rec_prefix}.router.bias"]: |
| if router_key in spider_sd and router_key not in reinit_params and router_key not in donor_params: |
| reinit_param_count += spider_sd[router_key].numel() |
| reinit_params.add(router_key) |
|
|
| |
| total_params = donor_param_count + reinit_param_count |
| if total_params > 0: |
| donor_pct = (donor_param_count / total_params) * 100.0 |
| reinit_pct = (reinit_param_count / total_params) * 100.0 |
| else: |
| donor_pct = 0.0 |
| reinit_pct = 0.0 |
|
|
| transfer_coverage = { |
| "donor_params": donor_param_count, |
| "reinit_params": reinit_param_count, |
| "total_params": total_params, |
| "donor_pct": round(donor_pct, 2), |
| "reinit_pct": round(reinit_pct, 2), |
| "donor_keys": sorted(donor_params), |
| "reinit_keys": sorted(reinit_params), |
| } |
|
|
| |
| print("=" * 60) |
| print("Weight Transfer Report") |
| print("=" * 60) |
| print(f" Donor: Qwen3.5-2B ({donor_config.get('num_hidden_layers', '?')} layers)") |
| print(f" Target: Spider-FLEXITOKENS ({spider_config.prelude_layers}+{spider_config.num_hidden_layers}+{spider_config.coda_layers} layers)") |
| print(f" Full attention layers used: {len(full_attention_layers)}") |
| print(f" Layer mapping: {layer_mapping}") |
| print() |
| print(f" Total params: {total_params:>12,} ({total_params/1e6:.1f}M)") |
| print(f" Donor-originated: {donor_param_count:>12,} ({donor_param_count/1e6:.1f}M) = {donor_pct:.1f}%") |
| print(f" Reinitialized: {reinit_param_count:>12,} ({reinit_param_count/1e6:.1f}M) = {reinit_pct:.1f}%") |
| print() |
| print(f" Transfer coverage: {donor_pct:.1f}% from donor, {reinit_pct:.1f}% reinitialized") |
| print("=" * 60) |
|
|
| return { |
| "spider_state_dict": spider_sd, |
| "transfer_coverage": transfer_coverage, |
| "layer_mapping": layer_mapping, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class SpiderMoEModel(nn.Module): |
| """Spider-FLEXITOKENS model with multimodal forward pass. |
| |
| Implements the full forward pass wiring per D-11: |
| - Text bytes → embed → prelude layers → BoundaryPredictor → downsample → |
| recurrent core → upsample → coda layers → lm_head → logits |
| - Modality tokens (vision/audio/video) are injected at sentinel-marked |
| positions and bypass the BoundaryPredictor entirely. |
| - Sentinel-gated passthrough: modality_mask forces boundary=1.0 at |
| sentinel+modality positions, preventing cross-modality merges. |
| |
| This is a simplified model that implements the forward pass logic |
| without the full SpiderPortalMLA attention (which requires position |
| IDs, KV cache, etc.). It uses simple linear projections to demonstrate |
| the multimodal wiring and parameter budget. |
| """ |
|
|
| def __init__(self, config: SpiderConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.boundary_predictor = BoundaryPredictor(config) |
|
|
| |
| self.null_group = nn.Parameter(torch.zeros(config.hidden_size, dtype=torch.float32)) |
|
|
| |
| self.down_ln = nn.LayerNorm(config.hidden_size) |
|
|
| |
| self.prelude_layers = nn.ModuleList([ |
| self._make_dense_layer(config) for _ in range(config.prelude_layers) |
| ]) |
|
|
| |
| self.recurrent_layers = nn.ModuleList([ |
| self._make_moe_layer(config, i) for i in range(config.num_hidden_layers) |
| ]) |
|
|
| |
| self.coda_layers = nn.ModuleList([ |
| self._make_dense_layer(config) for _ in range(config.coda_layers) |
| ]) |
|
|
| |
| self.norm = nn.LayerNorm(config.hidden_size) |
|
|
| |
| self.injection = _LTIInjection(config) |
|
|
| |
| self.act_halting = _ACTHalting(config) |
|
|
| |
| self.lora_adapters = nn.ModuleList([ |
| _LoRAAdapter(config) for _ in range(config.num_hidden_layers) |
| ]) |
|
|
| self.loop_embed_dim = config.loop_embed_dim |
| self.max_loop_iters = config.max_loop_iters |
|
|
| def _make_dense_layer(self, config): |
| """Create a simplified dense layer (prelude/coda).""" |
| return _DenseLayer(config) |
|
|
| def _make_moe_layer(self, config, layer_idx): |
| """Create a simplified MoE layer (recurrent).""" |
| return _MoELayer(config, layer_idx) |
|
|
| def _inject_modality_features( |
| self, |
| hidden_states: torch.Tensor, |
| input_ids: torch.Tensor, |
| features: list, |
| modality: str = 'IMG', |
| ) -> torch.Tensor: |
| """Replace placeholder embeddings with actual encoder features at modality regions. |
| |
| Per D-11: Modality tokens (vision, audio, video) are injected at |
| sentinel-marked positions in the hidden_states sequence. The caller |
| constructs input_ids with sentinel tokens (e.g., IMG_START, IMG_END) |
| marking modality regions. Between sentinel pairs, the initial |
| embeddings are placeholders — this method replaces them with the |
| actual encoder features. |
| |
| T-02-06 mitigation: Validates feature shape and sentinel pair count. |
| |
| Args: |
| hidden_states: [B, L, D] hidden states after embedding. |
| input_ids: [B, L] token IDs with sentinel markers. |
| features: List of tensors, one per sentinel pair per batch item. |
| Each tensor has shape [num_tokens, hidden_size]. |
| modality: Modality type prefix ('IMG', 'AUD', 'VID'). |
| |
| Returns: |
| hidden_states with modality features injected at sentinel regions. |
| |
| Raises: |
| ValueError: If feature shape doesn't match [num_tokens, hidden_size] |
| or sentinel pair count doesn't match feature count. |
| """ |
| start_token = SENTINEL_TOKENS[f'{modality}_START'] |
| end_token = SENTINEL_TOKENS[f'{modality}_END'] |
|
|
| for b in range(hidden_states.shape[0]): |
| starts = (input_ids[b] == start_token).nonzero(as_tuple=True)[0] |
| ends = (input_ids[b] == end_token).nonzero(as_tuple=True)[0] |
|
|
| if len(starts) != len(ends): |
| raise ValueError( |
| f"Batch {b}: mismatched {modality} sentinel pairs — " |
| f"{len(starts)} {_TOKEN_NAMES_BY_ID[start_token]}(s) vs " |
| f"{len(ends)} {_TOKEN_NAMES_BY_ID[end_token]}(s)." |
| ) |
|
|
| if len(starts) != len(features): |
| raise ValueError( |
| f"Batch {b}: {modality} sentinel pair count ({len(starts)}) " |
| f"doesn't match feature count ({len(features)})." |
| ) |
|
|
| for s, e, feat in zip(starts, ends, features): |
| |
| num_tokens = e - s - 1 |
| if feat.shape[0] != num_tokens: |
| raise ValueError( |
| f"Batch {b}: {modality} feature has {feat.shape[0]} tokens " |
| f"but sentinel region has {num_tokens} positions " |
| f"(from pos {s+1} to {e-1})." |
| ) |
| if feat.shape[1] != hidden_states.shape[-1]: |
| raise ValueError( |
| f"Batch {b}: {modality} feature hidden_size {feat.shape[1]} " |
| f"doesn't match model hidden_size {hidden_states.shape[-1]}." |
| ) |
| |
| hidden_states[b, s + 1:e] = feat.to(hidden_states.dtype) |
|
|
| return hidden_states |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[list] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| vision_features: Optional[list] = None, |
| audio_features: Optional[list] = None, |
| video_features: Optional[list] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Forward pass with multimodal sentinel-gated passthrough. |
| |
| Per D-11: |
| - All positions go through embed_tokens (bytes get byte embeddings, |
| sentinels get special embeddings, modality tokens get placeholder embeddings) |
| - External encoder features are injected at sentinel-marked positions |
| - BoundaryPredictor operates on the embedded sequence with modality_mask |
| - Text bytes go through BP → downsample → recurrent → upsample → coda → logits |
| - Modality tokens bypass BP and enter downsampled sequence at sentinel positions |
| |
| Args: |
| input_ids: [B, L] token IDs with optional sentinel markers. |
| attention_mask: Optional attention mask (not used in simplified model). |
| position_ids: Optional position IDs (not used in simplified model). |
| past_key_values: Optional KV cache (not used in simplified model). |
| inputs_embeds: Optional pre-computed embeddings. |
| vision_features: Optional list of tensors, each [num_tokens, hidden_size]. |
| audio_features: Optional list of tensors, each [num_tokens, hidden_size]. |
| video_features: Optional list of tensors, each [num_tokens, hidden_size]. |
| |
| Returns: |
| logits: [B, L, vocab_size] output logits. |
| """ |
| B, L = input_ids.shape |
|
|
| |
| if inputs_embeds is not None: |
| hidden_states = inputs_embeds |
| else: |
| hidden_states = self.embed_tokens(input_ids) |
|
|
| |
| if vision_features is not None: |
| hidden_states = self._inject_modality_features( |
| hidden_states, input_ids, vision_features, 'IMG' |
| ) |
| if audio_features is not None: |
| hidden_states = self._inject_modality_features( |
| hidden_states, input_ids, audio_features, 'AUD' |
| ) |
| if video_features is not None: |
| hidden_states = self._inject_modality_features( |
| hidden_states, input_ids, video_features, 'VID' |
| ) |
|
|
| |
| for layer in self.prelude_layers: |
| hidden_states = layer(hidden_states) |
|
|
| |
| modality_mask = create_modality_mask(input_ids) |
| soft_boundaries, hard_boundaries = self.boundary_predictor( |
| hidden_states, modality_mask=modality_mask |
| ) |
|
|
| |
| |
| hidden_states_normed = self.down_ln(hidden_states) |
| null_group = self.null_group.unsqueeze(0).unsqueeze(0).expand(1, B, -1) |
| shortened = downsample(hard_boundaries, hidden_states_normed, null_group) |
| |
|
|
| |
| |
| hidden_states = shortened.permute(1, 0, 2) |
|
|
| n_loops = self.max_loop_iters |
| input_embedding = hidden_states.clone() |
|
|
| for t in range(n_loops): |
| |
| loop_emb = _loop_index_embedding(hidden_states, t, self.loop_embed_dim) |
|
|
| if t > 0: |
| |
| injection = self.injection(hidden_states, input_embedding) |
| hidden_states = hidden_states + injection |
|
|
| |
| for i, layer in enumerate(self.recurrent_layers): |
| |
| lora_out = self.lora_adapters[i](hidden_states, t) |
| hidden_states = layer(hidden_states + lora_out * 0.01) |
|
|
| |
| |
| hidden_states_sbd = hidden_states.permute(1, 0, 2) |
| hidden_states = upsample(hard_boundaries, hidden_states_sbd) |
|
|
| |
| for layer in self.coda_layers: |
| hidden_states = layer(hidden_states) |
|
|
| |
| hidden_states = self.norm(hidden_states) |
| logits = self.lm_head(hidden_states) |
|
|
| return logits |
|
|
|
|
| |
| |
| |
|
|
| class _DenseLayer(nn.Module): |
| """Simplified dense layer for prelude/coda (attention + FFN).""" |
|
|
| def __init__(self, config: SpiderConfig): |
| super().__init__() |
| self.input_layernorm = nn.LayerNorm(config.hidden_size) |
| self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) |
| |
| self.self_attn = nn.MultiheadAttention( |
| config.hidden_size, num_heads=4, batch_first=True |
| ) |
| |
| self.ffn = _SwiGLUFFN(config.hidden_size, config.prelude_coda_intermediate_size) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| attn_out, _ = self.self_attn( |
| hidden_states, hidden_states, hidden_states |
| ) |
| hidden_states = residual + attn_out |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| ffn_out = self.ffn(hidden_states) |
| hidden_states = residual + ffn_out |
|
|
| return hidden_states |
|
|
|
|
| class _MoELayer(nn.Module): |
| """Simplified MoE layer for recurrent core.""" |
|
|
| def __init__(self, config: SpiderConfig, layer_idx: int = 0): |
| super().__init__() |
| self.input_layernorm = nn.LayerNorm(config.hidden_size) |
| self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) |
| |
| self.self_attn = nn.MultiheadAttention( |
| config.hidden_size, num_heads=4, batch_first=True |
| ) |
| |
| self.moe = _SharedProjectionMoE(config) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| attn_out, _ = self.self_attn( |
| hidden_states, hidden_states, hidden_states |
| ) |
| hidden_states = residual + attn_out |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| moe_out, _z_loss = self.moe(hidden_states) |
| hidden_states = residual + moe_out |
|
|
| return hidden_states |
|
|
|
|
| class _SwiGLUFFN(nn.Module): |
| """SwiGLU FFN: gate_proj, up_proj, down_proj.""" |
|
|
| def __init__(self, hidden_size: int, intermediate_size: int): |
| super().__init__() |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class _SharedProjectionMoE(nn.Module): |
| """SharedProjectionMoE matching spider.py architecture (D-20, D-21). |
| |
| Shared up/down projections computed once per token, rank-256 expert cores |
| specialize on the shared representation. |
| """ |
|
|
| def __init__(self, config: SpiderConfig): |
| super().__init__() |
| self.num_experts = config.num_experts |
| self.num_experts_per_tok = config.num_experts_per_tok |
| self.shared_inter = config.shared_intermediate_size |
| self.expert_core_rank = config.expert_core_rank |
| self.hidden_size = config.hidden_size |
|
|
| self.shared_up = nn.Linear(config.hidden_size, config.shared_intermediate_size, bias=False) |
| self.shared_down = nn.Linear(config.shared_intermediate_size, config.hidden_size, bias=False) |
|
|
| self.W_gate = nn.Parameter( |
| torch.randn(config.num_experts, config.hidden_size, config.expert_core_rank) * 0.02 |
| ) |
| self.W_transform = nn.Parameter( |
| torch.randn(config.num_experts, config.expert_core_rank, config.shared_intermediate_size) * 0.02 |
| ) |
|
|
| self.shared_expert = _SwiGLUFFN(config.hidden_size, config.shared_expert_intermediate_size) |
|
|
| self.router = nn.Linear(config.hidden_size, config.num_experts, bias=True) |
| self.router.bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32)) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| B, L, D = x.shape |
|
|
| shared_hidden = nn.functional.silu(self.shared_up(x)) |
|
|
| shared_out = self.shared_expert(x) |
|
|
| router_logits = self.router(x) |
| router_probs = nn.functional.softmax(router_logits, dim=-1) |
|
|
| top2_probs, top2_indices = router_probs.topk(self.num_experts_per_tok, dim=-1) |
| top2_probs = top2_probs / top2_probs.sum(dim=-1, keepdim=True) |
|
|
| x_flat = x.reshape(B * L, D) |
| shared_hidden_flat = shared_hidden.reshape(B * L, self.shared_inter) |
|
|
| routed_out = torch.zeros(B * L, D, device=x.device, dtype=x.dtype) |
|
|
| for k in range(self.num_experts_per_tok): |
| expert_indices = top2_indices[:, :, k].reshape(B * L) |
| expert_weights = top2_probs[:, :, k].reshape(B * L) |
|
|
| for e in range(self.num_experts): |
| mask = (expert_indices == e) |
| if not mask.any(): |
| continue |
| expert_input = x_flat[mask] |
| expert_sh = shared_hidden_flat[mask] |
|
|
| gate = expert_input @ self.W_gate[e] |
| core = gate @ self.W_transform[e] |
| expert_output = self.shared_down(core * expert_sh) |
|
|
| routed_out[mask] += expert_weights[mask].unsqueeze(-1) * expert_output |
|
|
| routed_out = routed_out.reshape(B, L, D) |
|
|
| z_loss = (router_logits.logsumexp(dim=-1) ** 2).mean() |
|
|
| return shared_out + routed_out, z_loss |
|
|
|
|
| class _LTIInjection(nn.Module): |
| """Linear Time-Invariant injection module.""" |
|
|
| def __init__(self, config: SpiderConfig): |
| super().__init__() |
| self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) |
| self.delta_t = nn.Parameter(torch.tensor(1.0)) |
| self.B_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
| def forward(self, h_t: torch.Tensor, e: torch.Tensor) -> torch.Tensor: |
| A = torch.exp(self.log_A) |
| decay = A * self.delta_t |
| B_e = self.B_proj(e) |
| return decay.unsqueeze(0).unsqueeze(0) * B_e |
|
|
|
|
| class _ACTHalting(nn.Module): |
| """Adaptive Computation Time halting module.""" |
|
|
| def __init__(self, config: SpiderConfig): |
| super().__init__() |
| self.halt_predictor = nn.Linear(config.hidden_size, 1, bias=True) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return torch.sigmoid(self.halt_predictor(hidden_states)).squeeze(-1) |
|
|
|
|
| class _LoRAAdapter(nn.Module): |
| """LoRA adapter for per-loop adaptation in recurrent layers.""" |
|
|
| def __init__(self, config: SpiderConfig): |
| super().__init__() |
| self.down = nn.Linear(config.hidden_size, config.lora_rank, bias=False) |
| self.up = nn.Linear(config.lora_rank, config.hidden_size, bias=False) |
| |
| nn.init.zeros_(self.up.weight) |
| self.scale_embeddings = nn.Embedding(config.max_loop_iters, config.lora_rank) |
|
|
| def forward(self, x: torch.Tensor, loop_iter: int) -> torch.Tensor: |
| down = self.down(x) |
| scale = self.scale_embeddings(torch.tensor([loop_iter], device=x.device)) |
| scaled = down * scale.squeeze(0) |
| return self.up(scaled) |
|
|
|
|
| def _loop_index_embedding( |
| hidden_states: torch.Tensor, |
| loop_iter: int, |
| embed_dim: int, |
| ) -> torch.Tensor: |
| """Generate sinusoidal loop index embedding. |
| |
| Provides positional-like encoding for the loop iteration index, |
| allowing the model to differentiate between iterations of the |
| recurrent depth loop. |
| """ |
| B, L, D = hidden_states.shape |
| device = hidden_states.device |
|
|
| |
| pos = torch.tensor([loop_iter], device=device, dtype=hidden_states.dtype) |
| dim = torch.arange(embed_dim, device=device, dtype=hidden_states.dtype) |
| freq = pos / (10000 ** (2 * dim / embed_dim)) |
|
|
| |
| emb = torch.zeros(embed_dim, device=device, dtype=hidden_states.dtype) |
| emb[0::2] = torch.sin(freq[::2][:emb[0::2].shape[0]]) |
| emb[1::2] = torch.cos(freq[1::2][:emb[1::2].shape[0]]) |
|
|
| |
| emb = emb.unsqueeze(0).unsqueeze(0).expand(B, L, -1) |
| if embed_dim < D: |
| padding = torch.zeros(B, L, D - embed_dim, device=device, dtype=hidden_states.dtype) |
| emb = torch.cat([emb, padding], dim=-1) |
| elif embed_dim > D: |
| emb = emb[:, :, :D] |
|
|
| return emb |
|
|
|
|
| |
| |
| |
|
|
| def save_spider_model( |
| spider_state_dict: Dict[str, torch.Tensor], |
| config: SpiderConfig, |
| output_dir: Path, |
| ): |
| """Save Spider model state dict and config to output directory. |
| |
| Handles weight tying per safetensors pattern from init_spiderportal.py. |
| """ |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| save_sd = {} |
| for name, param in spider_state_dict.items(): |
| |
| |
| save_sd[name] = param.contiguous() |
|
|
| if config.tie_word_embeddings and "lm_head.weight" in save_sd: |
| del save_sd["lm_head.weight"] |
| print(" Note: lm_head.weight tied to embed_tokens.weight (saved once)") |
|
|
| |
| try: |
| from safetensors.torch import save_file |
| save_file(save_sd, output_dir / "model.safetensors") |
| except ImportError: |
| |
| torch.save(save_sd, output_dir / "model.pt") |
| print(" Warning: safetensors not available, saved as model.pt") |
|
|
| |
| cfg_dict = { |
| "architectures": ["SpiderForConditionalGeneration"], |
| "model_type": config.model_type, |
| "vocab_size": config.vocab_size, |
| "hidden_size": config.hidden_size, |
| "num_hidden_layers": config.num_hidden_layers, |
| "num_attention_heads": config.num_attention_heads, |
| "num_key_value_heads": config.num_key_value_heads, |
| "intermediate_size": config.intermediate_size, |
| "hidden_act": config.hidden_act, |
| "max_position_embeddings": config.max_position_embeddings, |
| "rope_theta": config.rope_theta, |
| "rope_scaling": config.rope_scaling, |
| "sliding_window": config.sliding_window, |
| "rms_norm_eps": config.rms_norm_eps, |
| "initializer_range": config.initializer_range, |
| "tie_word_embeddings": config.tie_word_embeddings, |
| "torch_dtype": config.torch_dtype, |
| |
| "num_experts": config.num_experts, |
| "num_experts_per_tok": config.num_experts_per_tok, |
| "num_shared_experts": config.num_shared_experts, |
| "router_aux_loss_coef": config.router_aux_loss_coef, |
| "shared_intermediate_size": config.shared_intermediate_size, |
| "expert_core_rank": config.expert_core_rank, |
| "shared_expert_intermediate_size": config.shared_expert_intermediate_size, |
| "prelude_coda_intermediate_size": config.prelude_coda_intermediate_size, |
| |
| "kv_lora_rank": config.kv_lora_rank, |
| "q_lora_rank": config.q_lora_rank, |
| "qk_rope_head_dim": config.qk_rope_head_dim, |
| "qk_nope_head_dim": config.qk_nope_head_dim, |
| "v_head_dim": config.v_head_dim, |
| |
| "max_loop_iters": config.max_loop_iters, |
| "act_threshold": config.act_threshold, |
| "prelude_layers": config.prelude_layers, |
| "coda_layers": config.coda_layers, |
| "lora_rank": config.lora_rank, |
| |
| "bp_d_inner": config.bp_d_inner, |
| |
| "vision_hidden_size": config.vision_hidden_size, |
| "audio_hidden_size": config.audio_hidden_size, |
| "vision_num_frames": config.vision_num_frames, |
| "vision_tokens_per_frame": config.vision_tokens_per_frame, |
| "vision_temporal_tokens": config.vision_temporal_tokens, |
| "vision_temporal_layers": config.vision_temporal_layers, |
| } |
| with open(output_dir / "config.json", "w") as f: |
| json.dump(cfg_dict, f, indent=2) |
|
|
| |
| model_file = output_dir / "model.safetensors" |
| if not model_file.exists(): |
| model_file = output_dir / "model.pt" |
| if model_file.exists(): |
| sha256 = hashlib.sha256() |
| with open(model_file, "rb") as f: |
| for chunk in iter(lambda: f.read(8192), b""): |
| sha256.update(chunk) |
| print(f" Model SHA256: {sha256.hexdigest()[:16]}...") |
| with open(output_dir / "model.sha256", "w") as f: |
| f.write(sha256.hexdigest()) |
|
|
| print(f" Saved to {output_dir}") |
| if model_file.exists(): |
| print(f" Model file size: {model_file.stat().st_size / 1e6:.1f} MB") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Transfer weights from Qwen3.5-2B to Spider-FLEXITOKENS" |
| ) |
| parser.add_argument( |
| "--donor", type=str, default="Qwen/Qwen3.5-2B", |
| help="HuggingFace model ID or local path for donor model" |
| ) |
| parser.add_argument( |
| "--output", type=str, default="models/Spider-FLEXITOKENS-init/", |
| help="Output directory for Spider model" |
| ) |
| parser.add_argument( |
| "--config", type=str, default="spider_flexitokens_997m", |
| help="Spider model configuration name" |
| ) |
| parser.add_argument( |
| "--noise-scale", type=float, default=0.02, |
| help="Noise scale for MoE expert perturbation" |
| ) |
| parser.add_argument( |
| "--dry-run", action="store_true", |
| help="Run with dummy donor (no download required)" |
| ) |
| args = parser.parse_args() |
|
|
| |
| config_map = { |
| "spider_flexitokens_997m": spider_flexitokens_997m(), |
| } |
| spider_config = config_map.get(args.config, spider_flexitokens_997m()) |
|
|
| if args.dry_run: |
| print("DRY RUN: Using dummy donor (no download)") |
| donor = create_dummy_donor(num_layers=10, full_attention_layers=list(range(10))) |
| donor_sd = donor["state_dict"] |
| donor_cfg = donor["config"] |
| else: |
| |
| print(f"Loading donor model: {args.donor}") |
| try: |
| from transformers import AutoModelForCausalLM, AutoConfig |
| donor_model = AutoModelForCausalLM.from_pretrained( |
| args.donor, torch_dtype=torch.bfloat16, device_map="cpu" |
| ) |
| donor_cfg_obj = AutoConfig.from_pretrained(args.donor) |
|
|
| |
| |
| full_attention_layers = getattr( |
| donor_cfg_obj, "full_attention_layers", None |
| ) |
| if full_attention_layers is None: |
| |
| |
| full_attention_layers = [] |
| for i in range(donor_cfg_obj.num_hidden_layers): |
| layer_cfg = getattr(donor_cfg_obj, f"layer_{i}", None) |
| if layer_cfg and getattr(layer_cfg, "attention_type", "full") == "full": |
| full_attention_layers.append(i) |
| if not full_attention_layers: |
| |
| full_attention_layers = [3, 7, 11, 15, 19, 23] |
|
|
| donor_sd = donor_model.state_dict() |
| donor_cfg = { |
| "hidden_size": donor_cfg_obj.hidden_size, |
| "num_attention_heads": donor_cfg_obj.num_attention_heads, |
| "num_key_value_heads": getattr(donor_cfg_obj, "num_key_value_heads", 2), |
| "head_dim": getattr(donor_cfg_obj, "head_dim", |
| donor_cfg_obj.hidden_size // donor_cfg_obj.num_attention_heads), |
| "intermediate_size": donor_cfg_obj.intermediate_size, |
| "vocab_size": donor_cfg_obj.vocab_size, |
| "num_hidden_layers": donor_cfg_obj.num_hidden_layers, |
| "full_attention_layers": full_attention_layers, |
| "model_type": getattr(donor_cfg_obj, "model_type", "qwen3"), |
| } |
| except ImportError: |
| print("Error: transformers library required for loading donor model.") |
| print("Install with: pip install transformers") |
| sys.exit(1) |
| except Exception as e: |
| print(f"Error loading donor model: {e}") |
| print("Use --dry-run for testing without download.") |
| sys.exit(1) |
|
|
| |
| result = transfer_qwen_to_spider( |
| donor_state_dict=donor_sd, |
| donor_config=donor_cfg, |
| spider_config=spider_config, |
| noise_scale=args.noise_scale, |
| ) |
|
|
| |
| save_spider_model( |
| spider_state_dict=result["spider_state_dict"], |
| config=spider_config, |
| output_dir=Path(args.output), |
| ) |
|
|
| print("\nWeight transfer complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|