CLIWorks commited on
Commit
b3b689e
Β·
verified Β·
1 Parent(s): 20f5e73

Upload spider.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. spider.py +1568 -0
spider.py ADDED
@@ -0,0 +1,1568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Spider: MoE + RDT (Recurrent-Depth Transformer) architecture v5.
3
+
4
+ Canonical architecture ported from mythos-fineweb-moe.py (SpiderPortal v5-Dense)
5
+ with the following adaptations per Phase 02 decisions:
6
+
7
+ - Full Spider rebrand (no SpiderPortal/SpiderPortal prefix) per D-07
8
+ - Byte-level vocab: 272 tokens (256 bytes + 16 specials) per D-06
9
+ - MLA (Multi-Latent Attention) with compressed KV cache per D-10
10
+ - Engram conditional memory at recurrent layers 1 and 4
11
+ - MoE: 16 routed experts + 1 shared expert, top-1 routing
12
+ - Sliding window attention (sliding_window=8192) with 256k context (YaRN factor=8.0)
13
+ - Weight-tied embeddings per v5 canonical config (tie_word_embeddings=True)
14
+ - LTI Injection + ACT Halting + LoRA Adapter for RDT loops
15
+ - BoundaryPredictor + downsample/upsample for FlexiToken integration
16
+ - 272-token byte-level vocab with sentinel tokens for multimodal (D-11)
17
+
18
+ Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with:
19
+ - 2x Prelude (MLA + dense FFN)
20
+ - 6x Recurrent (MLA + Engram@L1,L4 + MoE) -- with gradient checkpointing
21
+ - 2x Coda (MLA + dense FFN)
22
+ - LTI Injection + ACT Halting + LoRA Adapter
23
+
24
+ Config: hidden_size=2048, 6 recurrent layers, 16 experts, top-1 routing
25
+ """
26
+
27
+ import math
28
+ from dataclasses import dataclass, field
29
+ from typing import Dict, List, Optional, Tuple
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from torch.nn import CrossEntropyLoss
35
+
36
+
37
+ # ============================================================================
38
+ # Spider Configuration
39
+ # ============================================================================
40
+
41
+ @dataclass
42
+ class SpiderConfig:
43
+ """Spider model configuration (hidden_size=2048, byte-level vocab).
44
+
45
+ Based on mythos-fineweb-moe.py SpiderPortalConfig with byte-level
46
+ tokenization, MLA attention, and Engram memory.
47
+ """
48
+ # Core architecture
49
+ vocab_size: int = 272 # 256 bytes + 16 specials (D-06)
50
+ hidden_size: int = 2048
51
+ num_hidden_layers: int = 6 # recurrent layers
52
+ num_attention_heads: int = 16
53
+ num_key_value_heads: int = 4 # not used directly in MLA but kept for compat
54
+ intermediate_size: int = 1024
55
+ hidden_act: str = "silu"
56
+
57
+ # MoE configuration (D-20, D-21: shared-projection MoE)
58
+ num_experts: int = 32
59
+ num_experts_per_tok: int = 2
60
+ num_shared_experts: int = 1
61
+ router_aux_loss_coef: float = 0.05
62
+ shared_intermediate_size: int = 6144
63
+ expert_core_rank: int = 256
64
+ shared_expert_intermediate_size: int = 7424
65
+ prelude_coda_intermediate_size: int = 4096
66
+
67
+ # RDT configuration
68
+ max_loop_iters: int = 16
69
+ act_threshold: float = 0.5
70
+ prelude_layers: int = 2
71
+ coda_layers: int = 2
72
+ lora_rank: int = 128
73
+ loop_embed_dim: int = 128
74
+
75
+ # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048)
76
+ kv_lora_rank: int = 128
77
+ q_lora_rank: int = 256
78
+ qk_rope_head_dim: int = 64
79
+ qk_nope_head_dim: int = 64
80
+ v_head_dim: int = 64
81
+
82
+ # Engram parameters (DeepSeek conditional memory, offloaded to CPU)
83
+ engram_layers: List[int] = field(default_factory=lambda: [1, 4])
84
+ engram_ngram_orders: Tuple[int, ...] = (2, 3)
85
+ engram_hash_heads: int = 4
86
+ engram_table_size: int = 8191 # prime, sized for byte vocab=272
87
+ engram_conv_kernel: int = 4
88
+ engram_conv_dilation: int = 3
89
+ engram_dim: int = 128 # per-head embedding dimension
90
+ engram_offload: bool = True # offload embed table to CPU (DeepSeek style)
91
+
92
+ # Attention / RoPE
93
+ max_position_embeddings: int = 262144 # 256k context
94
+ rope_theta: float = 10000000.0
95
+ rope_scaling: Optional[Dict] = field(default_factory=lambda: {
96
+ "type": "yarn",
97
+ "factor": 8.0,
98
+ "original_max_position_embeddings": 32768,
99
+ })
100
+ sliding_window: int = 8192 # local attention window
101
+ attention_dropout: float = 0.0
102
+ rms_norm_eps: float = 1e-6
103
+ initializer_range: float = 0.02
104
+
105
+ # Embeddings / head
106
+ tie_word_embeddings: bool = True # per v5 canonical config
107
+
108
+ # Multimodal
109
+ vision_hidden_size: int = 2048
110
+ audio_hidden_size: int = 512
111
+ vision_num_frames: int = 60
112
+ vision_tokens_per_frame: int = 256
113
+ vision_temporal_tokens: int = 64
114
+ vision_temporal_layers: int = 2
115
+
116
+ # Metadata
117
+ model_type: str = "spider"
118
+ torch_dtype: str = "bfloat16"
119
+
120
+ # BoundaryPredictor (for FlexiToken integration)
121
+ bp_d_inner: int = 8192
122
+
123
+ @property
124
+ def head_dim(self):
125
+ return self.qk_nope_head_dim + self.qk_rope_head_dim # 128
126
+
127
+
128
+ def spider_flexitokens_997m() -> SpiderConfig:
129
+ """Spider-FLEXITOKENS 995.1M config per D-20."""
130
+ return SpiderConfig()
131
+
132
+
133
+ # ============================================================================
134
+ # Sentinel Token Vocabulary (D-06, D-11)
135
+ # ============================================================================
136
+
137
+ # 272-token vocab: 256 bytes + 16 specials
138
+ # Sentinel tokens at indices 259-264 mark modality region boundaries
139
+ SENTINEL_TOKENS = {
140
+ 'PAD': 256, 'BOS': 257, 'EOS': 258,
141
+ 'IMG_START': 259, 'IMG_END': 260,
142
+ 'AUD_START': 261, 'AUD_END': 262,
143
+ 'VID_START': 263, 'VID_END': 264,
144
+ 'MASK': 265, 'im_start': 266, 'im_end': 267,
145
+ 'prefix': 268, 'suffix': 269, 'middle': 270,
146
+ 'THINK': 271,
147
+ }
148
+
149
+ # Sentinel pairs for modality regions (start_id, end_id)
150
+ _SENTINEL_PAIRS = [
151
+ (SENTINEL_TOKENS['IMG_START'], SENTINEL_TOKENS['IMG_END']), # (259, 260)
152
+ (SENTINEL_TOKENS['AUD_START'], SENTINEL_TOKENS['AUD_END']), # (261, 262)
153
+ (SENTINEL_TOKENS['VID_START'], SENTINEL_TOKENS['VID_END']), # (263, 264)
154
+ ]
155
+
156
+ # Set of modality sentinel token IDs (259-264 only)
157
+ _MODALITY_SENTINEL_IDS = {259, 260, 261, 262, 263, 264}
158
+
159
+ # Reverse mapping (computed once at module level, per IN-01)
160
+ _TOKEN_NAMES_BY_ID = {v: k for k, v in SENTINEL_TOKENS.items()}
161
+
162
+
163
+ def is_sentinel_token(token_id: int) -> bool:
164
+ """Return True if token_id is one of the 6 modality sentinel tokens (259-264).
165
+
166
+ These are the sentinel tokens that mark modality region boundaries:
167
+ IMG_START/END, AUD_START/END, VID_START/END.
168
+ Other special tokens (PAD, BOS, EOS, MASK, etc.) are NOT modality sentinels.
169
+ """
170
+ return token_id in _MODALITY_SENTINEL_IDS
171
+
172
+
173
+ def create_modality_mask(input_ids: torch.Tensor, strict: bool = True) -> torch.Tensor:
174
+ """Create boolean mask (BΓ—L) marking sentinel and modality token positions.
175
+
176
+ Per D-11: Sentinel-gated passthrough ensures modality tokens bypass the
177
+ BoundaryPredictor entirely. This mask marks positions where:
178
+ - Sentinel tokens (IMG_START/END, AUD_START/END, VID_START/END) appear
179
+ - Modality tokens (between sentinel pairs) appear
180
+
181
+ The BoundaryPredictor uses this mask to force boundary=1.0 at these
182
+ positions, ensuring no boundary merging across modality boundaries.
183
+
184
+ Args:
185
+ input_ids: Token IDs of shape [B, L] with values in 0-271 range.
186
+ strict: If True, raise on mismatched sentinel pairs (training mode).
187
+ If False, skip mismatched pairs gracefully (generation mode).
188
+
189
+ Returns:
190
+ Boolean tensor of shape [B, L], True at sentinel+modality positions.
191
+
192
+ Raises:
193
+ ValueError: If strict=True and sentinel pairs are mismatched.
194
+ """
195
+ B, L = input_ids.shape
196
+ mask = torch.zeros(B, L, dtype=torch.bool, device=input_ids.device)
197
+
198
+ # Mark direct sentinel token positions
199
+ for sid in _MODALITY_SENTINEL_IDS:
200
+ mask |= (input_ids == sid)
201
+
202
+ # Mark regions between sentinel pairs (inclusive of sentinels)
203
+ for start_id, end_id in _SENTINEL_PAIRS:
204
+ for b in range(B):
205
+ starts = (input_ids[b] == start_id).nonzero(as_tuple=True)[0]
206
+ ends = (input_ids[b] == end_id).nonzero(as_tuple=True)[0]
207
+
208
+ # T-02-04 mitigation: validate sentinel pairs are matched (strict mode only)
209
+ if strict and len(starts) != len(ends):
210
+ raise ValueError(
211
+ f"Batch {b}: mismatched sentinel pairs β€” "
212
+ f"{len(starts)} {_TOKEN_NAMES_BY_ID[start_id]}(s) vs "
213
+ f"{len(ends)} {_TOKEN_NAMES_BY_ID[end_id]}(s). "
214
+ f"Every {_TOKEN_NAMES_BY_ID[start_id]} must have a matching "
215
+ f"{_TOKEN_NAMES_BY_ID[end_id]}."
216
+ )
217
+
218
+ # Match pairs min(starts, ends) β€” skip unmatched in non-strict mode
219
+ n_pairs = min(len(starts), len(ends))
220
+ for i in range(n_pairs):
221
+ s, e = starts[i].item(), ends[i].item()
222
+ if s > e:
223
+ if strict:
224
+ raise ValueError(
225
+ f"Batch {b}: {_TOKEN_NAMES_BY_ID[start_id]} at position {s} "
226
+ f"appears after {_TOKEN_NAMES_BY_ID[end_id]} at position {e}. "
227
+ f"Sentinel pairs must be properly ordered."
228
+ )
229
+ continue
230
+ mask[b, s:e + 1] = True
231
+
232
+ return mask
233
+
234
+
235
+ # ============================================================================
236
+ # RMSNorm
237
+ # ============================================================================
238
+
239
+ class SpiderRMSNorm(nn.Module):
240
+ """RMS normalization (bf16-only, no dtype conversions)."""
241
+
242
+ def __init__(self, hidden_size, eps=1e-6):
243
+ super().__init__()
244
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.float32)) # IN-02: RMSNorm weight is float32 per convention
245
+ self.variance_epsilon = eps
246
+
247
+ def forward(self, hidden_states):
248
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
249
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
250
+ return self.weight * hidden_states
251
+
252
+
253
+ # ============================================================================
254
+ # MLA: Multi-Latent Attention (DeepSeek-V2 style)
255
+ # ============================================================================
256
+
257
+ class SpiderMLA(nn.Module):
258
+ """Multi-Latent Attention with compressed KV cache.
259
+
260
+ For hidden_size=2048, num_heads=16:
261
+ - qk_nope_head_dim=64, qk_rope_head_dim=64 -> total head_dim=128
262
+ - kv_lora_rank=128 -> 10.7x compression vs full 2048-dim KV
263
+ - v_head_dim=64 -> value projection
264
+ - sliding_window=8192 -> local attention window
265
+ """
266
+
267
+ def __init__(self, config: SpiderConfig):
268
+ super().__init__()
269
+ self.config = config
270
+ self.hidden_size = config.hidden_size
271
+ self.num_heads = config.num_attention_heads
272
+ self.kv_lora_rank = config.kv_lora_rank
273
+ self.q_lora_rank = config.q_lora_rank
274
+ self.qk_rope_head_dim = config.qk_rope_head_dim
275
+ self.qk_nope_head_dim = config.qk_nope_head_dim
276
+ self.v_head_dim = config.v_head_dim
277
+ self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
278
+ self.sliding_window = getattr(config, 'sliding_window', 0)
279
+
280
+ # Q projection: optional low-rank -> full Q
281
+ if self.q_lora_rank > 0:
282
+ self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False)
283
+ self.q_a_layernorm = SpiderRMSNorm(self.q_lora_rank)
284
+ self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False)
285
+ else:
286
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
287
+
288
+ # KV compression: hidden -> kv_lora_rank (shared latent)
289
+ self.kv_a_proj_with_mqa = nn.Linear(
290
+ config.hidden_size,
291
+ self.kv_lora_rank + self.qk_rope_head_dim,
292
+ bias=False,
293
+ )
294
+ self.kv_a_layernorm = SpiderRMSNorm(self.kv_lora_rank)
295
+ # Decompress: kv_lora_rank -> nope heads + v heads
296
+ self.kv_b_proj = nn.Linear(
297
+ self.kv_lora_rank,
298
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
299
+ bias=False,
300
+ )
301
+ # Output projection: [hidden_size, num_heads * v_head_dim]
302
+ # Per D-08 and MLA architecture: o_proj maps from num_heads*v_head_dim back to hidden_size
303
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False)
304
+
305
+ # RoPE frequencies
306
+ rope_scaling = getattr(config, 'rope_scaling', None)
307
+ if rope_scaling and rope_scaling.get("type") == "yarn":
308
+ factor = rope_scaling.get("factor", 1.0)
309
+ orig_max_pos = rope_scaling.get(
310
+ "original_max_position_embeddings", config.max_position_embeddings
311
+ )
312
+ inv_freq = self._compute_yarn_inv_freq(
313
+ self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos
314
+ )
315
+ else:
316
+ inv_freq = 1.0 / (
317
+ config.rope_theta
318
+ ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim)
319
+ )
320
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
321
+
322
+ @staticmethod
323
+ def _compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
324
+ dim = head_dim
325
+ orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
326
+ pos_freqs = torch.arange(0, dim, 2).float() / dim
327
+ beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
328
+ scale = torch.where(
329
+ beta < beta_slow, torch.ones_like(beta),
330
+ torch.where(
331
+ beta > beta_fast, torch.ones_like(beta) / factor,
332
+ 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)
333
+ )
334
+ )
335
+ return orig_inv_freq * scale
336
+
337
+ def _rotate_half(self, x):
338
+ x1 = x[..., :x.shape[-1] // 2]
339
+ x2 = x[..., x.shape[-1] // 2:]
340
+ return torch.cat((-x2, x1), dim=-1)
341
+
342
+ def _apply_rotary(self, x, cos, sin):
343
+ return (x * cos) + (self._rotate_half(x) * sin)
344
+
345
+ def forward(
346
+ self,
347
+ hidden_states: torch.Tensor,
348
+ attention_mask=None,
349
+ position_ids=None,
350
+ past_key_value=None,
351
+ use_cache=False,
352
+ ):
353
+ bsz, q_len, _ = hidden_states.size()
354
+
355
+ # Q projection
356
+ if self.q_lora_rank > 0:
357
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
358
+ else:
359
+ q = self.q_proj(hidden_states)
360
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
361
+ q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
362
+
363
+ # KV: compress to latent, then decompress
364
+ kv_hidden = self.kv_a_proj_with_mqa(hidden_states)
365
+ kv_latent, k_rope = torch.split(
366
+ kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
367
+ )
368
+ kv_latent_norm = self.kv_a_layernorm(kv_latent)
369
+ kv_b_out = self.kv_b_proj(kv_latent_norm)
370
+ k_nope, v = torch.split(
371
+ kv_b_out,
372
+ [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim],
373
+ dim=-1,
374
+ )
375
+
376
+ k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
377
+ v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2)
378
+ k_rope = k_rope.unsqueeze(1) # [B, 1, L, qk_rope_head_dim]
379
+
380
+ # RoPE on Q and K rope parts
381
+ if position_ids is None:
382
+ position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
383
+ max_pos = position_ids.max().item() + 1
384
+ seq_len = max(max_pos, q_len)
385
+ t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
386
+ freqs = torch.outer(t, self.inv_freq)
387
+ emb = torch.cat((freqs, freqs), dim=-1)
388
+ cos, sin = emb.cos(), emb.sin()
389
+ cos_full = cos[position_ids].unsqueeze(1)
390
+ sin_full = sin[position_ids].unsqueeze(1)
391
+
392
+ q_rope = self._apply_rotary(q_rope, cos_full, sin_full)
393
+ k_rope = self._apply_rotary(k_rope, cos_full, sin_full)
394
+
395
+ # Assemble full K
396
+ k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1)
397
+ k_full = torch.cat([k_nope, k_rope_expanded], dim=-1)
398
+ q_full = torch.cat([q_nope, q_rope], dim=-1)
399
+
400
+ # KV cache
401
+ past_kv = None
402
+ if past_key_value is not None:
403
+ k_full = torch.cat([past_key_value[0], k_full], dim=2)
404
+ v = torch.cat([past_key_value[1], v], dim=2)
405
+ if use_cache:
406
+ past_kv = (k_full, v)
407
+
408
+ # Attention with SDPA
409
+ attn_mask = None
410
+ if self.sliding_window > 0 and k_full.shape[2] > self.sliding_window:
411
+ kv_len = k_full.shape[2]
412
+ q_positions = torch.arange(kv_len - q_len, kv_len, device=q_full.device)
413
+ k_positions = torch.arange(kv_len, device=q_full.device)
414
+ diff = q_positions.unsqueeze(1) - k_positions.unsqueeze(0)
415
+ causal = diff >= 0
416
+ window = diff < self.sliding_window
417
+ attn_mask = (causal & window).float().unsqueeze(0).unsqueeze(0)
418
+ attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf'))
419
+
420
+ attn_output = F.scaled_dot_product_attention(
421
+ q_full, k_full, v,
422
+ attn_mask=attn_mask,
423
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
424
+ is_causal=(attn_mask is None),
425
+ )
426
+ attn_output = attn_output.transpose(1, 2).contiguous()
427
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
428
+ return self.o_proj(attn_output), past_kv
429
+
430
+
431
+ # ============================================================================
432
+ # Engram: Conditional Memory via Scalable Lookup (DeepSeek style)
433
+ # ============================================================================
434
+
435
+ def _tokenizer_compress(token_ids, vocab_size=272):
436
+ """Simulate NFKC + lowercase canonical ID projection.
437
+
438
+ Per D-06: vocab_size=272 for byte-level Spider vocab.
439
+ """
440
+ return token_ids % (vocab_size * 77 // 100)
441
+
442
+
443
+ class SpiderEngram(nn.Module):
444
+ """Conditional memory module via NN-gram lookup.
445
+
446
+ Applied only at specific recurrent layers (config.engram_layers).
447
+ Ported from SpiderPortalEngram in mythos-fineweb-moe.py.
448
+ """
449
+
450
+ def __init__(self, config: SpiderConfig):
451
+ super().__init__()
452
+ self.config = config
453
+ self.ngram_orders = list(config.engram_ngram_orders)
454
+ self.num_heads_per_order = config.engram_hash_heads
455
+ self.table_size = config.engram_table_size
456
+ self.d_mem = config.engram_dim
457
+
458
+ self.total_mem_dim = len(self.ngram_orders) * self.num_heads_per_order * self.d_mem
459
+
460
+ # Stacked embedding table with offsets: [orders, heads, table_size, d_mem]
461
+ # Per DeepSeek Engram: static memory, offloaded to CPU, accessed via deterministic hash.
462
+ embed_data = torch.randn(len(self.ngram_orders), self.num_heads_per_order, self.table_size, self.d_mem) * 0.02
463
+ if config.engram_offload:
464
+ self.register_buffer("embed", embed_data, persistent=True)
465
+ else:
466
+ self.embed = nn.Parameter(embed_data)
467
+
468
+ # Seeds per (order, head) in a stable head_counter ordering.
469
+ seeds = []
470
+ for _order in self.ngram_orders:
471
+ for h in range(self.num_heads_per_order):
472
+ seeds.append((h + 1) * 2654435761)
473
+ self.register_buffer("hash_seeds", torch.tensor(seeds, dtype=torch.int64), persistent=False)
474
+
475
+ self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
476
+ self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False)
477
+
478
+ self.conv = nn.Conv1d(
479
+ config.hidden_size, config.hidden_size,
480
+ kernel_size=config.engram_conv_kernel,
481
+ padding=config.engram_conv_kernel - 1,
482
+ groups=config.hidden_size,
483
+ )
484
+ self.conv_dilation = config.engram_conv_dilation
485
+
486
+ with torch.no_grad():
487
+ self.conv.weight.zero_()
488
+ if self.conv.bias is not None:
489
+ self.conv.bias.zero_()
490
+
491
+ self.q_norm = SpiderRMSNorm(config.hidden_size)
492
+ self.k_norm = SpiderRMSNorm(config.hidden_size)
493
+
494
+ def _compute_hash(self, compressed, n, head_counter, bsz, seq_len):
495
+ """Compute n-gram hash indices (PyTorch-only path, no Numba/CUDA dependency)."""
496
+ pad = torch.zeros(bsz, n - 1, dtype=compressed.dtype, device=compressed.device)
497
+ padded = torch.cat([pad, compressed], dim=1)
498
+ ngrams = torch.stack([padded[:, i : i + seq_len] for i in range(n)], dim=-1)
499
+ h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed.device)
500
+ for i in range(n):
501
+ h_val = h_val * 31 + ngrams[:, :, i].to(torch.int64)
502
+ h_val = h_val % self.table_size
503
+ return h_val
504
+
505
+ def _retrieve(self, token_ids):
506
+ """Retrieve memory vectors for a batch of token sequences."""
507
+ bsz, seq_len = token_ids.shape
508
+ compressed = _tokenizer_compress(token_ids)
509
+
510
+ # PyTorch fallback (CPU and GPU, no external kernel dependency)
511
+ all_parts = []
512
+ head_counter = 0
513
+ for order_idx, n in enumerate(self.ngram_orders):
514
+ h_val = self._compute_hash(compressed, n, head_counter, bsz, seq_len)
515
+ seeds_slice = self.hash_seeds[head_counter : head_counter + self.num_heads_per_order]
516
+ indices_pt = (h_val.unsqueeze(-1) * seeds_slice.view(1, 1, -1)) % self.table_size
517
+ emb_table = self.embed[order_idx]
518
+ idx = indices_pt.permute(0, 2, 1).unsqueeze(-1).expand(-1, -1, -1, self.d_mem)
519
+ mem = torch.gather(emb_table.unsqueeze(0).expand(bsz, -1, -1, -1), dim=2, index=idx)
520
+ mem = mem.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads_per_order * self.d_mem)
521
+ all_parts.append(mem)
522
+ head_counter += self.num_heads_per_order
523
+ return torch.cat(all_parts, dim=-1)
524
+
525
+ def forward(self, hidden_states, token_ids, layer_id: int):
526
+ mem = self._retrieve(token_ids)
527
+
528
+ q = hidden_states
529
+ k = self.W_k(mem)
530
+ v = self.W_v(mem)
531
+ q_norm = self.q_norm(q)
532
+ k_norm = self.k_norm(k)
533
+ alpha = torch.sigmoid(
534
+ (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1])
535
+ )
536
+ v_gated = alpha * v
537
+ v_gated_t = v_gated.transpose(1, 2)
538
+ conv_out = self.conv(v_gated_t)
539
+ conv_out = conv_out[:, :, :v_gated_t.shape[-1]]
540
+ conv_out = conv_out.transpose(1, 2)
541
+
542
+ y = F.silu(conv_out) + v_gated
543
+ return y
544
+
545
+
546
+ # ============================================================================
547
+ # FFN Expert (SwiGLU)
548
+ # ============================================================================
549
+
550
+ class SpiderExpert(nn.Module):
551
+ """SwiGLU FFN expert for dense layers and MoE shared expert."""
552
+
553
+ def __init__(self, config: SpiderConfig, intermediate_size=None):
554
+ super().__init__()
555
+ inter_size = intermediate_size or config.intermediate_size
556
+ self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
557
+ self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
558
+ self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
559
+ self.act_fn = nn.SiLU()
560
+
561
+ def forward(self, hidden_states):
562
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
563
+
564
+
565
+ # ============================================================================
566
+ # Simple MoE (top-1 routing, no torchtitan dependency)
567
+ # ============================================================================
568
+
569
+ class SimpleMoE(nn.Module):
570
+ """Mixture of Experts with top-1 routing and shared expert.
571
+
572
+ This is a self-contained MoE implementation that does not depend on
573
+ torchtitan's MoE. Used by SpiderRecurrentLayer when torchtitan
574
+ is not available (e.g., during weight transfer and testing).
575
+ """
576
+
577
+ def __init__(self, config: SpiderConfig):
578
+ super().__init__()
579
+ self.num_experts = config.num_experts
580
+ self.num_experts_per_tok = config.num_experts_per_tok
581
+
582
+ # Shared expert
583
+ self.shared_expert = SpiderExpert(config, intermediate_size=config.intermediate_size)
584
+
585
+ # Routed experts
586
+ self.experts = nn.ModuleList([
587
+ SpiderExpert(config, intermediate_size=config.intermediate_size)
588
+ for _ in range(config.num_experts)
589
+ ])
590
+
591
+ # Router
592
+ self.router = nn.Linear(config.hidden_size, config.num_experts, bias=True)
593
+ # router.bias is named router_bias in the state dict for compatibility
594
+ self.router.bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32)) # IN-02
595
+
596
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
597
+ """Forward pass with top-1 routing.
598
+
599
+ Returns:
600
+ Tuple of (output, aux_loss) where aux_loss is the load balancing loss.
601
+ """
602
+ B, L, D = x.shape
603
+
604
+ # Shared expert output (always applied)
605
+ shared_out = self.shared_expert(x)
606
+
607
+ # Router logits
608
+ router_logits = self.router(x) # [B, L, num_experts]
609
+ router_probs = F.softmax(router_logits, dim=-1)
610
+
611
+ # Top-1 routing
612
+ top1_indices = router_probs.argmax(dim=-1) # [B, L]
613
+ top1_probs = router_probs.gather(-1, top1_indices.unsqueeze(-1)).squeeze(-1) # [B, L]
614
+
615
+ # Compute expert outputs for top-1
616
+ x_flat = x.reshape(B * L, D)
617
+ top1_flat = top1_indices.reshape(B * L)
618
+
619
+ expert_outs = torch.zeros_like(x_flat)
620
+ for e in range(self.num_experts):
621
+ mask = (top1_flat == e)
622
+ if mask.any():
623
+ expert_input = x_flat[mask]
624
+ expert_out = self.experts[e](expert_input)
625
+ expert_outs[mask] = expert_out
626
+
627
+ expert_outs = expert_outs.reshape(B, L, D)
628
+ routed_out = expert_outs * top1_probs.unsqueeze(-1)
629
+
630
+ # Aux loss: z-loss for load balancing
631
+ z_loss = (router_logits.logsumexp(dim=-1) ** 2).mean()
632
+
633
+ return shared_out + routed_out, z_loss
634
+
635
+
636
+ # ============================================================================
637
+ # Shared-Projection MoE (D-20, D-21: top-2 routing with shared projections)
638
+ # ============================================================================
639
+
640
+ class SharedProjectionMoE(nn.Module):
641
+ """Mixture of Experts with shared projections and low-rank expert cores.
642
+
643
+ Per D-20: 32 experts, top-2 routing, shared_intermediate_size=6144.
644
+ Per D-21: Shared up/down projections computed once per token, rank-192
645
+ expert cores specialize on the shared representation.
646
+
647
+ Architecture:
648
+ - shared_up: Linear(hidden, shared_inter) β€” computed once for all experts
649
+ - shared_down: Linear(shared_inter, hidden) β€” computed once for all experts
650
+ - W_gate: [num_experts, hidden, expert_core_rank] β€” per-expert gating
651
+ - W_transform: [num_experts, expert_core_rank, shared_inter] β€” per-expert transform
652
+ - shared_expert: SpiderExpert(hidden, shared_expert_inter=4096) β€” always active
653
+
654
+ Forward: shared_hidden = SiLU(shared_up(x))
655
+ routed_out = sum(top2_weights * shared_down(core_i(shared_hidden)))
656
+ output = routed_out + shared_expert(x)
657
+ """
658
+
659
+ def __init__(self, config: SpiderConfig):
660
+ super().__init__()
661
+ self.num_experts = config.num_experts
662
+ self.num_experts_per_tok = config.num_experts_per_tok
663
+ self.shared_inter = config.shared_intermediate_size
664
+ self.expert_core_rank = config.expert_core_rank
665
+ self.hidden_size = config.hidden_size
666
+
667
+ self.shared_up = nn.Linear(config.hidden_size, config.shared_intermediate_size, bias=False)
668
+ self.shared_down = nn.Linear(config.shared_intermediate_size, config.hidden_size, bias=False)
669
+
670
+ self.W_gate = nn.Parameter(
671
+ torch.randn(config.num_experts, config.hidden_size, config.expert_core_rank) * 0.02
672
+ )
673
+ self.W_transform = nn.Parameter(
674
+ torch.randn(config.num_experts, config.expert_core_rank, config.shared_intermediate_size) * 0.02
675
+ )
676
+
677
+ self.shared_expert = SpiderExpert(config, intermediate_size=config.shared_expert_intermediate_size)
678
+
679
+ self.router = nn.Linear(config.hidden_size, config.num_experts, bias=True)
680
+ self.router.bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32))
681
+
682
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
683
+ B, L, D = x.shape
684
+
685
+ shared_hidden = F.silu(self.shared_up(x))
686
+
687
+ shared_out = self.shared_expert(x)
688
+
689
+ router_logits = self.router(x)
690
+ router_probs = F.softmax(router_logits, dim=-1)
691
+
692
+ top2_probs, top2_indices = router_probs.topk(self.num_experts_per_tok, dim=-1)
693
+ top2_probs = top2_probs / top2_probs.sum(dim=-1, keepdim=True)
694
+
695
+ x_flat = x.reshape(B * L, D)
696
+ shared_hidden_flat = shared_hidden.reshape(B * L, self.shared_inter)
697
+
698
+ routed_out = torch.zeros(B * L, D, device=x.device, dtype=x.dtype)
699
+
700
+ for k in range(self.num_experts_per_tok):
701
+ expert_indices = top2_indices[:, :, k].reshape(B * L)
702
+ expert_weights = top2_probs[:, :, k].reshape(B * L)
703
+
704
+ for e in range(self.num_experts):
705
+ mask = (expert_indices == e)
706
+ if not mask.any():
707
+ continue
708
+ expert_input = x_flat[mask]
709
+ expert_sh = shared_hidden_flat[mask]
710
+
711
+ gate = expert_input @ self.W_gate[e]
712
+ core = gate @ self.W_transform[e]
713
+ expert_output = self.shared_down(core * expert_sh)
714
+
715
+ routed_out[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
716
+
717
+ routed_out = routed_out.reshape(B, L, D)
718
+
719
+ z_loss = (router_logits.logsumexp(dim=-1) ** 2).mean()
720
+
721
+ return shared_out + routed_out, z_loss
722
+
723
+
724
+ # ============================================================================
725
+ # Prelude/Coda Dense Layer (uses MLA)
726
+ # ============================================================================
727
+
728
+ class SpiderDenseLayer(nn.Module):
729
+ """Prelude/coda dense layer with MLA attention."""
730
+
731
+ def __init__(self, config: SpiderConfig):
732
+ super().__init__()
733
+ self.self_attn = SpiderMLA(config)
734
+ dense_intermediate = config.prelude_coda_intermediate_size
735
+ self.ffn = SpiderExpert(config, intermediate_size=dense_intermediate)
736
+ self.input_layernorm = SpiderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
737
+ self.post_attention_layernorm = SpiderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
738
+
739
+ def forward(
740
+ self,
741
+ hidden_states,
742
+ attention_mask=None,
743
+ position_ids=None,
744
+ past_key_value=None,
745
+ use_cache=False,
746
+ ):
747
+ attn_input = self.input_layernorm(hidden_states)
748
+ attn_output, past_kv = self.self_attn(
749
+ attn_input, attention_mask=attention_mask,
750
+ position_ids=position_ids,
751
+ past_key_value=past_key_value,
752
+ use_cache=use_cache,
753
+ )
754
+ hidden_states = hidden_states + attn_output
755
+ ffn_input = self.post_attention_layernorm(hidden_states)
756
+ ffn_output = self.ffn(ffn_input)
757
+ hidden_states = hidden_states + ffn_output
758
+ return hidden_states, past_kv
759
+
760
+
761
+ # ============================================================================
762
+ # Recurrent Layer (uses MLA + optional Engram + MoE)
763
+ # ============================================================================
764
+
765
+ class SpiderRecurrentLayer(nn.Module):
766
+ """Recurrent layer with MLA attention, optional Engram memory, and MoE."""
767
+
768
+ def __init__(self, config: SpiderConfig, layer_idx: int, has_engram: bool = False):
769
+ super().__init__()
770
+ self.layer_idx = layer_idx
771
+ self.has_engram = has_engram
772
+ self.self_attn = SpiderMLA(config)
773
+ if has_engram:
774
+ self.engram = SpiderEngram(config)
775
+ self.moe = SharedProjectionMoE(config)
776
+ self.input_layernorm = SpiderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
777
+ self.post_attention_layernorm = SpiderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
778
+ self.post_engram_layernorm = (
779
+ SpiderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
780
+ if has_engram else None
781
+ )
782
+
783
+ def forward(
784
+ self,
785
+ hidden_states,
786
+ token_ids=None,
787
+ attention_mask=None,
788
+ position_ids=None,
789
+ past_key_value=None,
790
+ use_cache=False,
791
+ ):
792
+ attn_input = self.input_layernorm(hidden_states)
793
+ attn_output, past_kv = self.self_attn(
794
+ attn_input, attention_mask=attention_mask,
795
+ position_ids=position_ids,
796
+ past_key_value=past_key_value,
797
+ use_cache=use_cache,
798
+ )
799
+ hidden_states = hidden_states + attn_output
800
+
801
+ if self.has_engram and token_ids is not None:
802
+ engram_out = self.engram(hidden_states, token_ids, layer_id=self.layer_idx)
803
+ hidden_states = hidden_states + engram_out
804
+ if self.post_engram_layernorm is not None:
805
+ hidden_states = self.post_engram_layernorm(hidden_states)
806
+
807
+ ffn_input = self.post_attention_layernorm(hidden_states)
808
+ ffn_output, aux_loss = self.moe(ffn_input)
809
+ hidden_states = hidden_states + ffn_output
810
+ return hidden_states, aux_loss, past_kv
811
+
812
+
813
+ # ============================================================================
814
+ # BoundaryPredictor (D-04, D-11)
815
+ # ============================================================================
816
+
817
+ class BoundaryPredictor(nn.Module):
818
+ """Boundary predictor for learnable byte-level tokenization.
819
+
820
+ 2-layer MLP that predicts merge boundaries between tokens.
821
+ Per D-11: When modality_mask is provided, forces boundary=1.0 at
822
+ sentinel and modality token positions, preventing cross-modality merges.
823
+
824
+ Architecture: Linear(d_model, d_inner) -> GELU -> Linear(d_inner, 1)
825
+ Uses Gumbel-Softmax straight-through estimator for differentiable
826
+ boundary decisions (ported from FLEXITOKENS fxt.py).
827
+ """
828
+
829
+ def __init__(
830
+ self,
831
+ config: SpiderConfig,
832
+ temp: float = 1.0,
833
+ threshold: float = 0.5,
834
+ ):
835
+ super().__init__()
836
+ self.temp = temp
837
+ self.threshold = threshold
838
+
839
+ self.boundary_predictor = nn.Sequential(
840
+ nn.Linear(config.hidden_size, config.bp_d_inner),
841
+ nn.GELU(),
842
+ nn.Linear(config.bp_d_inner, 1),
843
+ )
844
+
845
+ def forward(self, hidden, modality_mask=None):
846
+ """Predict boundary decisions for token merging.
847
+
848
+ Args:
849
+ hidden: Hidden states of shape [B, L, D] (batch-first per D-08).
850
+ modality_mask: Optional boolean tensor [B, L], True at positions
851
+ where sentinel/modality tokens appear. Per D-11,
852
+ forces boundary=1.0 at these positions.
853
+
854
+ Returns:
855
+ Tuple of (soft_boundaries, hard_boundaries), each [B, L].
856
+ - soft_boundaries: Differentiable boundary probabilities
857
+ - hard_boundaries: Binary boundary decisions (straight-through)
858
+ """
859
+ boundary_logits = self.boundary_predictor(hidden).squeeze(-1)
860
+ boundary_probs = torch.sigmoid(boundary_logits)
861
+
862
+ # Gumbel-Softmax straight-through for differentiable boundary decisions
863
+ bernoulli = torch.distributions.relaxed_bernoulli.RelaxedBernoulli(
864
+ temperature=self.temp,
865
+ probs=boundary_probs,
866
+ )
867
+ soft_boundaries = bernoulli.rsample()
868
+
869
+ hard_boundaries = (soft_boundaries > self.threshold).float()
870
+ # Straight-through estimator: gradient flows through soft, forward uses hard
871
+ hard_boundaries = (
872
+ hard_boundaries - soft_boundaries.detach() + soft_boundaries
873
+ )
874
+
875
+ # Per D-11: Force boundaries at sentinel/modality positions
876
+ if modality_mask is not None:
877
+ soft_boundaries = soft_boundaries.masked_fill(modality_mask, 1.0)
878
+ hard_boundaries = hard_boundaries.masked_fill(modality_mask, 1.0)
879
+
880
+ return soft_boundaries, hard_boundaries
881
+
882
+
883
+ # ============================================================================
884
+ # Downsample / Upsample (D-05, D-08, D-11)
885
+ # ============================================================================
886
+
887
+ def _downsample_common(boundaries: torch.Tensor, upsample: bool = False):
888
+ """Common helper for downsample/upsample einsum weight computation.
889
+
890
+ Computes the assignment matrix that maps original positions to groups.
891
+ Based on FLEXITOKENS shortening.py, adapted for batch-first (B*L*D) layout.
892
+
893
+ Args:
894
+ boundaries: [B, L] binary boundary tensor (1 = new group starts)
895
+ upsample: If True, compute upsample weights; else downsample weights
896
+
897
+ Returns:
898
+ Assignment tensor [B, L, S] or None if n_segments == 0
899
+ """
900
+ boundaries = boundaries.clone()
901
+ n_segments = int(boundaries.sum(dim=-1).max().item())
902
+
903
+ if upsample:
904
+ n_segments += 1
905
+
906
+ if n_segments == 0:
907
+ return None
908
+
909
+ tmp = torch.zeros_like(boundaries).unsqueeze(2) + torch.arange(
910
+ start=0, end=n_segments, device=boundaries.device, dtype=boundaries.dtype
911
+ )
912
+ hh1 = boundaries.cumsum(dim=-1)
913
+
914
+ if not upsample:
915
+ hh1 -= boundaries # Subtract current boundary so position belongs to previous group
916
+
917
+ foo = tmp - hh1.unsqueeze(-1)
918
+
919
+ # WR-01 fix: zero out unused columns for batch items with fewer segments
920
+ # When n_segments is set to the max across the batch, items with fewer
921
+ # segments have unused columns that would produce NaN on normalization.
922
+ item_segment_counts = boundaries.sum(dim=-1)
923
+ for b in range(boundaries.shape[0]):
924
+ item_segs = int(item_segment_counts[b].item())
925
+ if upsample:
926
+ item_segs += 1
927
+ if item_segs < n_segments:
928
+ foo[b, :, item_segs:] = 0
929
+
930
+ return foo
931
+
932
+
933
+ def _downsample_final(foo: torch.Tensor, upsample: bool = False) -> torch.Tensor:
934
+ """Normalize assignment weights for downsample/upsample einsum."""
935
+ autoregressive = foo != 0
936
+ lel = 1.0 - foo.float()
937
+ lel[autoregressive] = 0.0
938
+ dim = 2 if upsample else 1
939
+ lel = lel / (lel.sum(dim=dim, keepdim=True) + 1e-9)
940
+ return lel
941
+
942
+
943
+ def downsample(boundaries: torch.Tensor, hidden: torch.Tensor, null_group: torch.Tensor) -> torch.Tensor:
944
+ """Downsample hidden states using boundary decisions.
945
+
946
+ Per D-05: Exact einsum port from FLEXITOKENS shortening.py.
947
+ Per D-08: Batch-first layout [B, L, D].
948
+ Per D-11: Sentinel tokens forced to boundary=1 by modality_mask ->
949
+ downsample treats each sentinel+modality group as a separate merge
950
+ group -> groups appear intact in shortened sequence.
951
+
952
+ Args:
953
+ boundaries: [B, L] binary boundary tensor (1 = new group starts)
954
+ hidden: [B, L, D] hidden states (batch-first per D-08)
955
+ null_group: [1, B, D] null group token prepended to output
956
+
957
+ Returns:
958
+ shortened_hidden: [S, B, D] shortened sequence (LBD format for
959
+ compatibility with FLEXITOKENS upsample which expects SBD input)
960
+ """
961
+ foo = _downsample_common(boundaries, upsample=False)
962
+ if foo is None:
963
+ return null_group.repeat(1, hidden.size(0), 1)
964
+ else:
965
+ bar = _downsample_final(foo, upsample=False)
966
+ # Einsum: B*L*D @ B*L*S -> B*S*D, then transpose to S*B*D
967
+ shortened_hidden = torch.einsum('bld,bls->bsd', hidden, bar)
968
+ shortened_hidden = shortened_hidden.permute(1, 0, 2)
969
+ # Prepend null_group: [1, B, D] -> cat along dim=0 -> [S+1, B, D]
970
+ shortened_hidden = torch.cat([null_group, shortened_hidden], dim=0)
971
+ return shortened_hidden
972
+
973
+
974
+ def upsample(boundaries: torch.Tensor, shortened_hidden: torch.Tensor) -> torch.Tensor:
975
+ """Upsample shortened hidden states back to original sequence length.
976
+
977
+ Per D-05: Exact einsum port from FLEXITOKENS shortening.py.
978
+ Per D-08: Batch-first layout.
979
+
980
+ Args:
981
+ boundaries: [B, L] binary boundary tensor
982
+ shortened_hidden: [S, B, D] shortened sequence
983
+
984
+ Returns:
985
+ upsampled_hidden: [B, L, D] upsampled sequence
986
+ """
987
+ foo = _downsample_common(boundaries, upsample=True)
988
+ bar = _downsample_final(foo, upsample=True)
989
+ upsampled_hidden = torch.einsum('sbd,bls->bld', shortened_hidden, bar)
990
+ return upsampled_hidden
991
+
992
+
993
+ # ============================================================================
994
+ # LTI Injection, ACT Halting, LoRA Adapter
995
+ # ============================================================================
996
+
997
+ class LTIInjection(nn.Module):
998
+ """Linear Time-Invariant injection module."""
999
+
1000
+ def __init__(self, config: SpiderConfig):
1001
+ super().__init__()
1002
+ self.hidden_size = config.hidden_size
1003
+ self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
1004
+ self.delta_t = nn.Parameter(torch.tensor(1.0))
1005
+ self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
1006
+ with torch.no_grad():
1007
+ self.B.weight.data.normal_(mean=0.0, std=0.01)
1008
+
1009
+ def get_A(self):
1010
+ return -torch.exp(self.log_A)
1011
+
1012
+ def forward(self, h_t, e):
1013
+ A = self.get_A()
1014
+ return A * h_t + self.B(e)
1015
+
1016
+
1017
+ class ACTHalting(nn.Module):
1018
+ """Adaptive Computation Time halting module."""
1019
+
1020
+ def __init__(self, config: SpiderConfig):
1021
+ super().__init__()
1022
+ self.halt_predictor = nn.Linear(config.hidden_size, 1)
1023
+ self.threshold = config.act_threshold
1024
+
1025
+ def forward(self, hidden_states):
1026
+ return torch.sigmoid(self.halt_predictor(hidden_states))
1027
+
1028
+
1029
+ class LoRAAdapter(nn.Module):
1030
+ """LoRA adapter for per-loop adaptation in recurrent layers.
1031
+
1032
+ Per CR-01 fix: up-projection (self.B) is initialized to EXACTLY ZERO
1033
+ so that LoRA adapter output is zero at initialization -- meaning the
1034
+ model starts behaving identically to the base model. This follows
1035
+ standard LoRA convention (Hu et al., 2021).
1036
+ """
1037
+
1038
+ def __init__(self, config: SpiderConfig):
1039
+ super().__init__()
1040
+ rank = config.lora_rank
1041
+ self.down = nn.Linear(config.hidden_size, rank, bias=False)
1042
+ self.B = nn.Parameter(torch.zeros(rank, config.hidden_size, dtype=torch.float32)) # CR-01 fix: zeros, not randn*0.02; IN-02
1043
+ self.scale = nn.Embedding(config.max_loop_iters, rank)
1044
+ with torch.no_grad():
1045
+ self.scale.weight.data.zero_()
1046
+ self.down.weight.data.normal_(mean=0.0, std=0.001)
1047
+
1048
+ def forward(self, x, loop_t):
1049
+ max_t = self.scale.num_embeddings - 1
1050
+ t_idx = min(loop_t, max_t)
1051
+ s = self.scale(torch.tensor(t_idx, device=x.device))
1052
+ down = self.down(x) * s
1053
+ return down @ self.B
1054
+
1055
+
1056
+ def _loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
1057
+ """Sinusoidal loop index embedding for RDT depth differentiation."""
1058
+ freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
1059
+ angles = loop_t * freqs
1060
+ emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
1061
+ emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
1062
+ emb_full[:loop_dim] = emb
1063
+ return h + emb_full.unsqueeze(0).unsqueeze(0)
1064
+
1065
+
1066
+ def _checkpoint(func, *args, **kwargs):
1067
+ """Gradient checkpointing wrapper -- saves VRAM at ~20% compute cost."""
1068
+ if torch.is_grad_enabled():
1069
+ return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs)
1070
+ return func(*args, **kwargs)
1071
+
1072
+
1073
+ # ============================================================================
1074
+ # Full Spider Model (with FlexiToken integration)
1075
+ # ============================================================================
1076
+
1077
+ class SpiderModel(nn.Module):
1078
+ """Full RDT model with MLA attention + Engram memory + FlexiToken.
1079
+
1080
+ Architecture:
1081
+ 2x Prelude (MLA + dense FFN)
1082
+ 6x Recurrent (MLA + Engram@L1,L4 + MoE) -- with gradient checkpointing
1083
+ 2x Coda (MLA + dense FFN)
1084
+ LTI Injection + ACT Halting + LoRA Adapter
1085
+ BoundaryPredictor + downsample/upsample for FlexiToken
1086
+ """
1087
+
1088
+ def __init__(self, config: SpiderConfig):
1089
+ super().__init__()
1090
+ self.config = config
1091
+ self.prelude_layers = nn.ModuleList([
1092
+ SpiderDenseLayer(config) for _ in range(config.prelude_layers)
1093
+ ])
1094
+ self.recurrent_layers = nn.ModuleList([
1095
+ SpiderRecurrentLayer(config, i, has_engram=(i in config.engram_layers))
1096
+ for i in range(config.num_hidden_layers)
1097
+ ])
1098
+ self.coda_layers = nn.ModuleList([
1099
+ SpiderDenseLayer(config) for _ in range(config.coda_layers)
1100
+ ])
1101
+ self.norm = SpiderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1102
+ self.injection = LTIInjection(config)
1103
+ self.act_halting = ACTHalting(config)
1104
+ self.lora_adapter = LoRAAdapter(config)
1105
+ self.loop_embed_dim = config.loop_embed_dim
1106
+ self._gradient_checkpointing = False
1107
+
1108
+ def gradient_checkpointing_enable(self):
1109
+ self._gradient_checkpointing = True
1110
+
1111
+ def gradient_checkpointing_disable(self):
1112
+ self._gradient_checkpointing = False
1113
+
1114
+ def forward(
1115
+ self,
1116
+ hidden_states,
1117
+ input_embedding=None,
1118
+ attention_mask=None,
1119
+ position_ids=None,
1120
+ past_key_values=None,
1121
+ use_cache=False,
1122
+ n_loops=None,
1123
+ token_ids=None,
1124
+ hard_boundaries=None,
1125
+ ):
1126
+ n_loops = n_loops or 1
1127
+ input_embedding = input_embedding if input_embedding is not None else hidden_states
1128
+
1129
+ # Prelude layers
1130
+ for layer in self.prelude_layers:
1131
+ if self._gradient_checkpointing and torch.is_grad_enabled():
1132
+ hidden_states, _ = _checkpoint(
1133
+ layer, hidden_states,
1134
+ attention_mask=attention_mask,
1135
+ position_ids=position_ids,
1136
+ )
1137
+ else:
1138
+ hidden_states, _ = layer(
1139
+ hidden_states, attention_mask=attention_mask,
1140
+ position_ids=position_ids,
1141
+ )
1142
+
1143
+ # FlexiToken: if hard_boundaries provided, downsample before recurrent core
1144
+ if hard_boundaries is not None:
1145
+ # Apply norm before downsample
1146
+ hidden_normed = self.norm(hidden_states)
1147
+ null_group = torch.zeros(
1148
+ 1, hidden_states.shape[0], hidden_states.shape[-1],
1149
+ device=hidden_states.device, dtype=hidden_states.dtype,
1150
+ )
1151
+ shortened = downsample(hard_boundaries, hidden_normed, null_group)
1152
+ # shortened: [S, B, D] -> [B, S, D]
1153
+ hidden_states = shortened.permute(1, 0, 2)
1154
+
1155
+ # Shorten token_ids to match downsampled sequence length.
1156
+ # Take the first token in each boundary group so the Engram
1157
+ # hash-based lookup gets a representative token per group.
1158
+ # hard_boundaries: [B, L], cumsum gives group index per position.
1159
+ # Pick the first position (where boundary=1) of each group.
1160
+ if token_ids is not None:
1161
+ group_ids = hard_boundaries.cumsum(dim=-1) # [B, L], 1-based group indices
1162
+ n_groups = int(group_ids.max().item()) # number of groups
1163
+ B = hard_boundaries.shape[0]
1164
+ # For each group g (1..n_groups), find the first position where group_ids == g
1165
+ short_ids = torch.zeros(B, n_groups, device=token_ids.device, dtype=token_ids.dtype)
1166
+ for g in range(1, n_groups + 1):
1167
+ # mask of positions belonging to group g
1168
+ mask = (group_ids == g)
1169
+ # first position in group g
1170
+ first_pos = mask.float().argmax(dim=-1) # [B]
1171
+ short_ids[:, g - 1] = token_ids.gather(1, first_pos.unsqueeze(1)).squeeze(1)
1172
+ # Prepend a dummy token (0) for the null_group entry
1173
+ null_token = torch.zeros(B, 1, device=token_ids.device, dtype=token_ids.dtype)
1174
+ token_ids = torch.cat([null_token, short_ids], dim=1) # [B, S+1]
1175
+
1176
+ # After downsample, input_embedding must match the shortened sequence length
1177
+ input_embedding = hidden_states.clone()
1178
+
1179
+ # Recurrent core with RDT looping
1180
+ e = hidden_states.clone()
1181
+ B, T_seq, D = hidden_states.shape
1182
+ halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
1183
+ cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
1184
+ h_out = torch.zeros_like(hidden_states)
1185
+ total_aux_loss = 0.0
1186
+ past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
1187
+
1188
+ for t in range(n_loops):
1189
+ h_loop = _loop_index_embedding(hidden_states, t, self.loop_embed_dim)
1190
+ if t > 0:
1191
+ injection = self.injection(hidden_states, input_embedding)
1192
+ hidden_states = hidden_states + injection
1193
+
1194
+ new_past_key_values = []
1195
+ for i, layer in enumerate(self.recurrent_layers):
1196
+ hidden_states, aux_loss, past_kv = _checkpoint(
1197
+ layer, hidden_states,
1198
+ token_ids=token_ids,
1199
+ attention_mask=attention_mask,
1200
+ position_ids=position_ids,
1201
+ past_key_value=past_key_values[i] if t == 0 else None,
1202
+ use_cache=use_cache,
1203
+ )
1204
+ total_aux_loss = total_aux_loss + aux_loss
1205
+ new_past_key_values.append(past_kv)
1206
+
1207
+ lora_delta = self.lora_adapter(hidden_states, t)
1208
+ hidden_states = hidden_states + lora_delta
1209
+
1210
+ halt_prob = self.act_halting(hidden_states).squeeze(-1)
1211
+ still_running = ~halted
1212
+ remainder = (1.0 - cumulative_p).clamp(min=0)
1213
+ weight = torch.where(
1214
+ cumulative_p + halt_prob >= self.config.act_threshold,
1215
+ remainder, halt_prob,
1216
+ )
1217
+ weight = weight * still_running.to(hidden_states.dtype)
1218
+ h_out = h_out + weight.unsqueeze(-1) * hidden_states
1219
+ cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
1220
+ halted = halted | (cumulative_p >= self.config.act_threshold)
1221
+ if halted.all() and not self.training:
1222
+ break
1223
+
1224
+ never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
1225
+ hidden_states = h_out + never_halted * hidden_states
1226
+
1227
+ # FlexiToken: if hard_boundaries provided, upsample after recurrent core
1228
+ if hard_boundaries is not None:
1229
+ hidden_states_sbd = hidden_states.permute(1, 0, 2) # [S, B, D]
1230
+ hidden_states = upsample(hard_boundaries, hidden_states_sbd) # [B, L, D]
1231
+
1232
+ # Coda layers
1233
+ for layer in self.coda_layers:
1234
+ if self._gradient_checkpointing and torch.is_grad_enabled():
1235
+ hidden_states, _ = _checkpoint(
1236
+ layer, hidden_states,
1237
+ attention_mask=attention_mask,
1238
+ position_ids=position_ids,
1239
+ )
1240
+ else:
1241
+ hidden_states, _ = layer(
1242
+ hidden_states, attention_mask=attention_mask,
1243
+ position_ids=position_ids,
1244
+ )
1245
+
1246
+ hidden_states = self.norm(hidden_states)
1247
+ return hidden_states, total_aux_loss, new_past_key_values
1248
+
1249
+
1250
+ # ============================================================================
1251
+ # SpiderForConditionalGeneration
1252
+ # ============================================================================
1253
+
1254
+ class SpiderForConditionalGeneration(nn.Module):
1255
+ """Spider model with embedding, LM head, and FlexiToken boundary prediction.
1256
+
1257
+ Forward flow:
1258
+ 1. embed_tokens(input_ids) -> hidden_states
1259
+ 2. Inject modality features at sentinel positions
1260
+ 3. Prelude layers
1261
+ 4. BoundaryPredictor with modality_mask -> boundaries
1262
+ 5. SpiderModel (downsample -> recurrent -> upsample -> coda)
1263
+ 6. lm_head -> logits
1264
+ """
1265
+
1266
+ def __init__(self, config: SpiderConfig):
1267
+ super().__init__()
1268
+ self.config = config
1269
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
1270
+ self.boundary_predictor = BoundaryPredictor(config)
1271
+ self.model = SpiderModel(config)
1272
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1273
+ if config.tie_word_embeddings:
1274
+ self.lm_head.weight = self.embed_tokens.weight
1275
+ self.apply(self._init_weights)
1276
+
1277
+ def gradient_checkpointing_enable(self):
1278
+ self.model.gradient_checkpointing_enable()
1279
+
1280
+ def gradient_checkpointing_disable(self):
1281
+ self.model.gradient_checkpointing_disable()
1282
+
1283
+ def enable_input_require_grads(self):
1284
+ def _make_inputs_require_grad(module, input, output):
1285
+ output.requires_grad_(True)
1286
+ self.embed_tokens.register_forward_hook(_make_inputs_require_grad)
1287
+
1288
+ def _init_weights(self, module):
1289
+ if isinstance(module, nn.Linear):
1290
+ if hasattr(self, 'model') and module is self.model.injection.B:
1291
+ return # LTI injection B has its own init
1292
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1293
+ if module.bias is not None:
1294
+ module.bias.data.zero_()
1295
+ elif isinstance(module, nn.Embedding):
1296
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1297
+
1298
+ def _inject_modality_features(
1299
+ self,
1300
+ hidden_states: torch.Tensor,
1301
+ input_ids: torch.Tensor,
1302
+ features: list,
1303
+ modality: str = 'IMG',
1304
+ ) -> torch.Tensor:
1305
+ """Replace placeholder embeddings with actual encoder features at modality regions.
1306
+
1307
+ Per D-11: Modality tokens (vision, audio, video) are injected at
1308
+ sentinel-marked positions. Between sentinel pairs, the initial
1309
+ embeddings are placeholders -- this method replaces them with the
1310
+ actual encoder features.
1311
+
1312
+ T-02-06 mitigation: Validates feature shape and sentinel pair count.
1313
+ """
1314
+ start_token = SENTINEL_TOKENS[f'{modality}_START']
1315
+ end_token = SENTINEL_TOKENS[f'{modality}_END']
1316
+
1317
+ for b in range(hidden_states.shape[0]):
1318
+ starts = (input_ids[b] == start_token).nonzero(as_tuple=True)[0]
1319
+ ends = (input_ids[b] == end_token).nonzero(as_tuple=True)[0]
1320
+
1321
+ if len(starts) != len(ends):
1322
+ raise ValueError(
1323
+ f"Batch {b}: mismatched {modality} sentinel pairs -- "
1324
+ f"{len(starts)} {_TOKEN_NAMES_BY_ID[start_token]}(s) vs "
1325
+ f"{len(ends)} {_TOKEN_NAMES_BY_ID[end_token]}(s)."
1326
+ )
1327
+ if len(starts) != len(features):
1328
+ raise ValueError(
1329
+ f"Batch {b}: {modality} sentinel pair count ({len(starts)}) "
1330
+ f"doesn't match feature count ({len(features)})."
1331
+ )
1332
+
1333
+ for s, e, feat in zip(starts, ends, features):
1334
+ num_tokens = e - s - 1
1335
+ if feat.shape[0] != num_tokens:
1336
+ raise ValueError(
1337
+ f"Batch {b}: {modality} feature has {feat.shape[0]} tokens "
1338
+ f"but sentinel region has {num_tokens} positions "
1339
+ f"(from pos {s+1} to {e-1})."
1340
+ )
1341
+ if feat.shape[1] != hidden_states.shape[-1]:
1342
+ raise ValueError(
1343
+ f"Batch {b}: {modality} feature hidden_size {feat.shape[1]} "
1344
+ f"doesn't match model hidden_size {hidden_states.shape[-1]}."
1345
+ )
1346
+ hidden_states[b, s + 1:e] = feat.to(hidden_states.dtype)
1347
+
1348
+ return hidden_states
1349
+
1350
+ def forward(
1351
+ self,
1352
+ input_ids: torch.Tensor,
1353
+ attention_mask=None,
1354
+ position_ids=None,
1355
+ labels=None,
1356
+ n_loops=None,
1357
+ use_cache=False,
1358
+ vision_features=None,
1359
+ audio_features=None,
1360
+ video_features=None,
1361
+ **kwargs,
1362
+ ):
1363
+ hidden_states = self.embed_tokens(input_ids)
1364
+ model_dtype = next(self.model.parameters()).dtype
1365
+ hidden_states = hidden_states.to(model_dtype)
1366
+ input_embedding = hidden_states.clone()
1367
+
1368
+ # Inject modality features at sentinel positions
1369
+ if vision_features is not None:
1370
+ hidden_states = self._inject_modality_features(
1371
+ hidden_states, input_ids, vision_features, 'IMG'
1372
+ )
1373
+ if audio_features is not None:
1374
+ hidden_states = self._inject_modality_features(
1375
+ hidden_states, input_ids, audio_features, 'AUD'
1376
+ )
1377
+ if video_features is not None:
1378
+ hidden_states = self._inject_modality_features(
1379
+ hidden_states, input_ids, video_features, 'VID'
1380
+ )
1381
+
1382
+ # Create modality mask and predict boundaries
1383
+ modality_mask = create_modality_mask(input_ids, strict=(labels is not None))
1384
+ soft_boundaries, hard_boundaries = self.boundary_predictor(
1385
+ hidden_states, modality_mask=modality_mask
1386
+ )
1387
+
1388
+ # Run model with FlexiToken boundaries
1389
+ hidden_states, aux_loss, past_kv = self.model(
1390
+ hidden_states,
1391
+ input_embedding=input_embedding,
1392
+ attention_mask=None,
1393
+ position_ids=position_ids,
1394
+ use_cache=use_cache,
1395
+ n_loops=n_loops,
1396
+ token_ids=input_ids,
1397
+ hard_boundaries=hard_boundaries,
1398
+ )
1399
+
1400
+ logits = self.lm_head(hidden_states)
1401
+ loss = None
1402
+ if labels is not None:
1403
+ shift_logits = logits[..., :-1, :].contiguous()
1404
+ shift_labels = labels[..., 1:].contiguous()
1405
+ loss_fct = CrossEntropyLoss()
1406
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1407
+
1408
+ return {
1409
+ "loss": loss,
1410
+ "logits": logits,
1411
+ "aux_loss": aux_loss,
1412
+ "past_key_values": past_kv,
1413
+ "soft_boundaries": soft_boundaries,
1414
+ "hard_boundaries": hard_boundaries,
1415
+ }
1416
+
1417
+ @torch.inference_mode()
1418
+ def generate(
1419
+ self,
1420
+ input_ids: torch.Tensor,
1421
+ max_new_tokens: int = 100,
1422
+ temperature: float = 1.0,
1423
+ top_k: Optional[int] = None,
1424
+ n_loops: int = 1,
1425
+ use_cache: bool = True,
1426
+ boundary_mode: str = 'adaptive',
1427
+ ) -> torch.Tensor:
1428
+ """Token-level generation with compressed-prefix KV cache per D-28.
1429
+
1430
+ Strategy: Encode the prefix through prelude + BP + downsample to get
1431
+ a compressed KV cache, then autoregressively decode byte-by-byte using
1432
+ that cached prefix. The speedup comes from the prefix being shorter in
1433
+ the KV cache (~3.3x fewer entries for English text).
1434
+
1435
+ Flow:
1436
+ 1. Embed prefix β†’ prelude layers β†’ BP β†’ downsample β†’ recurrent core
1437
+ β†’ collect KV cache for compressed prefix
1438
+ 2. Coda + lm_head on last position β†’ sample first new byte
1439
+ 3. For each subsequent byte: embed β†’ recurrent (with KV cache) β†’ coda
1440
+ β†’ lm_head β†’ sample β†’ append
1441
+ 4. Stop at max_new_tokens or EOS
1442
+
1443
+ Args:
1444
+ input_ids: Prefix token IDs [B, L] (byte values 0-255 + BOS/EOS)
1445
+ max_new_tokens: Maximum number of new bytes to generate
1446
+ temperature: Sampling temperature (0 = greedy, 1.0 = default)
1447
+ top_k: If set, only sample from top-k logits
1448
+ n_loops: Number of recurrent loops during generation
1449
+ use_cache: Use KV cache for incremental decoding
1450
+ boundary_mode: 'adaptive' (threshold) or 'fixed' (top-k) for BP
1451
+
1452
+ Returns:
1453
+ Generated token IDs [B, N] where N ≀ max_new_tokens
1454
+ """
1455
+ B = input_ids.shape[0]
1456
+ device = input_ids.device
1457
+ model_dtype = next(self.model.parameters()).dtype
1458
+
1459
+ # --- Step 1: Encode prefix and collect KV cache ---
1460
+ hidden_states = self.embed_tokens(input_ids).to(model_dtype)
1461
+
1462
+ # Prelude layers (byte-level, no compression)
1463
+ for layer in self.model.prelude_layers:
1464
+ hidden_states, _ = layer(hidden_states)
1465
+
1466
+ # Boundary prediction on prefix (strict=False for generation)
1467
+ modality_mask = create_modality_mask(input_ids, strict=False)
1468
+ soft_boundaries, hard_boundaries = self.boundary_predictor(
1469
+ hidden_states, modality_mask=modality_mask
1470
+ )
1471
+
1472
+ # Apply boundary mode
1473
+ if boundary_mode == 'adaptive':
1474
+ hard_boundaries = (soft_boundaries > 0.5).float()
1475
+ hard_boundaries = hard_boundaries - soft_boundaries.detach() + soft_boundaries
1476
+ elif boundary_mode == 'fixed':
1477
+ k = max(1, int(soft_boundaries.shape[-1] / 3.3))
1478
+ topk_vals, topk_idx = soft_boundaries.topk(k, dim=-1)
1479
+ hard_boundaries = torch.zeros_like(soft_boundaries)
1480
+ hard_boundaries.scatter_(-1, topk_idx, 1.0)
1481
+ hard_boundaries = hard_boundaries - soft_boundaries.detach() + soft_boundaries
1482
+
1483
+ # Downsample prefix for compressed KV cache
1484
+ hidden_normed = self.model.norm(hidden_states)
1485
+ null_group = torch.zeros(
1486
+ 1, B, hidden_states.shape[-1], device=device, dtype=hidden_states.dtype
1487
+ )
1488
+ shortened = downsample(hard_boundaries, hidden_normed, null_group)
1489
+ hidden_states = shortened.permute(1, 0, 2) # [B, S, D]
1490
+ input_embedding = hidden_states.clone()
1491
+
1492
+ # Run through recurrent core + coda (hard_boundaries=None skips downsample/upsample)
1493
+ hidden_states, _, past_key_values = self.model(
1494
+ hidden_states,
1495
+ input_embedding=input_embedding,
1496
+ use_cache=use_cache,
1497
+ n_loops=n_loops,
1498
+ hard_boundaries=None,
1499
+ )
1500
+
1501
+ # Get logits for last position of prefix (norm + lm_head only, coda already applied)
1502
+ logits = self.lm_head(hidden_states[:, -1:, :]) # [B, 1, vocab]
1503
+ next_token = self._sample_token(logits, temperature, top_k) # [B, 1]
1504
+
1505
+ generated = [next_token]
1506
+
1507
+ # --- Step 2: Autoregressive byte-level decoding with KV cache ---
1508
+ for _ in range(max_new_tokens - 1):
1509
+ # Check EOS
1510
+ if (next_token == SENTINEL_TOKENS['EOS']).all():
1511
+ break
1512
+
1513
+ # Embed the last generated token
1514
+ hidden_states = self.embed_tokens(next_token).to(model_dtype) # [B, 1, D]
1515
+ input_embedding = hidden_states.clone()
1516
+
1517
+ if use_cache:
1518
+ # Incremental forward: 1 new token, cached prefix in past_key_values
1519
+ hidden_states, _, past_key_values = self.model(
1520
+ hidden_states,
1521
+ input_embedding=input_embedding,
1522
+ past_key_values=past_key_values,
1523
+ use_cache=True,
1524
+ n_loops=n_loops,
1525
+ hard_boundaries=None,
1526
+ )
1527
+ else:
1528
+ # Naive: re-run full forward from scratch (no KV cache)
1529
+ all_ids = torch.cat([input_ids, torch.cat(generated, dim=1)], dim=1)
1530
+ output = self.forward(
1531
+ all_ids, n_loops=n_loops, use_cache=False,
1532
+ )
1533
+ logits_full = output['logits']
1534
+ next_logits = logits_full[:, -1, :] / max(temperature, 1e-8)
1535
+ if top_k is not None and top_k > 0:
1536
+ v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
1537
+ next_logits = next_logits.masked_fill(next_logits < v[:, [-1]], float('-inf'))
1538
+ if temperature < 1e-8:
1539
+ next_token = next_logits.argmax(dim=-1, keepdim=True)
1540
+ else:
1541
+ probs = torch.softmax(next_logits, dim=-1)
1542
+ next_token = torch.multinomial(probs, num_samples=1)
1543
+ generated.append(next_token)
1544
+ continue
1545
+
1546
+ # lm_head on last position (coda + norm already applied by self.model)
1547
+ logits = self.lm_head(hidden_states[:, -1:, :]) # [B, 1, vocab]
1548
+ next_token = self._sample_token(logits, temperature, top_k)
1549
+ generated.append(next_token)
1550
+
1551
+ return torch.cat(generated, dim=1) # [B, N]
1552
+
1553
+ @staticmethod
1554
+ def _sample_token(logits: torch.Tensor, temperature: float, top_k: Optional[int]) -> torch.Tensor:
1555
+ """Sample next token from logits with temperature and top-k."""
1556
+ logits = logits.squeeze(1) # [B, vocab]
1557
+ if temperature < 1e-8:
1558
+ return logits.argmax(dim=-1, keepdim=True) # greedy
1559
+ logits = logits / temperature
1560
+ if top_k is not None and top_k > 0:
1561
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
1562
+ logits = logits.masked_fill(logits < v[:, [-1]], float('-inf'))
1563
+ probs = torch.softmax(logits, dim=-1)
1564
+ return torch.multinomial(probs, num_samples=1) # [B, 1]
1565
+
1566
+ def get_num_params(self):
1567
+ total = sum(p.numel() for p in self.parameters())
1568
+ return {"total": total, "trainable": total}