#!/usr/bin/env python3 """ SpiderPortal v5-Dense: English pretraining on FineWeb-Edu with AdamW. Architecture: RDT (2 prelude + 6 recurrent + 2 coda) with: - MLA (Multi-Latent Attention): 10.7x KV cache compression + sliding window - Engram conditional memory at recurrent layers 1 and 4 - Dense FFN (all params active, MoE conversion in Phase 2) - LTI Injection + ACT Halting + LoRA Adapter - 32k context (extendable to 256k at inference via YaRN) Config: hidden_size=2048, 6 recurrent layers, 32 experts (Phase 2), top-2 routing Single GPU: python mythos-fineweb-dense.py Multi-GPU: torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") mythos-fineweb-dense.py """ import os import math import time import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist import sys # Simple print-based logging — no file rotation, no hanging def log(msg, level="INFO"): ts = time.strftime("%Y-%m-%d %H:%M:%S") print(f"{ts} | {level} | {msg}", flush=True) # Speed up CUDA memory allocation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512,expandable_segments:True" from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, FullStateDictConfig, StateDictType, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.utils.data import IterableDataset, DataLoader, get_worker_info from contextlib import nullcontext from dataclasses import dataclass, field from typing import Optional, Tuple, Dict, List from torch.nn import CrossEntropyLoss from datasets import load_dataset from transformers import AutoTokenizer # --------------------------------------------------------------------------- # SpiderPortal Model Architecture (Dense + MLA + Engram) # --------------------------------------------------------------------------- @dataclass class SpiderPortalConfig: vocab_size: int = 50257 hidden_size: int = 2048 num_hidden_layers: int = 6 num_attention_heads: int = 16 num_key_value_heads: int = 4 intermediate_size: int = 8192 hidden_act: str = "silu" num_experts: int = 32 num_experts_per_tok: int = 2 num_shared_experts: int = 1 router_aux_loss_coef: float = 0.05 max_loop_iters: int = 4 act_threshold: float = 0.5 max_position_embeddings: int = 32768 rope_theta: float = 10000000.0 rope_scaling: dict = None sliding_window: int = 4096 attention_dropout: float = 0.0 rms_norm_eps: float = 1e-6 initializer_range: float = 0.02 use_cache: bool = True tie_word_embeddings: bool = True prelude_layers: int = 2 coda_layers: int = 2 lora_rank: int = 128 loop_embed_dim: int = 128 vision_hidden_size: int = 2048 audio_hidden_size: int = 512 vision_num_frames: int = 60 vision_tokens_per_frame: int = 256 vision_temporal_tokens: int = 64 vision_temporal_layers: int = 2 model_type: str = "spiderportal" torch_dtype: str = "bfloat16" # MLA parameters (DeepSeek-V2 style, scaled for hidden_size=2048) kv_lora_rank: int = 128 q_lora_rank: int = 256 qk_rope_head_dim: int = 64 qk_nope_head_dim: int = 64 v_head_dim: int = 64 # Engram parameters (DeepSeek conditional memory) engram_layers: List[int] = field(default_factory=lambda: [1, 4]) engram_ngram_orders: Tuple[int, ...] = (2, 3) engram_hash_heads: int = 4 engram_table_size: int = 65537 # prime number for hash table engram_conv_kernel: int = 4 engram_conv_dilation: int = 3 engram_dim: int = 128 # per-head embedding dimension def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) angles = loop_t * freqs emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) emb_full[:loop_dim] = emb return h + emb_full.unsqueeze(0).unsqueeze(0) class SpiderPortalRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight.to(input_dtype) * hidden_states.to(input_dtype) def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0): dim = head_dim orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim)) pos_freqs = torch.arange(0, dim, 2).float() / dim beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max)) scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor))) return orig_inv_freq * scale # --------------------------------------------------------------------------- # MLA: Multi-Latent Attention (DeepSeek-V2 style) + Sliding Window # --------------------------------------------------------------------------- class SpiderPortalMLA(nn.Module): """Multi-Latent Attention with compressed KV cache and sliding window. For hidden_size=2048, num_heads=16: - qk_nope_head_dim=64, qk_rope_head_dim=64 → total head_dim=128 - kv_lora_rank=128 → 10.7x compression vs full 2048-dim KV - v_head_dim=64 → value projection - sliding_window=4096 → local attention range """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.kv_lora_rank = config.kv_lora_rank self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.v_head_dim = config.v_head_dim self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim self.sliding_window = getattr(config, 'sliding_window', None) # Q projection: optional low-rank → full Q if self.q_lora_rank > 0: self.q_a_proj = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False) self.q_a_layernorm = SpiderPortalRMSNorm(self.q_lora_rank) self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False) else: self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) # KV compression: hidden → kv_lora_rank (shared latent) self.kv_a_proj_with_mqa = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False) self.kv_a_layernorm = SpiderPortalRMSNorm(self.kv_lora_rank) # Decompress: kv_lora_rank → nope heads + v heads self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, ) # Output projection self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False) # RoPE frequencies rope_scaling = getattr(config, 'rope_scaling', None) if rope_scaling and rope_scaling.get("type") == "yarn": factor = rope_scaling.get("factor", 1.0) orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings) inv_freq = compute_yarn_inv_freq(self.qk_rope_head_dim, config.rope_theta, factor, orig_max_pos) else: inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.qk_rope_head_dim, 2).float() / self.qk_rope_head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def _rotate_half(self, x): x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def _apply_rotary(self, x, cos, sin): return (x * cos) + (self._rotate_half(x) * sin) def _make_sliding_window_mask(self, q_len, kv_len, device, dtype): """Create a sliding window causal mask.""" if self.sliding_window is None or self.sliding_window <= 0: return None mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, device=device, dtype=dtype) for i in range(q_len): start = max(0, i - self.sliding_window + 1) mask[i, start:i + 1] = 0.0 return mask def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): bsz, q_len, _ = hidden_states.size() # Q projection if self.q_lora_rank > 0: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) else: q = self.q_proj(hidden_states) q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # KV: compress to latent, then decompress kv_hidden = self.kv_a_proj_with_mqa(hidden_states) kv_latent, k_rope = torch.split(kv_hidden, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_latent_norm = self.kv_a_layernorm(kv_latent) kv_b_out = self.kv_b_proj(kv_latent_norm) k_nope, v = torch.split(kv_b_out, [self.num_heads * self.qk_nope_head_dim, self.num_heads * self.v_head_dim], dim=-1) k_nope = k_nope.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2) v = v.view(bsz, q_len, self.num_heads, self.v_head_dim).transpose(1, 2) k_rope = k_rope.unsqueeze(1) # RoPE on Q and K rope parts if position_ids is None: position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1) max_pos = position_ids.max().item() + 1 seq_len = max(max_pos, q_len) t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos, sin = emb.cos(), emb.sin() cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) q_rope = self._apply_rotary(q_rope, cos, sin) k_rope = self._apply_rotary(k_rope, cos, sin) # Assemble full K k_rope_expanded = k_rope.expand(-1, self.num_heads, -1, -1) k_full = torch.cat([k_nope, k_rope_expanded], dim=-1) q_full = torch.cat([q_nope, q_rope], dim=-1) # KV cache if past_key_value is not None: k_full = torch.cat([past_key_value[0], k_full], dim=2) v = torch.cat([past_key_value[1], v], dim=2) past_kv = (k_full, v) if use_cache else None # Build attention mask: user mask + sliding window final_mask = attention_mask if self.sliding_window is not None and self.sliding_window > 0: kv_len = k_full.size(2) sw_mask = self._make_sliding_window_mask(q_len, kv_len, hidden_states.device, hidden_states.dtype) if final_mask is not None: final_mask = final_mask + sw_mask else: final_mask = sw_mask # Attention with SDPA attn_output = F.scaled_dot_product_attention( q_full, k_full, v, attn_mask=final_mask, dropout_p=self.config.attention_dropout if self.training else 0.0, is_causal=(final_mask is None), ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) return self.o_proj(attn_output), past_kv # --------------------------------------------------------------------------- # Engram: Conditional Memory via Scalable Lookup (DeepSeek style) # --------------------------------------------------------------------------- def _tokenizer_compress(token_ids, vocab_size=50257): """Simulate NFKC + lowercase canonical ID projection.""" return token_ids % (vocab_size * 77 // 100) class SpiderPortalEngram(nn.Module): """Conditional memory module via NN-gram lookup. Applied only at specific recurrent layers (config.engram_layers). """ def __init__(self, config): super().__init__() self.config = config self.ngram_orders = config.engram_ngram_orders self.num_heads = config.engram_hash_heads self.table_size = config.engram_table_size self.d_mem = config.engram_dim self.total_mem_dim = len(self.ngram_orders) * self.num_heads * self.d_mem self.embed_tables = nn.ParameterDict() for n in self.ngram_orders: for h in range(self.num_heads): key = f"e_{n}_{h}" self.embed_tables[key] = nn.Parameter( torch.randn(self.table_size, self.d_mem) * 0.02 ) self.register_buffer("hash_seeds", torch.tensor([ (h + 1) * 2654435761 for _ in self.ngram_orders for h in range(self.num_heads) ], dtype=torch.int64)) self.W_k = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False) self.W_v = nn.Linear(self.total_mem_dim, config.hidden_size, bias=False) self.conv = nn.Conv1d( config.hidden_size, config.hidden_size, kernel_size=config.engram_conv_kernel, padding=config.engram_conv_kernel - 1, groups=config.hidden_size, ) self.conv_dilation = config.engram_conv_dilation with torch.no_grad(): self.conv.weight.zero_() if self.conv.bias is not None: self.conv.bias.zero_() self.q_norm = SpiderPortalRMSNorm(config.hidden_size) self.k_norm = SpiderPortalRMSNorm(config.hidden_size) def _compute_indices(self, compressed_ids, n, head_idx): """Vectorized NN-gram hash indices for a single (order, head).""" bsz, seq_len = compressed_ids.shape pad = torch.zeros(bsz, n - 1, dtype=compressed_ids.dtype, device=compressed_ids.device) padded = torch.cat([pad, compressed_ids], dim=1) indices_list = [] for i in range(n): indices_list.append(padded[:, i:i + seq_len]) ngrams = torch.stack(indices_list, dim=-1) seed = int(self.hash_seeds[head_idx].item()) h_val = torch.zeros(bsz, seq_len, dtype=torch.int64, device=compressed_ids.device) for i in range(n): h_val = h_val * 31 + ngrams[:, :, i] h_val = h_val % self.table_size h_val = (h_val * seed) % self.table_size return h_val def _retrieve(self, token_ids): """Retrieve memory vectors for a batch of token sequences.""" bsz, seq_len = token_ids.shape compressed = _tokenizer_compress(token_ids) all_parts = [] head_counter = 0 for n in self.ngram_orders: for h in range(self.num_heads): key = f"e_{n}_{h}" table = self.embed_tables[key] indices = self._compute_indices(compressed, n, head_counter) emb = table[indices.view(-1)] all_parts.append(emb.view(bsz, seq_len, self.d_mem)) head_counter += 1 memory = torch.cat(all_parts, dim=-1) return memory def forward(self, hidden_states, token_ids): mem = self._retrieve(token_ids) q = hidden_states k = self.W_k(mem) v = self.W_v(mem) q_norm = self.q_norm(q) k_norm = self.k_norm(k) alpha = torch.sigmoid( (q_norm * k_norm).sum(dim=-1, keepdim=True) / math.sqrt(q.shape[-1]) ) v_gated = alpha * v v_gated_t = v_gated.transpose(1, 2) conv_out = self.conv(v_gated_t) conv_out = conv_out[:, :, :v_gated_t.shape[-1]] conv_out = conv_out.transpose(1, 2) y = F.silu(conv_out) + v_gated return y # --------------------------------------------------------------------------- # FFN Expert (dense) # --------------------------------------------------------------------------- class SpiderPortalExpert(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() inter_size = intermediate_size or config.intermediate_size self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False) self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, hidden_states): return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) # --------------------------------------------------------------------------- # Prelude/Coda Dense Layer (uses MLA) # --------------------------------------------------------------------------- class SpiderPortalDenseLayer(nn.Module): """Prelude/coda dense layer with MLA attention.""" def __init__(self, config): super().__init__() self.self_attn = SpiderPortalMLA(config) dense_intermediate = config.hidden_size * 4 // 3 self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output ffn_input = self.post_attention_layernorm(hidden_states) ffn_output = self.ffn(ffn_input) hidden_states = hidden_states + ffn_output return hidden_states, past_kv # --------------------------------------------------------------------------- # Recurrent Dense Layer (uses MLA + optional Engram) # --------------------------------------------------------------------------- class SpiderPortalRecurrentDenseLayer(nn.Module): """Recurrent layer with MLA attention and optional Engram memory.""" def __init__(self, config, layer_idx, has_engram=False): super().__init__() self.layer_idx = layer_idx self.has_engram = has_engram self.self_attn = SpiderPortalMLA(config) if has_engram: self.engram = SpiderPortalEngram(config) self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_engram_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if has_engram else None def forward(self, hidden_states, token_ids=None, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output if self.has_engram and token_ids is not None: engram_out = self.engram(hidden_states, token_ids) hidden_states = hidden_states + engram_out if self.post_engram_layernorm is not None: hidden_states = self.post_engram_layernorm(hidden_states) ffn_input = self.post_attention_layernorm(hidden_states) ffn_output = self.ffn(ffn_input) hidden_states = hidden_states + ffn_output return hidden_states, 0.0, past_kv # --------------------------------------------------------------------------- # LTI Injection, ACT Halting, LoRA Adapter # --------------------------------------------------------------------------- class LTIInjection(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) self.delta_t = nn.Parameter(torch.tensor(1.0)) self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False) with torch.no_grad(): self.B.weight.data.normal_(mean=0.0, std=0.01) def get_A(self): return -torch.exp(self.log_A) def forward(self, h_t, e): A = self.get_A() return A * h_t + self.B(e) class ACTHalting(nn.Module): def __init__(self, config): super().__init__() self.halt_predictor = nn.Linear(config.hidden_size, 1) self.threshold = config.act_threshold def forward(self, hidden_states): return torch.sigmoid(self.halt_predictor(hidden_states)) class LoRAAdapter(nn.Module): def __init__(self, config): super().__init__() rank = config.lora_rank self.down = nn.Linear(config.hidden_size, rank, bias=False) self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02) self.scale = nn.Embedding(config.max_loop_iters, rank) with torch.no_grad(): self.scale.weight.data.zero_() self.down.weight.data.normal_(mean=0.0, std=0.001) def forward(self, x, loop_t): max_t = self.scale.num_embeddings - 1 t_idx = min(loop_t, max_t) s = self.scale(torch.tensor(t_idx, device=x.device)) down = self.down(x) * s return down @ self.B def checkpoint(func, *args, **kwargs): """Gradient checkpointing wrapper — saves VRAM at ~20% compute cost.""" if torch.is_grad_enabled(): return torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=False, **kwargs) return func(*args, **kwargs) # --------------------------------------------------------------------------- # Full Model # --------------------------------------------------------------------------- class SpiderPortalDenseModel(nn.Module): """Full RDT model with MLA attention + Engram memory at layers 1,4. Architecture: 2x Prelude (MLA + dense FFN) 6x Recurrent (MLA + Engram@L1,L4 + dense FFN) — with gradient checkpointing 2x Coda (MLA + dense FFN) LTI Injection + ACT Halting + LoRA Adapter """ def __init__(self, config): super().__init__() self.config = config self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) self.recurrent_layers = nn.ModuleList([ SpiderPortalRecurrentDenseLayer(config, i, has_engram=(i in config.engram_layers)) for i in range(config.num_hidden_layers) ]) self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.injection = LTIInjection(config) self.act_halting = ACTHalting(config) self.lora_adapter = LoRAAdapter(config) self.loop_embed_dim = config.loop_embed_dim def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None, token_ids=None): n_loops = n_loops or self.config.max_loop_iters input_embedding = input_embedding if input_embedding is not None else hidden_states for layer in self.prelude_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) e = hidden_states.clone() B, T_seq, D = hidden_states.shape halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) h_out = torch.zeros_like(hidden_states) past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) for t in range(n_loops): h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) if t > 0: injection = self.injection(hidden_states, input_embedding) hidden_states = hidden_states + injection new_past_key_values = [] for i, layer in enumerate(self.recurrent_layers): hidden_states, aux_loss, past_kv = checkpoint( layer, hidden_states, token_ids=token_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache ) new_past_key_values.append(past_kv) lora_delta = self.lora_adapter(hidden_states, t) hidden_states = hidden_states + lora_delta halt_prob = self.act_halting(hidden_states).squeeze(-1) still_running = ~halted remainder = (1.0 - cumulative_p).clamp(min=0) weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) weight = weight * still_running.to(hidden_states.dtype) h_out = h_out + weight.unsqueeze(-1) * hidden_states cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) halted = halted | (cumulative_p >= self.config.act_threshold) if halted.all() and not self.training: break never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) hidden_states = h_out + never_halted * hidden_states for layer in self.coda_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) hidden_states = self.norm(hidden_states) return hidden_states, 0.0, new_past_key_values class SpiderPortalForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.model = SpiderPortalDenseModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): if hasattr(self, 'model') and module is self.model.injection.B: return module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False): hidden_states = self.embed_tokens(input_ids) model_dtype = next(self.model.parameters()).dtype hidden_states = hidden_states.to(model_dtype) input_embedding = hidden_states.clone() if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device) causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min) causal_mask = causal_mask.triu(1) hidden_states, aux_loss, past_kv = self.model( hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops, token_ids=input_ids ) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv} def get_num_params(self): total = sum(p.numel() for p in self.parameters()) return {"total": total, "trainable": total} # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- class FineWebEduDataset(IterableDataset): def __init__(self, tokenizer, seq_len: int, subset: str, rank: int, world_size: int, local_token_file=None): self.tokenizer = tokenizer self.seq_len = seq_len self.subset = subset self.rank = rank self.world_size = world_size # Local tokenized data - USE mmapped binary for speed if local_token_file and os.path.exists(local_token_file): import numpy as np self.use_local = True self.local_file = local_token_file self.mm = np.memmap(local_token_file, dtype='= self.seq_len + 1: chunk = buf[: self.seq_len + 1] buf = buf[self.seq_len + 1 :] yield ( torch.tensor(chunk[:-1], dtype=torch.long), torch.tensor(chunk[1:], dtype=torch.long), ) # --------------------------------------------------------------------------- # LR schedule # --------------------------------------------------------------------------- def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: if step < warmup: return max_lr * step / warmup if step >= total: return min_lr decay = (step - warmup) / (total - warmup) return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) # --------------------------------------------------------------------------- # Checkpointing # --------------------------------------------------------------------------- def save_weights_only(model, step, epoch, ckpt_dir, ddp): if ddp: with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): model_state = model.state_dict() else: model_state = model.state_dict() ckpt_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-ep{epoch}-step{step}.pt") tmp_path = ckpt_path + ".tmp" torch.save(model_state, tmp_path) os.replace(tmp_path, ckpt_path) size_mb = os.path.getsize(ckpt_path) / (1024 * 1024) return ckpt_path, size_mb def save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, ckpt_name="full"): if ddp: with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): model_state = model.state_dict() optim_state = FSDP.optim_state_dict(model, optimizer) else: model_state = model.state_dict() optim_state = optimizer.state_dict() if not master: return None, 0 os.makedirs(ckpt_dir, exist_ok=True) final_path = os.path.join(ckpt_dir, f"spiderportal-v5-dense-{ckpt_name}.pt") tmp_path = final_path + ".tmp" torch.save( { "step": step, "epoch": epoch, "model_state_dict": model_state, "optimizer_state_dict": optim_state, "cfg": cfg, "vocab_size": vocab_size, }, tmp_path, ) os.replace(tmp_path, final_path) size_mb = os.path.getsize(final_path) / (1024 * 1024) return final_path, size_mb def load_checkpoint(model, optimizer, path, ddp): ckpt = torch.load(path, map_location="cpu", weights_only=False) if ddp: with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=False), ): model.load_state_dict(ckpt["model_state_dict"]) optim_state = FSDP.optim_state_dict_to_load( model=model, optim=optimizer, optim_state_dict=ckpt["optimizer_state_dict"], ) optimizer.load_state_dict(optim_state) else: model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) return int(ckpt["step"]), int(ckpt.get("epoch", 0)) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): # ------------------------------------------------------------------ # Distributed init # ------------------------------------------------------------------ ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: dist.init_process_group("nccl") rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = f"cuda:{local_rank}" torch.cuda.set_device(device) else: rank = local_rank = 0 world_size = 1 device = "cuda" if torch.cuda.is_available() else "cpu" master = rank == 0 if master: log( f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" ) # ------------------------------------------------------------------ # Tokenizer # ------------------------------------------------------------------ tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token vocab_size = tokenizer.vocab_size if master: log(f"Tokenizer: gpt2 | Vocab size: {vocab_size:,}") # ------------------------------------------------------------------ # Hyperparameters # ------------------------------------------------------------------ seq_len = 2048 micro_batch = 32 # Increased — 96GB VRAM can handle this target_tokens = 20_000_000_000 grad_accum = 2 global_batch_tok = world_size * micro_batch * grad_accum * seq_len total_steps = target_tokens // global_batch_tok warmup_steps = 200 lr = 3e-4 wd = 0.1 log_every = 10 ckpt_every = 500 ckpt_dir = "checkpoints-dense" dataset_subset = "sample-10BT" if master: log( f"[DENSE MLA+Engram] hidden=2048 | layers=6 | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | " f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}" ) log( f"Attention: MLA (kv_lora_rank=128, sliding_window=4096) | " f"Engram: layers [1,4] | Context: 32k | " f"Gradient checkpointing: enabled" ) # ------------------------------------------------------------------ # Model # ------------------------------------------------------------------ cfg = SpiderPortalConfig( hidden_size=2048, num_hidden_layers=6, num_attention_heads=16, num_key_value_heads=4, intermediate_size=8192, num_experts=32, num_experts_per_tok=2, num_shared_experts=1, router_aux_loss_coef=0.05, max_loop_iters=4, prelude_layers=2, coda_layers=2, lora_rank=128, rope_theta=10000000.0, rope_scaling=None, max_position_embeddings=32768, sliding_window=4096, tie_word_embeddings=True, kv_lora_rank=128, q_lora_rank=256, qk_rope_head_dim=64, qk_nope_head_dim=64, v_head_dim=64, engram_layers=[1, 4], engram_ngram_orders=(2, 3), engram_hash_heads=4, engram_table_size=65537, engram_dim=128, ) cfg.vocab_size = vocab_size bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 model = SpiderPortalForConditionalGeneration(cfg) if ddp: mp_policy = MixedPrecision( param_dtype=amp_dtype, reduce_dtype=amp_dtype, buffer_dtype=amp_dtype, ) wrap_policy = ModuleWrapPolicy({SpiderPortalDenseLayer, SpiderPortalRecurrentDenseLayer}) model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mp_policy, auto_wrap_policy=wrap_policy, device_id=local_rank, ) amp_ctx = nullcontext() else: model = model.to(device) amp_ctx = torch.amp.autocast(device_type="cuda", dtype=amp_dtype) if torch.cuda.is_available() else nullcontext() # Enable torch.compile for 20-30% speedup try: model = torch.compile(model, mode="reduce-overhead") if master: log("torch.compile: enabled (reduce-overhead)") except Exception as e: if master: log(f"torch.compile failed ({e}), using eager mode") if master: n_params = sum(p.numel() for p in model.parameters()) engram_params = sum(p.numel() for n, p in model.named_parameters() if 'engram' in n) mla_params = sum(p.numel() for n, p in model.named_parameters() if 'self_attn' in n) embed_params = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n or 'lm_head' in n) ffn_params = sum(p.numel() for n, p in model.named_parameters() if 'ffn' in n or 'gate_proj' in n or 'up_proj' in n or 'down_proj' in n) other_params = n_params - engram_params - mla_params - embed_params - ffn_params log( f"Parameters: {n_params:,} (all active) | " f"Embeddings: {embed_params:,} | MLA: {mla_params:,} | " f"FFN: {ffn_params:,} | Engram: {engram_params:,} | " f"Other: {other_params:,} | AMP dtype: {amp_dtype}" ) # ------------------------------------------------------------------ # Optimizer — dual optimizer for Engram embeddings # ------------------------------------------------------------------ engram_params_list = [p for n, p in model.named_parameters() if 'engram' in n and 'embed_tables' in n] backbone_params = [p for n, p in model.named_parameters() if 'engram' not in n or 'embed_tables' not in n] optimizer = torch.optim.AdamW( backbone_params, lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True ) if engram_params_list: engram_optimizer = torch.optim.Adam( engram_params_list, lr=lr * 5, betas=(0.9, 0.95), eps=1e-8 ) else: engram_optimizer = None # ------------------------------------------------------------------ # Resume from latest checkpoint # ------------------------------------------------------------------ start_step = 0 start_epoch = 1 best_loss = float("inf") existing_ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith("spiderportal-v5-dense-ep") and f.endswith(".pt") and "-step" not in f] if os.path.isdir(ckpt_dir) else [] if existing_ckpts: latest = os.path.join(ckpt_dir, sorted(existing_ckpts)[-1]) if master: log(f"Resuming from checkpoint: {latest}") start_step, start_epoch = load_checkpoint(model, optimizer, latest, ddp) if master: log(f"Resumed at step {start_step}, epoch {start_epoch}") # ------------------------------------------------------------------ # Dataset + DataLoader # ------------------------------------------------------------------ # Check for pre-tokenized binary file local_token_file = os.environ.get("TOKEN_FILE", "data/fineweb-edu-sample-10BT.bin") dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file) num_workers = 16 if dataset.use_local else 4 prefetch = 8 if dataset.use_local else 2 loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch) if master: log(f"DataLoader: num_workers={num_workers}, prefetch={prefetch}, use_local={dataset.use_local}") # ------------------------------------------------------------------ # Training loop # ------------------------------------------------------------------ if master: os.makedirs(ckpt_dir, exist_ok=True) model.train() data_iter = iter(loader) t0 = time.perf_counter() step = start_step epoch = start_epoch step_ckpt_files = [] tokens_in_epoch = 0 tokens_per_epoch = target_tokens while step < total_steps: cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) for g in optimizer.param_groups: g["lr"] = cur_lr if engram_optimizer: for g in engram_optimizer.param_groups: g["lr"] = cur_lr * 5 optimizer.zero_grad() if engram_optimizer: engram_optimizer.zero_grad() loss_accum = 0.0 for micro_step in range(grad_accum): try: x, y = next(data_iter) except StopIteration: # Dataset exhausted — reshuffle and restart if master: log(f"Dataset exhausted at step {step}, restarting DataLoader") dataset = FineWebEduDataset(tokenizer, seq_len, dataset_subset, rank, world_size, local_token_file=local_token_file) loader = DataLoader(dataset, batch_size=micro_batch, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch) data_iter = iter(loader) x, y = next(data_iter) x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) sync = ( nullcontext() if (not ddp or micro_step == grad_accum - 1) else model.no_sync() ) with sync, amp_ctx: output = model(x) if isinstance(output, dict): logits = output["logits"] else: logits = output loss = nn.functional.cross_entropy( logits.view(-1, vocab_size), y.view(-1) ) loss = loss / grad_accum loss.backward() loss_accum += loss.item() if ddp: grad_norm = model.clip_grad_norm_(1.0) else: grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() if engram_optimizer: engram_optimizer.step() step += 1 tokens_in_epoch += global_batch_tok if master and step % log_every == 0: dt = time.perf_counter() - t0 tok_per_sec = global_batch_tok * log_every / dt tokens_seen = step * global_batch_tok log( f"Epoch {epoch} | step {step:6d}/{total_steps} | loss {loss_accum:.4f} " f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " f"| {tok_per_sec / 1e6:.2f}M tok/s " f"| {tokens_seen / 1e9:.2f}B tokens seen" ) t0 = time.perf_counter() if step % ckpt_every == 0 and master: ckpt_path, size_mb = save_weights_only(model, step, epoch, ckpt_dir, ddp) step_ckpt_files.append(ckpt_path) log(f"Saved weights-only: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") if tokens_in_epoch >= tokens_per_epoch: epoch_loss = loss_accum if master: epoch_time = (time.perf_counter() - t0) / 60 log(f"Epoch {epoch} complete | loss={epoch_loss:.4f} | Time: {epoch_time:.1f}min") for f in step_ckpt_files: if os.path.exists(f): os.remove(f) log(f" Deleted step checkpoint: {os.path.basename(f)}") step_ckpt_files.clear() ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"ep{epoch}") if ckpt_path: log(f"Saved epoch checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") if epoch_loss < best_loss: best_loss = epoch_loss ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, "best") if ckpt_path: log(f"Saved best checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") epoch += 1 tokens_in_epoch = 0 if step > start_step and master: ckpt_path, size_mb = save_full_checkpoint(model, optimizer, step, epoch, cfg, vocab_size, ckpt_dir, ddp, master, f"final-ep{epoch}") if ckpt_path: log(f"Saved final checkpoint: {os.path.basename(ckpt_path)} ({size_mb:.0f}MB)") if ddp: dist.barrier() dist.destroy_process_group() if master: log("Training complete.") if __name__ == "__main__": main()