DECO-1.2B / modeling_blockffn.py
Raincleared's picture
Upload folder using huggingface_hub
8563fb4 verified
raw
history blame
42.1 kB
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
import math
import torch
from torch import nn
import tree
from abc import ABC, abstractmethod
from fmoe.linear import MOELinear
from fmoe.functions import prepare_forward, MOEScatter, MOEGather
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.generic import check_model_inputs
from .configuration_blockffn import BlockFFNConfig
logger = logging.get_logger(__name__)
@use_kernel_forward_from_hub("RMSNorm")
class BlockFFNRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class BlockFFNRotaryEmbedding(nn.Module):
def __init__(self, config: BlockFFNConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SimpleLayerNorm(nn.Module):
def __init__(self, dim_norm: int, fixed: bool = False, init_var: float = 1.0):
super().__init__()
self.dim_norm = dim_norm
self.fixed = fixed
if self.fixed:
self.weight = init_var
else:
self.weight = torch.nn.Parameter(torch.full((self.dim_norm,), init_var))
@torch.compile
def forward(self, x: torch.Tensor):
return x * self.weight
class BlockFFNMLP(nn.Module):
def __init__(self, config: BlockFFNConfig, intermediate_size: int = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class BlockFFNRouter(nn.Module):
def __init__(self, config: BlockFFNConfig):
super().__init__()
self.config = config
self.num_experts = self.config.num_experts
if self.config.moe_router_dtype == "fp32":
self.router_dtype = torch.float32
elif self.config.moe_router_dtype == "fp64":
self.router_dtype = torch.float64
elif self.config.moe_router_dtype == "bf16":
self.router_dtype = torch.bfloat16
else:
raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_experts, self.config.hidden_size), dtype=self.router_dtype)
)
def forward(self, x: torch.Tensor):
return nn.functional.linear(x.to(self.router_dtype), self.weight)
class NormSiLU(nn.Module):
def __init__(self, config: BlockFFNConfig):
super().__init__()
self.num_blocks, self.block_size = config.num_experts, config.moe_ffn_hidden_size
self.activate_fn_type = config.expert_act_func
assert self.activate_fn_type in ["norm_silu", "norm_silu_norms", "norm_silu_nomean", "silu"]
self.rms_norm = None
if self.activate_fn_type not in ["norm_silu_norms", "silu"]:
self.rms_norm = BlockFFNRMSNorm(config.moe_ffn_hidden_size, eps=config.norm_epsilon)
self.silu = torch.nn.SiLU()
@torch.compile
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
assert hidden.ndim == 2
if self.activate_fn_type not in ["norm_silu_nomean", "silu"]:
hidden = hidden - torch.mean(hidden, dim=-1, keepdim=True)
if self.activate_fn_type not in ["norm_silu_norms", "silu"]:
return self.silu(self.rms_norm(hidden.view(hidden.shape[0], self.num_blocks, self.block_size)))
else:
return self.silu(hidden)
class BlockFFNLayer(nn.Module):
def __init__(self, config: BlockFFNConfig):
super(BlockFFNLayer, self).__init__()
self.config = config
self.num_experts, self.dim_expert, self.hidden_size = \
config.num_experts, config.moe_ffn_hidden_size, config.hidden_size
self.dim_shared_expert = config.moe_shared_expert_intermediate_size
self.router_norm_type = config.router_norm_type
self.moe_router = BlockFFNRouter(self.config)
assert config.router_act_func == "relu"
self.router_act = nn.ReLU()
if config.router_norm_type == "simple":
self.router_norm = SimpleLayerNorm(
dim_norm=(1 if self.config.router_norm_scalar else config.num_experts),
fixed=config.router_norm_fixed,
init_var=config.router_norm_init_var,
)
elif config.router_norm_type == "rms":
self.router_norm = BlockFFNRMSNorm(self.config.num_experts, eps=config.norm_epsilon)
else:
raise NotImplementedError
self.expert_gated = not config.expert_not_gated
if self.expert_gated:
self.expert_gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
self.expert_up_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias)
assert config.expert_act_norm_type == "normal"
self.expert_act = NormSiLU(self.config)
self.expert_down_proj = nn.Linear(self.num_experts * self.dim_expert, self.hidden_size, bias=config.mlp_bias)
self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
if self.use_shared_expert:
self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
self.enable_expert_bias = config.moe_router_enable_expert_bias
if self.enable_expert_bias:
self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
self.expert_bias_apply_method = config.moe_expert_bias_apply_method
def apply_expert_bias(self, router_scores: torch.Tensor) -> torch.Tensor:
if self.expert_bias_apply_method == "base":
scores_for_routing = router_scores + self.expert_bias
elif self.expert_bias_apply_method == "rms":
variance = router_scores.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
scores_for_routing = router_scores + self.expert_bias.unsqueeze(0) * torch.sqrt(variance)
else:
raise NotImplementedError(f"invalid apply method: {self.expert_bias_apply_method}")
return scores_for_routing
def forward(self, hidden_states: torch.Tensor):
ori_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
seq_len = hidden_states.shape[0]
# router module forward
raw_router_score = self.moe_router(hidden_states) # [seq_len, num_experts]
if self.enable_expert_bias:
scores_for_routing = self.apply_expert_bias(raw_router_score)
router_score = self.router_act(raw_router_score) * torch.gt(scores_for_routing, 0).type_as(raw_router_score)
else:
router_score = self.router_act(raw_router_score)
router_score = self.router_norm(router_score)
# expert module forward
x_in = self.expert_up_proj(hidden_states) # [seq_len, num_experts * dim_expert]
if self.expert_gated:
x_gate = self.expert_gate_proj(hidden_states)
x_gate = self.expert_act(x_gate)
if x_gate.ndim == 3:
x_in = x_in.view(seq_len, self.num_experts, self.dim_expert)
x_in = x_in * x_gate
else:
x_in = self.expert_act(x_in)
if x_in.ndim == 3:
scored_x_in = x_in * router_score.type_as(hidden_states).unsqueeze(-1)
else:
scored_x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) * router_score.type_as(hidden_states).unsqueeze(-1)
output = self.expert_down_proj(scored_x_in.view(seq_len, self.num_experts * self.dim_expert))
if self.use_shared_expert:
output = output + self.shared_experts(hidden_states)
return output.view(*ori_shape)
class BaseRouter(ABC, nn.Module):
"""Base Router class"""
def __init__(self, config: BlockFFNConfig) -> None:
super().__init__()
self.config = config
self.num_experts = self.config.num_experts
if self.config.moe_router_dtype == "fp32":
self.router_dtype = torch.float32
elif self.config.moe_router_dtype == "fp64":
self.router_dtype = torch.float64
elif self.config.moe_router_dtype == "bf16":
self.router_dtype = torch.bfloat16
else:
raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.")
self.weight = torch.nn.Parameter(
torch.empty((self.num_experts, self.config.hidden_size), dtype=self.router_dtype)
)
def gating(self, input: torch.Tensor):
return torch.nn.functional.linear(input.to(self.router_dtype), self.weight.to(self.router_dtype))
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
probabilities and mapping.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
class TopKRouter(BaseRouter):
"""Route each token to the top-k experts."""
def __init__(self, config: BlockFFNConfig) -> None:
super().__init__(config)
self.config = config
self.topk = self.config.moe_router_topk
self.score_function = self.config.moe_router_score_function
self.use_pre_softmax = self.config.moe_router_pre_softmax
self.scaling_factor = self.config.moe_router_topk_scaling_factor
self.enable_expert_bias = self.config.moe_router_enable_expert_bias
if self.enable_expert_bias:
self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
else:
self.expert_bias = None
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
We keep it in float32 to avoid routing errors when updating the expert_bias.
"""
if hasattr(self, 'expert_bias') and self.expert_bias is not None:
if self.expert_bias.dtype != torch.float32:
self.expert_bias.data = self.expert_bias.data.to(torch.float32)
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment,
with shape [num_tokens, num_experts].
"""
logits = logits.view(-1, self.num_experts)
if self.score_function == "softmax":
if self.use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = torch.topk(scores, k=self.topk, dim=1)
else:
scores, top_indices = torch.topk(logits, k=self.topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif self.score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
if self.expert_bias is not None:
scores_for_routing = scores + self.expert_bias
_, top_indices = torch.topk(scores_for_routing, k=self.topk, dim=1)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = torch.topk(scores, k=self.topk, dim=1)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {self.score_function}")
if self.scaling_factor:
probs = probs * self.scaling_factor
return probs, top_indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self._maintain_float32_expert_bias()
logits = self.gating(input)
top_scores, top_indices = self.routing(logits)
return top_scores, top_indices
class ReMoERouter(BaseRouter):
def __init__(self, config: BlockFFNConfig) -> None:
super().__init__(config)
self.config = config
self.router_act = torch.nn.ReLU()
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment,
with shape [num_tokens, num_experts].
"""
logits = logits.view(-1, self.num_experts)
router_score = self.router_act(logits)
routing_map = router_score > 0
sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
sorted_map = sorted_probs <= 0
sorted_indices = torch.where(sorted_map, -1, sorted_indices)
max_valid_num = max(sorted_probs.size(-1) - torch.min(torch.sum(sorted_map, dim=-1)).item(), 1)
assert torch.all(sorted_map[:, max_valid_num:])
sorted_probs = sorted_probs[:, :max_valid_num]
sorted_indices = sorted_indices[:, :max_valid_num]
assert torch.sum(routing_map) == torch.sum(sorted_indices != -1)
return sorted_probs, sorted_indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
logits = self.gating(input)
top_scores, top_indices = self.routing(logits)
return top_scores, top_indices
class TopPRouter(BaseRouter):
def __init__(self, config: BlockFFNConfig) -> None:
super().__init__(config)
self.config = config
self.top_p = config.moe_router_topp
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment,
with shape [num_tokens, num_experts].
"""
logits = logits.view(-1, self.num_experts)
router_score = torch.abs(logits)
router_score = router_score / (router_score.sum(dim=-1, keepdim=True) + 1e-20)
sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs > self.top_p
threshold_indices = mask.long().argmax(dim=-1)
threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool()
mask = mask & ~threshold_mask
sorted_indices = torch.where(mask, -1, sorted_indices)
sorted_probs = torch.where(mask, 0.0, sorted_probs)
max_valid_num = max(mask.size(-1) - torch.min(torch.sum(mask, dim=-1)).item(), 1)
assert torch.all(mask[:, max_valid_num:])
sorted_indices = sorted_indices[:, :max_valid_num]
sorted_probs = sorted_probs[:, :max_valid_num]
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
return sorted_probs, sorted_indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
logits = self.gating(input)
top_scores, top_indices = self.routing(logits)
return top_scores, top_indices
class FastTopKCalculator:
def __init__(self, num_experts: int):
self.num_experts = num_experts
def fmoe_sparse_topk_forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, experts: torch.nn.Module):
(
pos,
local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = prepare_forward(topk_indices, self.num_experts, 1)
topk = 1
if len(topk_indices.shape) == 2:
topk = topk_indices.shape[1]
def scatter_func(tensor):
return MOEScatter.apply(
tensor,
torch.div(pos, topk, rounding_mode='floor'),
local_expert_count,
global_expert_count,
fwd_batch_size,
1,
)
x = tree.map_structure(scatter_func, hidden_states)
x = experts(x, fwd_expert_count, topk_indices=topk_indices)
out_batch_size = tree.flatten(hidden_states)[0].shape[0]
if len(topk_indices.shape) == 2:
out_batch_size *= topk_indices.shape[1]
def gather_func(tensor):
return MOEGather.apply(
tensor,
pos,
local_expert_count,
global_expert_count,
out_batch_size,
1,
)
outp = tree.map_structure(gather_func, x)
return outp
def forward(self, hidden_states, topk_indices, topk_weights, experts):
assert topk_indices.shape == topk_weights.shape
top_k = topk_indices.shape[-1]
dim3 = hidden_states.ndim == 3
if dim3:
batch_size, seq_len, dim = hidden_states.shape
hidden_states = hidden_states.view(batch_size * seq_len, dim)
else:
assert hidden_states.ndim == 2
batch_size, (seq_len, dim) = -1, hidden_states.shape
fwd = self.fmoe_sparse_topk_forward(hidden_states, topk_indices, experts)
def view_func(tensor):
n_dim = tensor.shape[-1]
tensor = tensor.view(-1, top_k, n_dim)
return tensor
moe_output = tree.map_structure(view_func, fwd)
topk_weights = topk_weights.unsqueeze(1)
def bmm_func(tensor):
n_dim = tensor.shape[-1]
tensor = torch.bmm(topk_weights, tensor).reshape(-1, n_dim)
return tensor
moe_output = tree.map_structure(bmm_func, moe_output)
if dim3:
moe_output = moe_output.view(batch_size, seq_len, -1)
return moe_output
class MoELinearExperts(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
num_experts: int,
ffn_bias: bool,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.weight = torch.nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
self.bias = None
if ffn_bias:
self.bias = torch.nn.Parameter(torch.empty(num_experts, dim_out))
def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor):
x = MOELinear.apply(x, fwd_expert_count, self.weight, self.bias)
return x
class MoEGatedExperts(nn.Module):
def __init__(
self,
dim_in: int,
dim_ff: int,
is_gated: bool,
act_name: str,
num_experts: int,
ffn_bias: bool = False,
):
super().__init__()
self.is_gated = is_gated
self.dim_in, self.dim_ff, self.num_experts = dim_in, dim_ff, num_experts
if self.is_gated:
self.gate_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
self.up_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias)
self.down_proj = MoELinearExperts(dim_ff, dim_in, num_experts, ffn_bias)
self.act_fn = ACT2FN[act_name]
def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor, **kwargs) -> torch.Tensor:
if self.is_gated:
gate_score = self.gate_proj(x, fwd_expert_count)
up_proj = self.up_proj(x, fwd_expert_count)
x = up_proj * self.act_fn(gate_score)
else:
up_score = self.up_proj(x, fwd_expert_count)
x = self.act_fn(up_score)
x = self.down_proj(x, fwd_expert_count)
return x
class VanillaMoELayer(nn.Module):
def __init__(self, config: BlockFFNConfig):
super(VanillaMoELayer, self).__init__()
self.config = config
# Initialize router
if config.router_type == "topk":
self.router = TopKRouter(config=self.config)
elif config.router_type == "remoe":
self.router = ReMoERouter(config=self.config)
elif config.router_type == "topp":
self.router = TopPRouter(config=self.config)
else:
raise NotImplementedError(f"Router type {config.router_type} not implemented.")
self.mix_calculator = FastTopKCalculator(num_experts=self.config.num_experts)
# Initialize experts
self.experts = MoEGatedExperts(
dim_in=self.config.hidden_size,
dim_ff=self.config.moe_ffn_hidden_size,
is_gated=not self.config.expert_not_gated,
act_name="silu",
num_experts=self.config.num_experts,
)
self.dim_shared_expert = self.config.moe_shared_expert_intermediate_size
self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0
if self.use_shared_expert:
self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert)
def forward(self, hidden_states: torch.Tensor):
top_scores, top_indices = self.router(hidden_states)
y = self.mix_calculator.forward(
hidden_states=hidden_states,
topk_indices=top_indices.contiguous(),
topk_weights=top_scores.type_as(hidden_states),
experts=self.experts,
)
if self.shared_experts is not None:
y = y + self.shared_experts(hidden_states)
return y
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class BlockFFNAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: BlockFFNConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class BlockFFNDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlockFFNConfig, layer_idx: int, is_moe_layer: bool):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = BlockFFNAttention(config=config, layer_idx=layer_idx)
if is_moe_layer:
if config.use_blockffn:
self.mlp = BlockFFNLayer(config)
elif config.router_type in ["topk", "remoe", "topp"]:
self.mlp = VanillaMoELayer(config)
else:
raise NotImplementedError
else:
self.mlp = BlockFFNMLP(config)
self.input_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if self.config.use_mup:
hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
else:
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.config.use_mup:
hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers))
else:
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class BlockFFNPreTrainedModel(PreTrainedModel):
config: BlockFFNConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BlockFFNDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": BlockFFNDecoderLayer,
"attentions": BlockFFNAttention,
}
@auto_docstring
class BlockFFNModel(BlockFFNPreTrainedModel):
def __init__(self, config: BlockFFNConfig):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.moe_layer_freq = eval(config.moe_layer_freq) if isinstance(config.moe_layer_freq, str) else config.moe_layer_freq
assert len(self.moe_layer_freq) == config.num_layers
self.layers = nn.ModuleList(
[BlockFFNDecoderLayer(config, layer_idx, bool(self.moe_layer_freq[layer_idx])) for layer_idx in range(config.num_layers)]
)
self.norm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon)
self.rotary_emb = BlockFFNRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if self.config.use_mup:
inputs_embeds = inputs_embeds * self.config.mup_emb_scale
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class BlockFFNForCausalLM(BlockFFNPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: BlockFFNConfig):
super().__init__(config)
self.config = config
self.model = BlockFFNModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if self.config.use_mup:
hidden_states = hidden_states / self.config.mup_width_scale
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"BlockFFNForCausalLM",
"BlockFFNModel",
"BlockFFNPreTrainedModel",
]