import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple, List import warnings class RotaryPositionEmbedding(nn.Module): """RoPE implementation without traditional position embeddings""" def __init__(self, dim: int, base: int = 10000): super().__init__() self.dim = dim self.base = base # Only compute frequencies for half the dimensions (complex form) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq, persistent=False) def forward(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]: seq_len = x.shape[seq_dim] device = x.device dtype = x.dtype t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) # Create cosine and sine components cos = torch.cos(freqs).to(dtype) sin = torch.sin(freqs).to(dtype) return cos, sin def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply rotary position embedding to input tensor""" # x shape: [batch_size, num_heads, seq_len, head_dim] # cos, sin shape: [seq_len, head_dim//2] batch_size, num_heads, seq_len, head_dim = x.shape half_dim = head_dim // 2 # Reshape x to separate real and imaginary parts x_reshaped = x.view(batch_size, num_heads, seq_len, half_dim, 2) x_real = x_reshaped[..., 0] x_imag = x_reshaped[..., 1] # Expand cos and sin to match dimensions cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, half_dim] sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, half_dim] # Apply rotation x_real_rot = x_real * cos - x_imag * sin x_imag_rot = x_real * sin + x_imag * cos # Combine back x_rotated = torch.stack([x_real_rot, x_imag_rot], dim=-1) x_rotated = x_rotated.view(batch_size, num_heads, seq_len, head_dim) return x_rotated.type_as(x) class VariableGroupedQueryAttention(nn.Module): """Variable Grouped Query Attention with layer-specific head grouping and optional RoPE/NoPE""" def __init__(self, dim: int, num_heads: int = 8, layer_idx: int = 0, num_layers: int = 12, variable_groups: bool = True, use_rope: bool = True): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.variable_groups = variable_groups self.layer_idx = layer_idx self.num_layers = num_layers self.use_rope = use_rope # Variable group calculation - different KV heads for each layer if variable_groups: # Create progressive pattern: more KV heads in deeper layers # Early layers: fewer KV heads (more compression) # Later layers: more KV heads (more detail) # Normalized layer position (0 to 1) layer_ratio = layer_idx / max(1, num_layers - 1) # Calculate KV heads with progressive scaling # Start with fewer KV heads (e.g., 2-3) and increase toward end min_kv_heads = max(1, num_heads // 6) # Minimum 1/6 of heads max_kv_heads = max(2, num_heads // 3) # Maximum 1/3 of heads # Progressive scaling: early layers use fewer, later use more raw_kv_heads = int(min_kv_heads + (max_kv_heads - min_kv_heads) * layer_ratio) # Ensure it's a valid divisor self.num_kv_heads = raw_kv_heads if self.num_heads % self.num_kv_heads != 0: # Find the nearest valid num_kv_heads for i in range(self.num_kv_heads, 0, -1): if self.num_heads % i == 0: self.num_kv_heads = i break # If that didn't work, try going up if self.num_heads % self.num_kv_heads != 0: for i in range(self.num_kv_heads + 1, max_kv_heads + 1): if self.num_heads % i == 0: self.num_kv_heads = i break else: self.num_kv_heads = max(2, num_heads // 2) # Final validation assert self.num_heads % self.num_kv_heads == 0, \ f"Layer {layer_idx}: num_heads ({num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" # Query projections self.q_proj = nn.Linear(dim, dim, bias=False) # Key-Value projections with grouped attention self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # Output projection self.out_proj = nn.Linear(dim, dim, bias=False) # RoPE - only create if using positional embeddings # NoPE layers (every 4th layer) skip positional embeddings entirely if self.use_rope: self.rope = RotaryPositionEmbedding(self.head_dim) else: self.rope = None def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, seq_len, _ = x.shape # Project queries, keys, values q = self.q_proj(x) # [batch, seq_len, dim] k = self.k_proj(x) # [batch, seq_len, num_kv_heads * head_dim] v = self.v_proj(x) # [batch, seq_len, num_kv_heads * head_dim] # Reshape for multi-head attention q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # Apply RoPE to queries and keys (NoPE layers skip this) # NoPE layers rely on causal attention mask for positional information if self.use_rope and self.rope is not None: cos, sin = self.rope(q) q = apply_rotary_pos_emb(q, cos, sin) k = apply_rotary_pos_emb(k, cos, sin) # else: NoPE - no positional embeddings applied, causal mask provides ordering # Expand KV heads for grouped query attention if self.num_kv_heads != self.num_heads: k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) # Compute attention scores attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # Apply attention mask if provided if attention_mask is not None: attn_scores = attn_scores + attention_mask attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) # Apply attention to values attn_output = torch.matmul(attn_weights, v) # Reshape and project back attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.dim ) return self.out_proj(attn_output) class Expert(nn.Module): """Single expert in the MOE layer""" def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class MOELayer(nn.Module): """Mixture of Experts Layer with adaptive routing based on input complexity""" def __init__(self, dim: int, hidden_dim: int, num_experts: int = 4, capacity_factor: float = 1.0, noisy_gating: bool = True, adaptive_routing: bool = True): super().__init__() self.dim = dim self.num_experts = num_experts self.capacity_factor = capacity_factor self.noisy_gating = noisy_gating self.adaptive_routing = adaptive_routing # Create experts self.experts = nn.ModuleList([ Expert(dim, hidden_dim) for _ in range(num_experts) ]) # Standard gate network self.gate = nn.Linear(dim, num_experts) # NOVEL: Adaptive complexity-based routing # Learns to route tokens based on their complexity/importance if adaptive_routing: # Complexity encoder: estimates how "complex" each token representation is self.complexity_encoder = nn.Sequential( nn.Linear(dim, dim // 4), nn.GELU(), nn.Linear(dim // 4, 1), nn.Sigmoid() # Output: 0 (simple) to 1 (complex) ) # Adaptive temperature: dynamically adjusts expert selection based on complexity self.complexity_proj = nn.Linear(dim, 1) # Learnable bias for complexity-aware routing self.complexity_bias = nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = x.shape # Flatten for expert routing x_flat = x.reshape(-1, dim) num_tokens = x_flat.shape[0] # Compute standard gate scores gate_scores = self.gate(x_flat) # NOVEL: Adaptive routing based on token complexity if self.adaptive_routing: # Estimate complexity of each token (0 = simple, 1 = complex) complexity_scores = self.complexity_encoder(x_flat) # [num_tokens, 1] # Compute adaptive temperature based on complexity # Complex tokens get lower temperature (sharper distribution) # Simple tokens get higher temperature (softer distribution) complexity_temp = self.complexity_proj(x_flat) + self.complexity_bias # Temperature in range [0.5, 2.0] - inverse relationship with complexity adaptive_temp = 0.5 + 1.5 * (1.0 - complexity_scores.squeeze(-1)) # Apply adaptive temperature scaling to gate scores # Lower temp = sharper = focus on fewer experts # Higher temp = softer = distribute more evenly scaled_scores = gate_scores / (adaptive_temp.unsqueeze(-1) + 1e-8) if self.noisy_gating and self.training: # Reduced noise for complex tokens (they should be more confident) noise_scale = (1.0 / self.num_experts) * (1.0 - complexity_scores.squeeze(-1) * 0.5) noise = torch.randn_like(gate_scores) * noise_scale.unsqueeze(-1) scaled_scores = scaled_scores + noise else: scaled_scores = gate_scores if self.noisy_gating and self.training: noise = torch.randn_like(gate_scores) * (1.0 / self.num_experts) scaled_scores = scaled_scores + noise # Get top-2 experts using adaptive scores top_k = 2 top_scores, top_indices = torch.topk(scaled_scores, k=top_k, dim=-1) top_gates = F.softmax(top_scores, dim=-1, dtype=torch.float32).to(x_flat.dtype) # Create placeholder for final output final_output = torch.zeros_like(x_flat) # Compute auxiliary loss for load balancing (use original gate_scores, not scaled) self.aux_loss = self._load_balancing_loss(gate_scores, top_indices) # Route tokens to experts for i in range(top_k): # Process tokens for the i-th choice expert expert_indices = top_indices[:, i] gate_values = top_gates[:, i].unsqueeze(-1) for expert_idx, expert in enumerate(self.experts): token_indices = (expert_indices == expert_idx).nonzero(as_tuple=True)[0] if token_indices.numel() > 0: selected_tokens = x_flat[token_indices] selected_gates = gate_values[token_indices] expert_output = expert(selected_tokens) final_output.index_add_(0, token_indices, expert_output * selected_gates) # Reshape back to original dimensions return final_output.reshape(batch_size, seq_len, dim) def _load_balancing_loss(self, gate_scores: torch.Tensor, top_indices: torch.Tensor) -> torch.Tensor: """Compute load balancing auxiliary loss""" if not self.training: return torch.tensor(0.0, device=gate_scores.device) # Compute fraction of tokens routed to each expert (based on top-1 choice) top1_indices = top_indices[:, 0] expert_mask = F.one_hot(top1_indices, num_classes=self.num_experts).float() routing_fraction = expert_mask.mean(dim=0) # Compute fraction of gate probability for each expert gate_prob = F.softmax(gate_scores, dim=-1) gate_fraction = gate_prob.mean(dim=0) # Load balancing loss load_balance_loss = self.num_experts * torch.sum(routing_fraction * gate_fraction) return load_balance_loss class SlimMoETransformerBlock(nn.Module): """Transformer block with VGQA and MOE""" def __init__(self, dim: int, num_heads: int, hidden_dim: int, num_experts: int = 4, dropout: float = 0.1, layer_idx: int = 0, num_layers: int = 12, adaptive_routing: bool = True): super().__init__() self.dim = dim self.adaptive_routing = adaptive_routing # Attention components with layer-specific KV heads self.attn_norm = nn.LayerNorm(dim) # NoPE every 4th layer (layers 3, 7, 11, ...), RoPE for all others # Pattern: layer_idx % 4 == 3 means it's the 4th layer (0-indexed: 3rd, 7th, etc.) use_rope = (layer_idx % 4 != 3) self.attention = VariableGroupedQueryAttention( dim, num_heads, layer_idx=layer_idx, num_layers=num_layers, variable_groups=True, use_rope=use_rope ) # Dense transformer feed-forward (before MoE) self.dense_ffn_norm = nn.LayerNorm(dim) self.dense_ffn = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) # MOE components self.moe_norm = nn.LayerNorm(dim) self.moe = MOELayer(dim, hidden_dim, num_experts, adaptive_routing=adaptive_routing) # Dropout self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Attention branch with residual attn_norm_out = self.attn_norm(x) attn_out = self.attention(attn_norm_out, attention_mask) x = x + self.dropout(attn_out) # Dense transformer feed-forward branch with residual dense_ffn_norm_out = self.dense_ffn_norm(x) dense_ffn_out = self.dense_ffn(dense_ffn_norm_out) x = x + dense_ffn_out # MOE branch with residual moe_norm_out = self.moe_norm(x) moe_out = self.moe(moe_norm_out) x = x + self.dropout(moe_out) return x class SlimMOETransformer(nn.Module): """Complete MOE Transformer with Variable Grouped Query Attention and RoPE""" def __init__(self, vocab_size: int = 50257, dim: int = 768, num_layers: int = 12, num_heads: int = 12, hidden_dim: int = 2048, num_experts: int = 4, max_seq_len: int = 2048, dropout: float = 0.1, adaptive_routing: bool = True): super().__init__() self.vocab_size = vocab_size self.dim = dim self.num_layers = num_layers self.max_seq_len = max_seq_len self.token_embedding = nn.Embedding(vocab_size, dim) self.dropout = nn.Dropout(dropout) self.layers = nn.ModuleList([ SlimMoETransformerBlock( dim=dim, num_heads=num_heads, hidden_dim=hidden_dim, num_experts=num_experts, dropout=dropout, layer_idx=i, num_layers=num_layers, adaptive_routing=adaptive_routing ) for i in range(num_layers) ]) self.norm = nn.LayerNorm(dim) # --- FIX: Remove the lm_head from the base transformer model --- # self.lm_head = nn.Linear(dim, vocab_size, bias=False) # The CausalLM wrapper will handle the final projection. self.apply(self._init_weights) self._calculate_parameters() # This will now show a smaller number def _init_weights(self, module): """Initialize weights""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) def _calculate_parameters(self): # ... (this method is unchanged) ... total_params = sum(p.numel() for p in self.parameters()) print(f"Total Parameters: {total_params:,}") def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None) -> dict: # Note: labels are ignored here now batch_size, seq_len = input_ids.shape causal_mask = torch.triu( torch.full((1, 1, seq_len, seq_len), -torch.finfo(torch.get_default_dtype()).max, device=input_ids.device), diagonal=1 ) if attention_mask is not None: padding_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -torch.finfo( torch.get_default_dtype()).max extended_attention_mask = causal_mask + padding_mask else: extended_attention_mask = causal_mask x = self.token_embedding(input_ids) * math.sqrt(self.dim) x = self.dropout(x) total_aux_loss = 0.0 for layer in self.layers: x = layer(x, extended_attention_mask) if self.training: total_aux_loss += layer.moe.aux_loss x = self.norm(x) # --- FIX: Return hidden states and aux loss, not logits --- return { 'last_hidden_state': x, 'aux_loss': total_aux_loss } def create_moe_model(vocab_size: int = 50257) -> SlimMOETransformer: """ Create a MOE model with approximately 300M parameters. Configuration: - dim=768, num_layers=16, num_heads=12 - hidden_dim=1536, num_experts=4 - This yields ~280-290M parameters, safely under 300M """ return SlimMOETransformer( vocab_size=vocab_size, dim=768, num_layers=16, num_heads=12, hidden_dim=1536, num_experts=4, max_seq_len=2048, dropout=0.1 )