File size: 4,289 Bytes
11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """
Sparse Mixture-of-Experts for Chimera (CPU-first).
Key design choices:
* Routing is computed in the model's compute dtype (no fp32 promotion):
the original draft cast every router input to fp32 which doubled memory
bandwidth for nothing on CPUs without dedicated softmax units.
* Dispatch uses ``index_select`` + boolean masks per expert. No global
``argsort`` of the routing pairs and no ``bincount`` table. This keeps
the path ``torch.compile``-friendly even when expert counts vary.
* All experts share an :class:`SwiGLUMLP` topology so weights can be packed
ternary identically to the rest of the model.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import SwiGLUMLP
class NoAuxMoEGate(nn.Module):
"""Top-k softmax router with optional bias-only correction (no aux loss)."""
__constants__ = ["n_routed_experts", "num_experts_per_tok"]
def __init__(self, hidden_size: int, n_routed_experts: int,
num_experts_per_tok: int = 2):
super().__init__()
self.n_routed_experts = int(n_routed_experts)
self.num_experts_per_tok = int(num_experts_per_tok)
self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size))
nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5)
# Buffer (not a Parameter): bias correction updated by training scripts.
self.register_buffer("e_score_correction_bias",
torch.zeros(self.n_routed_experts))
def forward(self, x: torch.Tensor):
# x: [N, D] in arbitrary dtype. Routing is stable enough in bf16/fp32.
scores = F.linear(x, self.weight) + self.e_score_correction_bias
probs = F.softmax(scores, dim=-1)
weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1)
weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9)
return indices, weights
class MoELayer(nn.Module):
"""Sparse MoE block with grouped expert dispatch."""
def __init__(self, hidden_size: int, moe_intermediate_size: int,
n_routed_experts: int = 16, n_shared_experts: int = 1,
num_experts_per_tok: int = 2, use_ternary: bool = True):
super().__init__()
self.hidden_size = int(hidden_size)
self.n_routed_experts = int(n_routed_experts)
self.n_shared_experts = int(n_shared_experts)
self.num_experts_per_tok = int(num_experts_per_tok)
self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts,
self.num_experts_per_tok)
self.experts = nn.ModuleList([
SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary)
for _ in range(self.n_routed_experts)
])
if self.n_shared_experts > 0:
shared_inter = max(1, moe_intermediate_size * self.n_shared_experts)
self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter,
use_ternary=use_ternary)
else:
self.shared_experts = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
flat = x.reshape(-1, self.hidden_size)
N = flat.size(0)
topk_idx, topk_w = self.gate(flat) # [N, k]
out = torch.zeros_like(flat)
# Per-expert dispatch via boolean masks: avoids the global argsort and
# ``bincount`` of the previous draft and keeps the structure compatible
# with torch.compile.
for e in range(self.n_routed_experts):
match = (topk_idx == e)
if not match.any():
continue
# Token positions and per-pair weights for this expert.
tok_pos, slot_pos = match.nonzero(as_tuple=True)
w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype)
y = self.experts[e](flat.index_select(0, tok_pos))
out.index_add_(0, tok_pos, y * w)
if self.shared_experts is not None:
out = out + self.shared_experts(flat)
return out.reshape(orig_shape)
__all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"]
|