| """ |
| Sparse Mixture-of-Experts for Chimera (CPU-first). |
| |
| Key design choices: |
| * Routing is computed in the model's compute dtype (no fp32 promotion): |
| the original draft cast every router input to fp32 which doubled memory |
| bandwidth for nothing on CPUs without dedicated softmax units. |
| * Dispatch uses ``index_select`` + boolean masks per expert. No global |
| ``argsort`` of the routing pairs and no ``bincount`` table. This keeps |
| the path ``torch.compile``-friendly even when expert counts vary. |
| * All experts share an :class:`SwiGLUMLP` topology so weights can be packed |
| ternary identically to the rest of the model. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .layers import SwiGLUMLP |
|
|
|
|
| class NoAuxMoEGate(nn.Module): |
| """Top-k softmax router with optional bias-only correction (no aux loss).""" |
|
|
| __constants__ = ["n_routed_experts", "num_experts_per_tok"] |
|
|
| def __init__(self, hidden_size: int, n_routed_experts: int, |
| num_experts_per_tok: int = 2): |
| super().__init__() |
| self.n_routed_experts = int(n_routed_experts) |
| self.num_experts_per_tok = int(num_experts_per_tok) |
| self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size)) |
| nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5) |
| |
| self.register_buffer("e_score_correction_bias", |
| torch.zeros(self.n_routed_experts)) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| scores = F.linear(x, self.weight) + self.e_score_correction_bias |
| probs = F.softmax(scores, dim=-1) |
| weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1) |
| weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9) |
| return indices, weights |
|
|
|
|
| class MoELayer(nn.Module): |
| """Sparse MoE block with grouped expert dispatch.""" |
|
|
| def __init__(self, hidden_size: int, moe_intermediate_size: int, |
| n_routed_experts: int = 16, n_shared_experts: int = 1, |
| num_experts_per_tok: int = 2, use_ternary: bool = True): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.n_routed_experts = int(n_routed_experts) |
| self.n_shared_experts = int(n_shared_experts) |
| self.num_experts_per_tok = int(num_experts_per_tok) |
| self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts, |
| self.num_experts_per_tok) |
| self.experts = nn.ModuleList([ |
| SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary) |
| for _ in range(self.n_routed_experts) |
| ]) |
| if self.n_shared_experts > 0: |
| shared_inter = max(1, moe_intermediate_size * self.n_shared_experts) |
| self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter, |
| use_ternary=use_ternary) |
| else: |
| self.shared_experts = None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| orig_shape = x.shape |
| flat = x.reshape(-1, self.hidden_size) |
| N = flat.size(0) |
|
|
| topk_idx, topk_w = self.gate(flat) |
| out = torch.zeros_like(flat) |
|
|
| |
| |
| |
| for e in range(self.n_routed_experts): |
| match = (topk_idx == e) |
| if not match.any(): |
| continue |
| |
| tok_pos, slot_pos = match.nonzero(as_tuple=True) |
| w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype) |
| y = self.experts[e](flat.index_select(0, tok_pos)) |
| out.index_add_(0, tok_pos, y * w) |
|
|
| if self.shared_experts is not None: |
| out = out + self.shared_experts(flat) |
|
|
| return out.reshape(orig_shape) |
|
|
|
|
| __all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"] |
|
|