# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import math import torch from torch import nn import tree from abc import ABC, abstractmethod from fmoe.linear import MOELinear from fmoe.functions import prepare_forward, MOEScatter, MOEGather from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.integrations import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from transformers.utils.generic import check_model_inputs from .configuration_blockffn import BlockFFNConfig logger = logging.get_logger(__name__) @use_kernel_forward_from_hub("RMSNorm") class BlockFFNRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class BlockFFNRotaryEmbedding(nn.Module): def __init__(self, config: BlockFFNConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class SimpleLayerNorm(nn.Module): def __init__(self, dim_norm: int, fixed: bool = False, init_var: float = 1.0): super().__init__() self.dim_norm = dim_norm self.fixed = fixed if self.fixed: self.weight = init_var else: self.weight = torch.nn.Parameter(torch.full((self.dim_norm,), init_var)) @torch.compile def forward(self, x: torch.Tensor): return x * self.weight class BlockFFNMLP(nn.Module): def __init__(self, config: BlockFFNConfig, intermediate_size: int = None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class BlockFFNRouter(nn.Module): def __init__(self, config: BlockFFNConfig): super().__init__() self.config = config self.num_experts = self.config.num_experts if self.config.moe_router_dtype == "fp32": self.router_dtype = torch.float32 elif self.config.moe_router_dtype == "fp64": self.router_dtype = torch.float64 elif self.config.moe_router_dtype == "bf16": self.router_dtype = torch.bfloat16 else: raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.") self.weight = torch.nn.Parameter( torch.empty((self.config.num_experts, self.config.hidden_size), dtype=self.router_dtype) ) def forward(self, x: torch.Tensor): return nn.functional.linear(x.to(self.router_dtype), self.weight) class NormSiLU(nn.Module): def __init__(self, config: BlockFFNConfig): super().__init__() self.num_blocks, self.block_size = config.num_experts, config.moe_ffn_hidden_size self.activate_fn_type = config.expert_act_func assert self.activate_fn_type in ["norm_silu", "norm_silu_norms", "norm_silu_nomean", "silu"] self.rms_norm = None if self.activate_fn_type not in ["norm_silu_norms", "silu"]: self.rms_norm = BlockFFNRMSNorm(config.moe_ffn_hidden_size, eps=config.norm_epsilon) self.silu = torch.nn.SiLU() @torch.compile def forward(self, hidden: torch.Tensor) -> torch.Tensor: assert hidden.ndim == 2 if self.activate_fn_type not in ["norm_silu_nomean", "silu"]: hidden = hidden - torch.mean(hidden, dim=-1, keepdim=True) if self.activate_fn_type not in ["norm_silu_norms", "silu"]: return self.silu(self.rms_norm(hidden.view(hidden.shape[0], self.num_blocks, self.block_size))) else: return self.silu(hidden) class BlockFFNLayer(nn.Module): def __init__(self, config: BlockFFNConfig): super(BlockFFNLayer, self).__init__() self.config = config self.num_experts, self.dim_expert, self.hidden_size = \ config.num_experts, config.moe_ffn_hidden_size, config.hidden_size self.dim_shared_expert = config.moe_shared_expert_intermediate_size self.router_norm_type = config.router_norm_type self.moe_router = BlockFFNRouter(self.config) assert config.router_act_func == "relu" self.router_act = nn.ReLU() if config.router_norm_type == "simple": self.router_norm = SimpleLayerNorm( dim_norm=(1 if self.config.router_norm_scalar else config.num_experts), fixed=config.router_norm_fixed, init_var=config.router_norm_init_var, ) elif config.router_norm_type == "rms": self.router_norm = BlockFFNRMSNorm(self.config.num_experts, eps=config.norm_epsilon) else: raise NotImplementedError self.expert_gated = not config.expert_not_gated if self.expert_gated: self.expert_gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias) self.expert_up_proj = nn.Linear(self.hidden_size, self.num_experts * self.dim_expert, bias=config.mlp_bias) assert config.expert_act_norm_type == "normal" self.expert_act = NormSiLU(self.config) self.expert_down_proj = nn.Linear(self.num_experts * self.dim_expert, self.hidden_size, bias=config.mlp_bias) self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0 if self.use_shared_expert: self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert) def forward(self, hidden_states: torch.Tensor): ori_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) seq_len = hidden_states.shape[0] # router module forward raw_router_score = self.moe_router(hidden_states) # [seq_len, num_experts] router_score = self.router_act(raw_router_score) router_score = self.router_norm(router_score) # expert module forward x_in = self.expert_up_proj(hidden_states) # [seq_len, num_experts * dim_expert] if self.expert_gated: x_gate = self.expert_gate_proj(hidden_states) x_gate = self.expert_act(x_gate) if x_gate.ndim == 3: x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) x_in = x_in * x_gate else: x_in = self.expert_act(x_in) if x_in.ndim == 3: scored_x_in = x_in * router_score.type_as(hidden_states).unsqueeze(-1) else: scored_x_in = x_in.view(seq_len, self.num_experts, self.dim_expert) * router_score.type_as(hidden_states).unsqueeze(-1) output = self.expert_down_proj(scored_x_in.view(seq_len, self.num_experts * self.dim_expert)) if self.use_shared_expert: output = output + self.shared_experts(hidden_states) return output.view(*ori_shape) class BaseRouter(ABC, nn.Module): """Base Router class""" def __init__(self, config: BlockFFNConfig) -> None: super().__init__() self.config = config self.num_experts = self.config.num_experts if self.config.moe_router_dtype == "fp32": self.router_dtype = torch.float32 elif self.config.moe_router_dtype == "fp64": self.router_dtype = torch.float64 elif self.config.moe_router_dtype == "bf16": self.router_dtype = torch.bfloat16 else: raise NotImplementedError(f"{self.config.moe_router_dtype} is not supported.") self.weight = torch.nn.Parameter( torch.empty((self.num_experts, self.config.hidden_size), dtype=self.router_dtype) ) def gating(self, input: torch.Tensor): return torch.nn.functional.linear(input.to(self.router_dtype), self.weight.to(self.router_dtype)) @abstractmethod def routing(self, logits: torch.Tensor): """Routing function. Args: logits (torch.Tensor): Logits tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment probabilities and mapping. """ raise NotImplementedError("Routing function not implemented.") @abstractmethod def forward(self, input: torch.Tensor): """ Forward pass of the router. Args: input (torch.Tensor): Input tensor. """ raise NotImplementedError("Forward function not implemented.") class TopKRouter(BaseRouter): """Route each token to the top-k experts.""" def __init__(self, config: BlockFFNConfig) -> None: super().__init__(config) self.config = config self.topk = self.config.moe_router_topk self.score_function = self.config.moe_router_score_function self.use_pre_softmax = self.config.moe_router_pre_softmax self.scaling_factor = self.config.moe_router_topk_scaling_factor self.enable_expert_bias = self.config.moe_router_enable_expert_bias if self.enable_expert_bias: self.expert_bias = torch.nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32)) else: self.expert_bias = None def _maintain_float32_expert_bias(self): """ Maintain the expert bias in float32. When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module. We keep it in float32 to avoid routing errors when updating the expert_bias. """ if hasattr(self, 'expert_bias') and self.expert_bias is not None: if self.expert_bias.dtype != torch.float32: self.expert_bias.data = self.expert_bias.data.to(torch.float32) def routing(self, logits: torch.Tensor): """Top-k routing function Args: logits (torch.Tensor): Logits tensor after gating. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mapping of token to experts assignment, with shape [num_tokens, num_experts]. """ logits = logits.view(-1, self.num_experts) if self.score_function == "softmax": if self.use_pre_softmax: scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) probs, top_indices = torch.topk(scores, k=self.topk, dim=1) else: scores, top_indices = torch.topk(logits, k=self.topk, dim=1) probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) elif self.score_function == "sigmoid": scores = torch.sigmoid(logits.float()).type_as(logits) if self.expert_bias is not None: scores_for_routing = scores + self.expert_bias _, top_indices = torch.topk(scores_for_routing, k=self.topk, dim=1) scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) else: scores, top_indices = torch.topk(scores, k=self.topk, dim=1) probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores else: raise ValueError(f"Invalid score_function: {self.score_function}") if self.scaling_factor: probs = probs * self.scaling_factor return probs, top_indices def forward(self, input: torch.Tensor): """ Forward pass of the router. Args: input (torch.Tensor): Input tensor. """ self._maintain_float32_expert_bias() logits = self.gating(input) top_scores, top_indices = self.routing(logits) return top_scores, top_indices class ReMoERouter(BaseRouter): def __init__(self, config: BlockFFNConfig) -> None: super().__init__(config) self.config = config self.router_act = torch.nn.ReLU() def routing(self, logits: torch.Tensor): """Top-k routing function Args: logits (torch.Tensor): Logits tensor after gating. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mapping of token to experts assignment, with shape [num_tokens, num_experts]. """ logits = logits.view(-1, self.num_experts) router_score = self.router_act(logits) routing_map = router_score > 0 sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1) sorted_map = sorted_probs <= 0 sorted_indices = torch.where(sorted_map, -1, sorted_indices) max_valid_num = max(sorted_probs.size(-1) - torch.min(torch.sum(sorted_map, dim=-1)).item(), 1) assert torch.all(sorted_map[:, max_valid_num:]) sorted_probs = sorted_probs[:, :max_valid_num] sorted_indices = sorted_indices[:, :max_valid_num] assert torch.sum(routing_map) == torch.sum(sorted_indices != -1) return sorted_probs, sorted_indices def forward(self, input: torch.Tensor): """ Forward pass of the router. Args: input (torch.Tensor): Input tensor. """ logits = self.gating(input) top_scores, top_indices = self.routing(logits) return top_scores, top_indices class TopPRouter(BaseRouter): def __init__(self, config: BlockFFNConfig) -> None: super().__init__(config) self.config = config self.top_p = config.moe_router_topp def routing(self, logits: torch.Tensor): """Top-k routing function Args: logits (torch.Tensor): Logits tensor after gating. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mapping of token to experts assignment, with shape [num_tokens, num_experts]. """ logits = logits.view(-1, self.num_experts) router_score = torch.abs(logits) router_score = router_score / (router_score.sum(dim=-1, keepdim=True) + 1e-20) sorted_probs, sorted_indices = torch.sort(router_score, descending=True, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) mask = cumulative_probs > self.top_p threshold_indices = mask.long().argmax(dim=-1) threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool() mask = mask & ~threshold_mask sorted_indices = torch.where(mask, -1, sorted_indices) sorted_probs = torch.where(mask, 0.0, sorted_probs) max_valid_num = max(mask.size(-1) - torch.min(torch.sum(mask, dim=-1)).item(), 1) assert torch.all(mask[:, max_valid_num:]) sorted_indices = sorted_indices[:, :max_valid_num] sorted_probs = sorted_probs[:, :max_valid_num] sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) return sorted_probs, sorted_indices def forward(self, input: torch.Tensor): """ Forward pass of the router. Args: input (torch.Tensor): Input tensor. """ logits = self.gating(input) top_scores, top_indices = self.routing(logits) return top_scores, top_indices class FastTopKCalculator: def __init__(self, num_experts: int): self.num_experts = num_experts def fmoe_sparse_topk_forward(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, experts: torch.nn.Module): ( pos, local_expert_count, global_expert_count, fwd_expert_count, fwd_batch_size, ) = prepare_forward(topk_indices, self.num_experts, 1) topk = 1 if len(topk_indices.shape) == 2: topk = topk_indices.shape[1] def scatter_func(tensor): return MOEScatter.apply( tensor, torch.div(pos, topk, rounding_mode='floor'), local_expert_count, global_expert_count, fwd_batch_size, 1, ) x = tree.map_structure(scatter_func, hidden_states) x = experts(x, fwd_expert_count, topk_indices=topk_indices) out_batch_size = tree.flatten(hidden_states)[0].shape[0] if len(topk_indices.shape) == 2: out_batch_size *= topk_indices.shape[1] def gather_func(tensor): return MOEGather.apply( tensor, pos, local_expert_count, global_expert_count, out_batch_size, 1, ) outp = tree.map_structure(gather_func, x) return outp def forward(self, hidden_states, topk_indices, topk_weights, experts): assert topk_indices.shape == topk_weights.shape top_k = topk_indices.shape[-1] dim3 = hidden_states.ndim == 3 if dim3: batch_size, seq_len, dim = hidden_states.shape hidden_states = hidden_states.view(batch_size * seq_len, dim) else: assert hidden_states.ndim == 2 batch_size, (seq_len, dim) = -1, hidden_states.shape fwd = self.fmoe_sparse_topk_forward(hidden_states, topk_indices, experts) def view_func(tensor): n_dim = tensor.shape[-1] tensor = tensor.view(-1, top_k, n_dim) return tensor moe_output = tree.map_structure(view_func, fwd) topk_weights = topk_weights.unsqueeze(1) def bmm_func(tensor): n_dim = tensor.shape[-1] tensor = torch.bmm(topk_weights, tensor).reshape(-1, n_dim) return tensor moe_output = tree.map_structure(bmm_func, moe_output) if dim3: moe_output = moe_output.view(batch_size, seq_len, -1) return moe_output class MoELinearExperts(nn.Module): def __init__( self, dim_in: int, dim_out: int, num_experts: int, ffn_bias: bool, ): super().__init__() self.dim_in = self.in_features = dim_in self.dim_out = self.out_features = dim_out self.weight = torch.nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) self.bias = None if ffn_bias: self.bias = torch.nn.Parameter(torch.empty(num_experts, dim_out)) def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor): x = MOELinear.apply(x, fwd_expert_count, self.weight, self.bias) return x class MoEGatedExperts(nn.Module): def __init__( self, dim_in: int, dim_ff: int, is_gated: bool, act_name: str, num_experts: int, ffn_bias: bool = False, ): super().__init__() self.is_gated = is_gated self.dim_in, self.dim_ff, self.num_experts = dim_in, dim_ff, num_experts if self.is_gated: self.gate_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias) self.up_proj = MoELinearExperts(dim_in, dim_ff, num_experts, ffn_bias) self.down_proj = MoELinearExperts(dim_ff, dim_in, num_experts, ffn_bias) self.act_fn = ACT2FN[act_name] def forward(self, x: torch.Tensor, fwd_expert_count: torch.Tensor, **kwargs) -> torch.Tensor: if self.is_gated: gate_score = self.gate_proj(x, fwd_expert_count) up_proj = self.up_proj(x, fwd_expert_count) x = up_proj * self.act_fn(gate_score) else: up_score = self.up_proj(x, fwd_expert_count) x = self.act_fn(up_score) x = self.down_proj(x, fwd_expert_count) return x class VanillaMoELayer(nn.Module): def __init__(self, config: BlockFFNConfig): super(VanillaMoELayer, self).__init__() self.config = config # Initialize router if config.router_type == "topk": self.router = TopKRouter(config=self.config) elif config.router_type == "remoe": self.router = ReMoERouter(config=self.config) elif config.router_type == "topp": self.router = TopPRouter(config=self.config) else: raise NotImplementedError(f"Router type {config.router_type} not implemented.") self.mix_calculator = FastTopKCalculator(num_experts=self.config.num_experts) # Initialize experts self.experts = MoEGatedExperts( dim_in=self.config.hidden_size, dim_ff=self.config.moe_ffn_hidden_size, is_gated=not self.config.expert_not_gated, act_name="silu", num_experts=self.config.num_experts, ) self.dim_shared_expert = self.config.moe_shared_expert_intermediate_size self.use_shared_expert = self.dim_shared_expert is not None and self.dim_shared_expert > 0 if self.use_shared_expert: self.shared_experts = BlockFFNMLP(self.config, intermediate_size=self.dim_shared_expert) def forward(self, hidden_states: torch.Tensor): top_scores, top_indices = self.router(hidden_states) y = self.mix_calculator.forward( hidden_states=hidden_states, topk_indices=top_indices.contiguous(), topk_weights=top_scores.type_as(hidden_states), experts=self.experts, ) if self.shared_experts is not None: y = y + self.shared_experts(hidden_states) return y def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class BlockFFNAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: BlockFFNConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_query_groups self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_query_groups * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class BlockFFNDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlockFFNConfig, layer_idx: int, is_moe_layer: bool): super().__init__() self.config = config self.hidden_size = config.hidden_size self.self_attn = BlockFFNAttention(config=config, layer_idx=layer_idx) if is_moe_layer: if config.use_blockffn: self.mlp = BlockFFNLayer(config) elif config.router_type in ["topk", "remoe", "topp"]: self.mlp = VanillaMoELayer(config) else: raise NotImplementedError else: self.mlp = BlockFFNMLP(config) self.input_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) if self.config.use_mup: hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers)) else: hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if self.config.use_mup: hidden_states = residual + hidden_states * (self.config.mup_depth_scale / math.sqrt(self.config.num_layers)) else: hidden_states = residual + hidden_states return hidden_states @auto_docstring class BlockFFNPreTrainedModel(PreTrainedModel): config: BlockFFNConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BlockFFNDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": BlockFFNDecoderLayer, "attentions": BlockFFNAttention, } @auto_docstring class BlockFFNModel(BlockFFNPreTrainedModel): def __init__(self, config: BlockFFNConfig): super().__init__(config) self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.moe_layer_freq = eval(config.moe_layer_freq) if isinstance(config.moe_layer_freq, str) else config.moe_layer_freq assert len(self.moe_layer_freq) == config.num_layers self.layers = nn.ModuleList( [BlockFFNDecoderLayer(config, layer_idx, bool(self.moe_layer_freq[layer_idx])) for layer_idx in range(config.num_layers)] ) self.norm = BlockFFNRMSNorm(config.hidden_size, eps=config.norm_epsilon) self.rotary_emb = BlockFFNRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if self.config.use_mup: inputs_embeds = inputs_embeds * self.config.mup_emb_scale if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position: torch.Tensor = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) @auto_docstring class BlockFFNForCausalLM(BlockFFNPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: BlockFFNConfig): super().__init__(config) self.config = config self.model = BlockFFNModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state if self.config.use_mup: hidden_states = hidden_states / self.config.mup_width_scale # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "BlockFFNForCausalLM", "BlockFFNModel", "BlockFFNPreTrainedModel", ]