File size: 4,289 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Sparse Mixture-of-Experts for Chimera (CPU-first).

Key design choices:
* Routing is computed in the model's compute dtype (no fp32 promotion):
  the original draft cast every router input to fp32 which doubled memory
  bandwidth for nothing on CPUs without dedicated softmax units.
* Dispatch uses ``index_select`` + boolean masks per expert.  No global
  ``argsort`` of the routing pairs and no ``bincount`` table.  This keeps
  the path ``torch.compile``-friendly even when expert counts vary.
* All experts share an :class:`SwiGLUMLP` topology so weights can be packed
  ternary identically to the rest of the model.
"""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .layers import SwiGLUMLP


class NoAuxMoEGate(nn.Module):
    """Top-k softmax router with optional bias-only correction (no aux loss)."""

    __constants__ = ["n_routed_experts", "num_experts_per_tok"]

    def __init__(self, hidden_size: int, n_routed_experts: int,
                 num_experts_per_tok: int = 2):
        super().__init__()
        self.n_routed_experts = int(n_routed_experts)
        self.num_experts_per_tok = int(num_experts_per_tok)
        self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size))
        nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5)
        # Buffer (not a Parameter): bias correction updated by training scripts.
        self.register_buffer("e_score_correction_bias",
                             torch.zeros(self.n_routed_experts))

    def forward(self, x: torch.Tensor):
        # x: [N, D] in arbitrary dtype.  Routing is stable enough in bf16/fp32.
        scores = F.linear(x, self.weight) + self.e_score_correction_bias
        probs = F.softmax(scores, dim=-1)
        weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1)
        weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9)
        return indices, weights


class MoELayer(nn.Module):
    """Sparse MoE block with grouped expert dispatch."""

    def __init__(self, hidden_size: int, moe_intermediate_size: int,
                 n_routed_experts: int = 16, n_shared_experts: int = 1,
                 num_experts_per_tok: int = 2, use_ternary: bool = True):
        super().__init__()
        self.hidden_size = int(hidden_size)
        self.n_routed_experts = int(n_routed_experts)
        self.n_shared_experts = int(n_shared_experts)
        self.num_experts_per_tok = int(num_experts_per_tok)
        self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts,
                                 self.num_experts_per_tok)
        self.experts = nn.ModuleList([
            SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary)
            for _ in range(self.n_routed_experts)
        ])
        if self.n_shared_experts > 0:
            shared_inter = max(1, moe_intermediate_size * self.n_shared_experts)
            self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter,
                                            use_ternary=use_ternary)
        else:
            self.shared_experts = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_shape = x.shape
        flat = x.reshape(-1, self.hidden_size)
        N = flat.size(0)

        topk_idx, topk_w = self.gate(flat)                           # [N, k]
        out = torch.zeros_like(flat)

        # Per-expert dispatch via boolean masks: avoids the global argsort and
        # ``bincount`` of the previous draft and keeps the structure compatible
        # with torch.compile.
        for e in range(self.n_routed_experts):
            match = (topk_idx == e)
            if not match.any():
                continue
            # Token positions and per-pair weights for this expert.
            tok_pos, slot_pos = match.nonzero(as_tuple=True)
            w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype)
            y = self.experts[e](flat.index_select(0, tok_pos))
            out.index_add_(0, tok_pos, y * w)

        if self.shared_experts is not None:
            out = out + self.shared_experts(flat)

        return out.reshape(orig_shape)


__all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"]