""" Sparse Mixture-of-Experts for Chimera (CPU-first). Key design choices: * Routing is computed in the model's compute dtype (no fp32 promotion): the original draft cast every router input to fp32 which doubled memory bandwidth for nothing on CPUs without dedicated softmax units. * Dispatch uses ``index_select`` + boolean masks per expert. No global ``argsort`` of the routing pairs and no ``bincount`` table. This keeps the path ``torch.compile``-friendly even when expert counts vary. * All experts share an :class:`SwiGLUMLP` topology so weights can be packed ternary identically to the rest of the model. """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .layers import SwiGLUMLP class NoAuxMoEGate(nn.Module): """Top-k softmax router with optional bias-only correction (no aux loss).""" __constants__ = ["n_routed_experts", "num_experts_per_tok"] def __init__(self, hidden_size: int, n_routed_experts: int, num_experts_per_tok: int = 2): super().__init__() self.n_routed_experts = int(n_routed_experts) self.num_experts_per_tok = int(num_experts_per_tok) self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size)) nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5) # Buffer (not a Parameter): bias correction updated by training scripts. self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) def forward(self, x: torch.Tensor): # x: [N, D] in arbitrary dtype. Routing is stable enough in bf16/fp32. scores = F.linear(x, self.weight) + self.e_score_correction_bias probs = F.softmax(scores, dim=-1) weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1) weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9) return indices, weights class MoELayer(nn.Module): """Sparse MoE block with grouped expert dispatch.""" def __init__(self, hidden_size: int, moe_intermediate_size: int, n_routed_experts: int = 16, n_shared_experts: int = 1, num_experts_per_tok: int = 2, use_ternary: bool = True): super().__init__() self.hidden_size = int(hidden_size) self.n_routed_experts = int(n_routed_experts) self.n_shared_experts = int(n_shared_experts) self.num_experts_per_tok = int(num_experts_per_tok) self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts, self.num_experts_per_tok) self.experts = nn.ModuleList([ SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary) for _ in range(self.n_routed_experts) ]) if self.n_shared_experts > 0: shared_inter = max(1, moe_intermediate_size * self.n_shared_experts) self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter, use_ternary=use_ternary) else: self.shared_experts = None def forward(self, x: torch.Tensor) -> torch.Tensor: orig_shape = x.shape flat = x.reshape(-1, self.hidden_size) N = flat.size(0) topk_idx, topk_w = self.gate(flat) # [N, k] out = torch.zeros_like(flat) # Per-expert dispatch via boolean masks: avoids the global argsort and # ``bincount`` of the previous draft and keeps the structure compatible # with torch.compile. for e in range(self.n_routed_experts): match = (topk_idx == e) if not match.any(): continue # Token positions and per-pair weights for this expert. tok_pos, slot_pos = match.nonzero(as_tuple=True) w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype) y = self.experts[e](flat.index_select(0, tok_pos)) out.index_add_(0, tok_pos, y * w) if self.shared_experts is not None: out = out + self.shared_experts(flat) return out.reshape(orig_shape) __all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"]