chomera / chimera /moe.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
"""
Sparse Mixture-of-Experts for Chimera (CPU-first).
Key design choices:
* Routing is computed in the model's compute dtype (no fp32 promotion):
the original draft cast every router input to fp32 which doubled memory
bandwidth for nothing on CPUs without dedicated softmax units.
* Dispatch uses ``index_select`` + boolean masks per expert. No global
``argsort`` of the routing pairs and no ``bincount`` table. This keeps
the path ``torch.compile``-friendly even when expert counts vary.
* All experts share an :class:`SwiGLUMLP` topology so weights can be packed
ternary identically to the rest of the model.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import SwiGLUMLP
class NoAuxMoEGate(nn.Module):
"""Top-k softmax router with optional bias-only correction (no aux loss)."""
__constants__ = ["n_routed_experts", "num_experts_per_tok"]
def __init__(self, hidden_size: int, n_routed_experts: int,
num_experts_per_tok: int = 2):
super().__init__()
self.n_routed_experts = int(n_routed_experts)
self.num_experts_per_tok = int(num_experts_per_tok)
self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size))
nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5)
# Buffer (not a Parameter): bias correction updated by training scripts.
self.register_buffer("e_score_correction_bias",
torch.zeros(self.n_routed_experts))
def forward(self, x: torch.Tensor):
# x: [N, D] in arbitrary dtype. Routing is stable enough in bf16/fp32.
scores = F.linear(x, self.weight) + self.e_score_correction_bias
probs = F.softmax(scores, dim=-1)
weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1)
weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9)
return indices, weights
class MoELayer(nn.Module):
"""Sparse MoE block with grouped expert dispatch."""
def __init__(self, hidden_size: int, moe_intermediate_size: int,
n_routed_experts: int = 16, n_shared_experts: int = 1,
num_experts_per_tok: int = 2, use_ternary: bool = True):
super().__init__()
self.hidden_size = int(hidden_size)
self.n_routed_experts = int(n_routed_experts)
self.n_shared_experts = int(n_shared_experts)
self.num_experts_per_tok = int(num_experts_per_tok)
self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts,
self.num_experts_per_tok)
self.experts = nn.ModuleList([
SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary)
for _ in range(self.n_routed_experts)
])
if self.n_shared_experts > 0:
shared_inter = max(1, moe_intermediate_size * self.n_shared_experts)
self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter,
use_ternary=use_ternary)
else:
self.shared_experts = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
flat = x.reshape(-1, self.hidden_size)
N = flat.size(0)
topk_idx, topk_w = self.gate(flat) # [N, k]
out = torch.zeros_like(flat)
# Per-expert dispatch via boolean masks: avoids the global argsort and
# ``bincount`` of the previous draft and keeps the structure compatible
# with torch.compile.
for e in range(self.n_routed_experts):
match = (topk_idx == e)
if not match.any():
continue
# Token positions and per-pair weights for this expert.
tok_pos, slot_pos = match.nonzero(as_tuple=True)
w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype)
y = self.experts[e](flat.index_select(0, tok_pos))
out.index_add_(0, tok_pos, y * w)
if self.shared_experts is not None:
out = out + self.shared_experts(flat)
return out.reshape(orig_shape)
__all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"]