akshitab commited on
Commit
3bffe48
·
verified ·
1 Parent(s): 67a4eed

Emo rename: upload modeling_emo.py

Browse files
Files changed (1) hide show
  1. modeling_emo.py +1205 -0
modeling_emo.py ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/emo/modular_emo.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_emo.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import MoeModelOutputWithPast
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import ModelOutput, TransformersKwargs, auto_docstring
39
+ from transformers.utils.deprecation import deprecate_kwarg
40
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
41
+
42
+ from .configuration_emo import EmoConfig
43
+
44
+
45
+ @use_kernel_forward_from_hub("RMSNorm")
46
+ class EmoRMSNorm(nn.Module):
47
+ def __init__(self, hidden_size, eps=1e-6):
48
+ """
49
+ EmoRMSNorm is equivalent to T5LayerNorm
50
+ """
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, hidden_states):
56
+ input_dtype = hidden_states.dtype
57
+ hidden_states = hidden_states.to(torch.float32)
58
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
59
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
60
+ return (self.weight * hidden_states).to(input_dtype)
61
+
62
+ def extra_repr(self):
63
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
64
+
65
+
66
+ class EmoMLP(nn.Module):
67
+ def __init__(self, config):
68
+ super().__init__()
69
+ self.config = config
70
+ self.hidden_size = config.hidden_size
71
+ self.intermediate_size = config.intermediate_size
72
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75
+ self.act_fn = ACT2FN[config.hidden_act]
76
+ # Some densefirst models were accidentally trained with bias=True on dense MLPs
77
+ # (OLMo Core's FeedForwardConfig defaults bias to True when not explicitly set).
78
+ # We support loading those weights here.
79
+ dense_mlp_bias = getattr(config, "dense_mlp_bias", False)
80
+ if dense_mlp_bias:
81
+ del self.gate_proj
82
+ del self.up_proj
83
+ del self.down_proj
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
86
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
87
+
88
+ def forward(self, x):
89
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
+ return down_proj
91
+
92
+
93
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
94
+ """
95
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
96
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
97
+ """
98
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
99
+ if n_rep == 1:
100
+ return hidden_states
101
+ hidden_states = hidden_states[:, :, None, :, :].expand(
102
+ batch, num_key_value_heads, n_rep, slen, head_dim
103
+ )
104
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
105
+
106
+
107
+ def eager_attention_forward(
108
+ module: nn.Module,
109
+ query: torch.Tensor,
110
+ key: torch.Tensor,
111
+ value: torch.Tensor,
112
+ attention_mask: Optional[torch.Tensor],
113
+ scaling: float,
114
+ dropout: float = 0.0,
115
+ **kwargs: Unpack[TransformersKwargs],
116
+ ):
117
+ key_states = repeat_kv(key, module.num_key_value_groups)
118
+ value_states = repeat_kv(value, module.num_key_value_groups)
119
+
120
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
121
+ if attention_mask is not None:
122
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
123
+ attn_weights = attn_weights + causal_mask
124
+
125
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
126
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
127
+ attn_output = torch.matmul(attn_weights, value_states)
128
+ attn_output = attn_output.transpose(1, 2).contiguous()
129
+
130
+ return attn_output, attn_weights
131
+
132
+
133
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
134
+ """Applies Rotary Position Embedding to the query and key tensors.
135
+
136
+ Args:
137
+ q (`torch.Tensor`): The query tensor.
138
+ k (`torch.Tensor`): The key tensor.
139
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
140
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
141
+ position_ids (`torch.Tensor`, *optional*):
142
+ Deprecated and unused.
143
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
144
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
145
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
146
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
147
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
148
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
149
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
150
+ Returns:
151
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
152
+ """
153
+ q_type, k_type = q.dtype, k.dtype
154
+ cos = cos.unsqueeze(unsqueeze_dim)
155
+ sin = sin.unsqueeze(unsqueeze_dim)
156
+ q_embed = (q * cos) + (rotate_half(q) * sin)
157
+ k_embed = (k * cos) + (rotate_half(k) * sin)
158
+ return q_embed.to(q_type), k_embed.to(k_type)
159
+
160
+
161
+ def rotate_half(x):
162
+ """Rotates half the hidden dims of the input."""
163
+ x1 = x[..., : x.shape[-1] // 2]
164
+ x2 = x[..., x.shape[-1] // 2 :]
165
+ return torch.cat((-x2, x1), dim=-1)
166
+
167
+
168
+ class EmoAttention(nn.Module):
169
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
170
+
171
+ def __init__(self, config: EmoConfig, layer_idx: Optional[int] = None):
172
+ super().__init__()
173
+ self.config = config
174
+ self.layer_idx = layer_idx
175
+ self.head_dim = getattr(
176
+ config, "head_dim", config.hidden_size // config.num_attention_heads
177
+ )
178
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
179
+ self.scaling = self.head_dim**-0.5
180
+ self.attention_dropout = config.attention_dropout
181
+ self.is_causal = True
182
+
183
+ self.q_proj = nn.Linear(
184
+ config.hidden_size,
185
+ config.num_attention_heads * self.head_dim,
186
+ bias=config.attention_bias,
187
+ )
188
+ self.k_proj = nn.Linear(
189
+ config.hidden_size,
190
+ config.num_key_value_heads * self.head_dim,
191
+ bias=config.attention_bias,
192
+ )
193
+ self.v_proj = nn.Linear(
194
+ config.hidden_size,
195
+ config.num_key_value_heads * self.head_dim,
196
+ bias=config.attention_bias,
197
+ )
198
+ self.o_proj = nn.Linear(
199
+ config.num_attention_heads * self.head_dim,
200
+ config.hidden_size,
201
+ bias=config.attention_bias,
202
+ )
203
+
204
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
205
+ def forward(
206
+ self,
207
+ hidden_states: torch.Tensor,
208
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
209
+ attention_mask: Optional[torch.Tensor],
210
+ past_key_values: Optional[Cache] = None,
211
+ cache_position: Optional[torch.LongTensor] = None,
212
+ **kwargs: Unpack[TransformersKwargs],
213
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
214
+ input_shape = hidden_states.shape[:-1]
215
+ hidden_shape = (*input_shape, -1, self.head_dim)
216
+
217
+ query_states = self.q_proj(hidden_states)
218
+ key_states = self.k_proj(hidden_states)
219
+ value_states = self.v_proj(hidden_states)
220
+
221
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
222
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
223
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
224
+
225
+ cos, sin = position_embeddings
226
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
227
+
228
+ if past_key_values is not None:
229
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
230
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
231
+ key_states, value_states = past_key_values.update(
232
+ key_states, value_states, self.layer_idx, cache_kwargs
233
+ )
234
+
235
+ attention_interface: Callable = eager_attention_forward
236
+ if self.config._attn_implementation != "eager":
237
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
238
+
239
+ attn_output, attn_weights = attention_interface(
240
+ self,
241
+ query_states,
242
+ key_states,
243
+ value_states,
244
+ attention_mask,
245
+ dropout=0.0 if not self.training else self.attention_dropout,
246
+ scaling=self.scaling,
247
+ **kwargs,
248
+ )
249
+
250
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
251
+ attn_output = self.o_proj(attn_output)
252
+ return attn_output, attn_weights
253
+
254
+
255
+ class EmoSparseMoeBlock(nn.Module):
256
+ def __init__(
257
+ self,
258
+ config,
259
+ num_experts: int,
260
+ num_shared_experts: int,
261
+ always_active_experts: Optional[list[int]] = None,
262
+ ):
263
+ super().__init__()
264
+ self.top_k = config.num_experts_per_tok
265
+ self.norm_topk_prob = config.norm_topk_prob
266
+
267
+ self.num_shared_experts = num_shared_experts
268
+ self.always_active_experts = always_active_experts
269
+ self.num_experts = num_experts
270
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
271
+ # Expert MLPs should never use dense_mlp_bias (that's only for dense FFN layers)
272
+ import copy
273
+
274
+ expert_config = copy.copy(config)
275
+ expert_config.dense_mlp_bias = False
276
+ self.experts = nn.ModuleList([EmoMLP(expert_config) for _ in range(self.num_experts)])
277
+
278
+ def _get_top_k_with_always_active(
279
+ self, scores: torch.Tensor
280
+ ) -> tuple[torch.Tensor, torch.Tensor]:
281
+ """
282
+ Select top-k experts where always_active_experts are always included.
283
+ Softmax is computed over all experts, then always-active are masked out for topk selection.
284
+ """
285
+ always_active = self.always_active_experts
286
+ num_always_active = len(always_active)
287
+ routed_top_k = self.top_k - num_always_active
288
+
289
+ # Mask out always-active experts so they aren't selected by topk.
290
+ masked_scores = scores.clone()
291
+ masked_scores[:, always_active] = float("-inf")
292
+
293
+ # Select top-(top_k - num_always_active) from the remaining experts.
294
+ if routed_top_k == 1:
295
+ _, routed_indices = masked_scores.max(dim=-1, keepdim=True)
296
+ else:
297
+ _, routed_indices = torch.topk(masked_scores, routed_top_k, dim=-1)
298
+
299
+ # Gather actual weights from original (unmasked) scores.
300
+ routed_weights = scores.gather(-1, routed_indices)
301
+
302
+ # Build always-active indices and weights.
303
+ always_active_tensor = torch.tensor(
304
+ always_active, device=scores.device, dtype=routed_indices.dtype
305
+ )
306
+ always_active_indices = always_active_tensor.unsqueeze(0).expand(
307
+ scores.shape[0], num_always_active
308
+ )
309
+ always_active_weights = scores.gather(-1, always_active_indices)
310
+
311
+ # Concatenate: always-active first, then routed.
312
+ selected_experts = torch.cat([always_active_indices, routed_indices], dim=-1)
313
+ routing_weights = torch.cat([always_active_weights, routed_weights], dim=-1)
314
+
315
+ return routing_weights, selected_experts
316
+
317
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
318
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
319
+ hidden_states = hidden_states.view(-1, hidden_dim)
320
+ # router_logits: (batch * sequence_length, n_experts)
321
+ router_logits = self.gate(hidden_states)
322
+
323
+ if self.always_active_experts is not None and len(self.always_active_experts) > 0:
324
+ # Use masking approach: softmax over all experts, mask always-active for topk
325
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
326
+ routing_weights, selected_experts = self._get_top_k_with_always_active(routing_weights)
327
+ elif self.num_shared_experts > 0:
328
+ # Legacy path: shared experts are the last N experts
329
+ # split the router logits into shared and unshared experts
330
+ router_logits_standard = router_logits[
331
+ :, : -self.num_shared_experts
332
+ ] # (batch * sequence_length, n_experts - num_shared_experts)
333
+ router_logits_shared = router_logits[
334
+ :, -self.num_shared_experts :
335
+ ] # (batch * sequence_length, num_shared_experts)
336
+
337
+ # compute the routing weights for the standard experts and shared experts separately
338
+ routing_weights_standard = F.softmax(router_logits_standard, dim=1, dtype=torch.float)
339
+ routing_weights_shared = F.softmax(router_logits_shared, dim=1, dtype=torch.float)
340
+
341
+ # select the routing weights and experts for the standard experts and shared experts separately
342
+ routing_weights_standard, selected_experts_standard = torch.topk(
343
+ routing_weights_standard, self.top_k - self.num_shared_experts, dim=-1
344
+ )
345
+ routing_weights_shared, selected_experts_shared = torch.topk(
346
+ routing_weights_shared, self.num_shared_experts, dim=-1
347
+ )
348
+
349
+ # concatenate the routing weights and selected experts for the standard experts and shared experts
350
+ routing_weights = torch.cat([routing_weights_standard, routing_weights_shared], dim=1)
351
+ selected_experts = torch.cat(
352
+ [
353
+ selected_experts_standard,
354
+ selected_experts_shared + (self.num_experts - self.num_shared_experts),
355
+ ],
356
+ dim=1,
357
+ ) # we need to add the offset to the selected experts for the shared experts since they are at the end of the router logits
358
+
359
+ # make sure there are self.top_k experts selected in total
360
+ assert (
361
+ routing_weights.shape
362
+ == selected_experts.shape
363
+ == (batch_size * sequence_length, self.top_k)
364
+ ), f"routing_weights and selected_experts should have the same shape of (batch_size * sequence_length, self.top_k), but got {routing_weights.shape} and {selected_experts.shape}"
365
+ else:
366
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
367
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
368
+
369
+ if self.norm_topk_prob:
370
+ if self.num_shared_experts > 0 or (
371
+ self.always_active_experts is not None and len(self.always_active_experts) > 0
372
+ ):
373
+ raise NotImplementedError(
374
+ "norm_topk_prob is not implemented for the case where num_shared_experts > 0 or always_active_experts is set"
375
+ )
376
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
377
+
378
+ # we cast back to the input dtype
379
+ routing_weights = routing_weights.to(hidden_states.dtype)
380
+
381
+ final_hidden_states = torch.zeros(
382
+ (batch_size * sequence_length, hidden_dim),
383
+ dtype=hidden_states.dtype,
384
+ device=hidden_states.device,
385
+ )
386
+
387
+ # One hot encode the selected experts to create an expert mask
388
+ # this will be used to easily index which expert is going to be selected
389
+ expert_mask = torch.nn.functional.one_hot(
390
+ selected_experts, num_classes=self.num_experts
391
+ ).permute(2, 1, 0)
392
+
393
+ # Loop over all available experts in the model and perform the computation on each expert
394
+ for expert_idx in range(self.num_experts):
395
+ expert_layer = self.experts[expert_idx]
396
+ idx, top_x = torch.where(expert_mask[expert_idx])
397
+
398
+ # Index the correct hidden states and compute the expert hidden state for
399
+ # the current expert. We need to make sure to multiply the output hidden
400
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
401
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
402
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
403
+
404
+ # However `index_add_` only support torch tensors for indexing so we'll use
405
+ # the `top_x` tensor here.
406
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
407
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
408
+ return final_hidden_states, router_logits
409
+
410
+
411
+ class EmoDecoderLayer(GradientCheckpointingLayer):
412
+ def __init__(
413
+ self,
414
+ config: EmoConfig,
415
+ layer_idx: int,
416
+ num_experts: int,
417
+ num_shared_experts: int,
418
+ always_active_experts: Optional[list[int]] = None,
419
+ ):
420
+ super().__init__()
421
+ self.hidden_size = config.hidden_size
422
+ self.self_attn = EmoAttention(config=config, layer_idx=layer_idx)
423
+
424
+ self.num_experts = num_experts
425
+
426
+ if num_experts == 0:
427
+ # Dense layer: use MLP with dense_intermediate_size
428
+ dense_intermediate_size = getattr(config, "dense_intermediate_size", None)
429
+ if dense_intermediate_size is None:
430
+ raise ValueError(
431
+ "num_experts=0 (dense layer) but config.dense_intermediate_size is not set. "
432
+ "Please set dense_intermediate_size in the config."
433
+ )
434
+ import copy
435
+
436
+ dense_config = copy.copy(config)
437
+ dense_config.intermediate_size = dense_intermediate_size
438
+ dense_config.dense_mlp_bias = getattr(config, "dense_mlp_bias", False)
439
+ self.mlp = EmoMLP(dense_config)
440
+ else:
441
+ self.mlp = EmoSparseMoeBlock(
442
+ config, num_experts, num_shared_experts, always_active_experts
443
+ )
444
+
445
+ self.pre_attention_layernorm = EmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
446
+ self.pre_feedforward_layernorm = EmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+
448
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
449
+ def forward(
450
+ self,
451
+ hidden_states: torch.Tensor,
452
+ attention_mask: Optional[torch.Tensor] = None,
453
+ position_ids: Optional[torch.LongTensor] = None,
454
+ past_key_values: Optional[Cache] = None,
455
+ cache_position: Optional[torch.LongTensor] = None,
456
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
457
+ **kwargs,
458
+ ) -> torch.FloatTensor:
459
+ """
460
+ Args:
461
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
462
+ attention_mask (`torch.FloatTensor`, *optional*):
463
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
464
+ query_sequence_length, key_sequence_length)` if default attention is used.
465
+ output_attentions (`bool`, *optional*):
466
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
467
+ returned tensors for more detail.
468
+ output_router_logits (`bool`, *optional*):
469
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
470
+ and should not be returned during inference.
471
+ use_cache (`bool`, *optional*):
472
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
473
+ (see `past_key_values`).
474
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
475
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
476
+ Indices depicting the position of the input sequence tokens in the sequence
477
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
478
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
479
+ with `head_dim` being the embedding dimension of each attention head.
480
+ kwargs (`dict`, *optional*):
481
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
482
+ into the model
483
+ """
484
+ residual = hidden_states
485
+ # apply norm before attention
486
+ hidden_states = self.pre_attention_layernorm(hidden_states)
487
+ # Self Attention
488
+ hidden_states, _ = self.self_attn(
489
+ hidden_states=hidden_states,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_values=past_key_values,
493
+ cache_position=cache_position,
494
+ position_embeddings=position_embeddings,
495
+ **kwargs,
496
+ )
497
+ hidden_states = residual + hidden_states
498
+
499
+ # Fully Connected
500
+ residual = hidden_states
501
+ # apply norm before feedforward
502
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
503
+ mlp_output = self.mlp(hidden_states)
504
+ if isinstance(mlp_output, tuple):
505
+ hidden_states, _ = mlp_output
506
+ else:
507
+ hidden_states = mlp_output
508
+ hidden_states = residual + hidden_states
509
+ return hidden_states
510
+
511
+
512
+ @auto_docstring
513
+ class EmoPreTrainedModel(PreTrainedModel):
514
+ config: EmoConfig
515
+ base_model_prefix = "model"
516
+ supports_gradient_checkpointing = True
517
+ _no_split_modules = ["EmoDecoderLayer"]
518
+ _skip_keys_device_placement = ["past_key_values"]
519
+ _supports_flash_attn = True
520
+ _supports_sdpa = True
521
+ _supports_flex_attn = True
522
+ _can_compile_fullgraph = (
523
+ False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
524
+ )
525
+ _supports_attention_backend = True
526
+ _can_record_outputs = {
527
+ "router_logits": OutputRecorder(EmoSparseMoeBlock, index=1),
528
+ "hidden_states": EmoDecoderLayer,
529
+ "attentions": EmoAttention,
530
+ }
531
+ config_class = EmoConfig
532
+
533
+
534
+ class EmoRotaryEmbedding(nn.Module):
535
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
536
+
537
+ def __init__(self, config: EmoConfig, device=None):
538
+ super().__init__()
539
+ # BC: "rope_type" was originally "type"
540
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
541
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
542
+ else:
543
+ self.rope_type = "default"
544
+ self.max_seq_len_cached = config.max_position_embeddings
545
+ self.original_max_seq_len = config.max_position_embeddings
546
+
547
+ self.config = config
548
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
549
+
550
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
551
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
552
+ self.original_inv_freq = self.inv_freq
553
+
554
+ @torch.no_grad()
555
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
556
+ def forward(self, x, position_ids):
557
+ inv_freq_expanded = (
558
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
559
+ )
560
+ position_ids_expanded = position_ids[:, None, :].float()
561
+
562
+ device_type = (
563
+ x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
564
+ )
565
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
566
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
567
+ emb = torch.cat((freqs, freqs), dim=-1)
568
+ cos = emb.cos() * self.attention_scaling
569
+ sin = emb.sin() * self.attention_scaling
570
+ return cos, sin
571
+
572
+
573
+ @auto_docstring
574
+ class EmoModel(EmoPreTrainedModel):
575
+ def __init__(self, config):
576
+ super().__init__(config)
577
+ self.padding_idx = config.pad_token_id
578
+ self.vocab_size = config.vocab_size
579
+
580
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
581
+ self.norm = EmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
582
+ self.rotary_emb = EmoRotaryEmbedding(config=config)
583
+ self.gradient_checkpointing = False
584
+
585
+ # Check if per-layer expert counts are specified
586
+ num_experts_per_layer = getattr(config, "num_experts_per_layer", None)
587
+ num_shared_experts_per_layer = getattr(config, "num_shared_experts_per_layer", None)
588
+ always_active_experts_per_layer = getattr(config, "always_active_experts_per_layer", None)
589
+ always_active_experts = getattr(config, "always_active_experts", None)
590
+
591
+ # Resolve always_active_experts to a per-layer list
592
+ if always_active_experts_per_layer is None and always_active_experts is not None:
593
+ always_active_experts_per_layer = [always_active_experts] * config.num_hidden_layers
594
+
595
+ if num_experts_per_layer is not None:
596
+ # Use per-layer expert counts
597
+ assert (
598
+ len(num_experts_per_layer) == config.num_hidden_layers
599
+ ), f"num_experts_per_layer has length {len(num_experts_per_layer)} but model has {config.num_hidden_layers} layers"
600
+ if num_shared_experts_per_layer is None:
601
+ # Default: use config.num_shared_experts for all layers, but cap at layer's num_experts
602
+ num_shared_experts_per_layer = [
603
+ min(config.num_shared_experts, num_experts_per_layer[i])
604
+ for i in range(config.num_hidden_layers)
605
+ ]
606
+ self.layers = nn.ModuleList(
607
+ [
608
+ EmoDecoderLayer(
609
+ config,
610
+ layer_idx,
611
+ num_experts_per_layer[layer_idx],
612
+ num_shared_experts_per_layer[layer_idx],
613
+ always_active_experts=always_active_experts_per_layer[layer_idx]
614
+ if always_active_experts_per_layer is not None
615
+ else None,
616
+ )
617
+ for layer_idx in range(config.num_hidden_layers)
618
+ ]
619
+ )
620
+ else:
621
+ # Fall back to original behavior: all layers use config.num_experts
622
+ self.layers = nn.ModuleList(
623
+ [
624
+ EmoDecoderLayer(
625
+ config,
626
+ layer_idx,
627
+ config.num_experts,
628
+ config.num_shared_experts,
629
+ always_active_experts=always_active_experts_per_layer[layer_idx]
630
+ if always_active_experts_per_layer is not None
631
+ else None,
632
+ )
633
+ for layer_idx in range(config.num_hidden_layers)
634
+ ]
635
+ )
636
+
637
+ # Initialize weights and apply final processing
638
+ self.post_init()
639
+
640
+ @check_model_inputs
641
+ @auto_docstring
642
+ def forward(
643
+ self,
644
+ input_ids: Optional[torch.LongTensor] = None,
645
+ attention_mask: Optional[torch.Tensor] = None,
646
+ position_ids: Optional[torch.LongTensor] = None,
647
+ past_key_values: Optional[Cache] = None,
648
+ inputs_embeds: Optional[torch.FloatTensor] = None,
649
+ use_cache: Optional[bool] = None,
650
+ cache_position: Optional[torch.LongTensor] = None,
651
+ **kwargs: Unpack[TransformersKwargs],
652
+ ) -> MoeModelOutputWithPast:
653
+ if (input_ids is None) ^ (inputs_embeds is not None):
654
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
655
+
656
+ if use_cache and past_key_values is None:
657
+ past_key_values = DynamicCache(config=self.config)
658
+
659
+ if inputs_embeds is None:
660
+ inputs_embeds = self.embed_tokens(input_ids)
661
+
662
+ if cache_position is None:
663
+ past_seen_tokens = (
664
+ past_key_values.get_seq_length() if past_key_values is not None else 0
665
+ )
666
+ cache_position = torch.arange(
667
+ past_seen_tokens,
668
+ past_seen_tokens + inputs_embeds.shape[1],
669
+ device=inputs_embeds.device,
670
+ )
671
+ if position_ids is None:
672
+ position_ids = cache_position.unsqueeze(0)
673
+
674
+ causal_mask = create_causal_mask(
675
+ config=self.config,
676
+ input_embeds=inputs_embeds,
677
+ attention_mask=attention_mask,
678
+ cache_position=cache_position,
679
+ past_key_values=past_key_values,
680
+ position_ids=position_ids,
681
+ )
682
+
683
+ hidden_states = inputs_embeds
684
+
685
+ # create position embeddings to be shared across the decoder layers
686
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
687
+
688
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
689
+ hidden_states = decoder_layer(
690
+ hidden_states,
691
+ position_embeddings=position_embeddings,
692
+ attention_mask=causal_mask,
693
+ position_ids=position_ids,
694
+ past_key_values=past_key_values,
695
+ use_cache=use_cache,
696
+ cache_position=cache_position,
697
+ **kwargs,
698
+ )
699
+
700
+ hidden_states = self.norm(hidden_states)
701
+
702
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
703
+ last_hidden_state=hidden_states,
704
+ past_key_values=past_key_values,
705
+ )
706
+
707
+
708
+ @dataclass
709
+ class MoeCausalLMOutputWithPast(ModelOutput):
710
+ """
711
+ Base class for causal language model (or autoregressive) with mixture of experts outputs.
712
+
713
+ Args:
714
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
715
+ Language modeling loss (for next-token prediction).
716
+
717
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
718
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
719
+
720
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
721
+ aux_loss for the sparse modules.
722
+
723
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
724
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
725
+
726
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
727
+ loss for Mixture of Experts models.
728
+
729
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
730
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
731
+
732
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
733
+ `past_key_values` input) to speed up sequential decoding.
734
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
735
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
736
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
737
+
738
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
739
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
740
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
741
+ sequence_length)`.
742
+
743
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
744
+ heads.
745
+ """
746
+
747
+ loss: Optional[torch.FloatTensor] = None
748
+ aux_loss: Optional[torch.FloatTensor] = None
749
+ lb_loss: Optional[torch.FloatTensor] = None
750
+ ce_loss: Optional[torch.FloatTensor] = None
751
+ logits: Optional[torch.FloatTensor] = None
752
+ past_key_values: Optional[Cache] = None
753
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
754
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
755
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
756
+
757
+
758
+ def load_balancing_loss_func_olmoe(
759
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
760
+ num_experts: Optional[int] = None,
761
+ top_k=2,
762
+ attention_mask: Optional[torch.Tensor] = None,
763
+ labels: Optional[torch.Tensor] = None,
764
+ num_items_in_batch: Optional[
765
+ torch.Tensor
766
+ ] = None, # the number of tokens within a global batch (including across dp ranks)
767
+ ignore_index=-100,
768
+ num_shared_experts=0,
769
+ num_experts_per_layer: Optional[list[int]] = None,
770
+ num_shared_experts_per_layer: Optional[list[int]] = None,
771
+ always_active_experts: Optional[list[int]] = None,
772
+ always_active_experts_per_layer: Optional[list[list[int]]] = None,
773
+ ) -> Union[torch.Tensor, int]:
774
+ r"""
775
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
776
+
777
+ This version supports variable per-layer expert counts by computing the loss
778
+ per-layer individually and averaging across layers.
779
+
780
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
781
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
782
+ experts is too unbalanced.
783
+
784
+ Args:
785
+ gate_logits:
786
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
787
+ shape [batch_size X sequence_length, num_experts]. This has not been softmaxed yet.
788
+ Note: each layer may have a different num_experts if num_experts_per_layer is set.
789
+ num_experts:
790
+ Number of experts (used as fallback if num_experts_per_layer is None)
791
+ top_k:
792
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
793
+ parameter.
794
+ attention_mask (`torch.Tensor`, *optional*):
795
+ The attention_mask used in forward function
796
+ shape [batch_size X sequence_length] if not None.
797
+ num_experts_per_layer:
798
+ List of expert counts per layer. If None, uses num_experts for all layers.
799
+ num_shared_experts_per_layer:
800
+ List of shared expert counts per layer. If None, uses num_shared_experts for all layers.
801
+
802
+ Returns:
803
+ The auxiliary loss.
804
+ """
805
+ if gate_logits is None or not isinstance(gate_logits, tuple):
806
+ return 0
807
+
808
+ compute_device = gate_logits[0].device
809
+ num_hidden_layers = len(gate_logits)
810
+
811
+ # Resolve always_active_experts for the uniform path
812
+ if always_active_experts_per_layer is None and always_active_experts is not None:
813
+ always_active_experts_per_layer = [always_active_experts] * num_hidden_layers
814
+
815
+ # Check if we have variable expert counts
816
+ has_variable_experts = num_experts_per_layer is not None and len(set(num_experts_per_layer)) > 1
817
+
818
+ if not has_variable_experts:
819
+ # All layers have the same expert count - use the original stacking approach
820
+ concatenated_gate_logits = torch.stack(
821
+ [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
822
+ ) # shape: (num_hidden_layers, batch_size * sequence_length, num_experts)
823
+
824
+ # remove the shared experts from the gate logits since they are not used for routing in the loss function
825
+ if num_shared_experts > 0:
826
+ concatenated_gate_logits = concatenated_gate_logits[:, :, :-num_shared_experts]
827
+ # adjust the num_experts and top_k accordingly for the loss computation
828
+ num_experts = num_experts - num_shared_experts
829
+ top_k = top_k - num_shared_experts
830
+
831
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
832
+
833
+ # Exclude always-active experts from the LB loss by removing their
834
+ # columns entirely so that num_experts matches the last dimension.
835
+ if (
836
+ always_active_experts_per_layer is not None
837
+ and len(always_active_experts_per_layer[0]) > 0
838
+ ):
839
+ aa_experts = always_active_experts_per_layer[0] # uniform across layers in this path
840
+ routed_mask = torch.ones(num_experts, dtype=torch.bool, device=compute_device)
841
+ routed_mask[aa_experts] = False
842
+ routing_weights = routing_weights[:, :, routed_mask]
843
+ num_experts = num_experts - len(aa_experts)
844
+ top_k = top_k - len(aa_experts)
845
+
846
+ _, selected_experts = torch.topk(
847
+ routing_weights, top_k, dim=-1
848
+ ) # shape: (num_hidden_layers, batch_size * sequence_length, top_k)
849
+
850
+ expert_counts_onehot = torch.nn.functional.one_hot(
851
+ selected_experts, num_experts
852
+ ) # shape: (num_hidden_layers, batch_size * sequence_length, top_k, num_experts)
853
+
854
+ if attention_mask is None and labels is None:
855
+ # Compute the percentage of tokens routed to each experts
856
+ counts_per_expert = torch.mean(
857
+ expert_counts_onehot.float(), dim=(1, 2)
858
+ ) # shape: (num_hidden_layers, num_experts)
859
+
860
+ # Compute the average probability of routing to these experts
861
+ prob_per_expert = torch.mean(
862
+ routing_weights, dim=1
863
+ ) # shape: (num_hidden_layers, num_experts)
864
+ else:
865
+ # if there are labels, then we want to ignore the indices that are in the prompt as well (if there is any)
866
+ if labels is not None:
867
+ attention_mask = labels != ignore_index
868
+ batch_size, sequence_length = attention_mask.shape
869
+
870
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
871
+ expert_attention_mask = (
872
+ attention_mask[None, :, :, None, None]
873
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
874
+ .reshape(num_hidden_layers, -1, top_k, num_experts)
875
+ .to(compute_device)
876
+ )
877
+
878
+ # Compute the percentage of tokens routed to each experts
879
+ counts_per_expert = torch.sum(
880
+ expert_counts_onehot.float() * expert_attention_mask, dim=(1, 2)
881
+ )
882
+
883
+ # Compute the mask that masks all padding tokens as 0 with the same shape of frequency_per_expert
884
+ router_per_expert_attention_mask = (
885
+ attention_mask[None, :, :, None]
886
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
887
+ .reshape(num_hidden_layers, -1, num_experts)
888
+ .to(compute_device)
889
+ )
890
+
891
+ # average the probability across valid tokens
892
+ prob_per_expert = torch.sum(
893
+ routing_weights * router_per_expert_attention_mask, dim=1
894
+ ) / torch.sum(
895
+ attention_mask
896
+ ) # shape: (num_hidden_layers, num_experts)
897
+
898
+ overall_loss = torch.sum(counts_per_expert * prob_per_expert)
899
+
900
+ # Fallback when num_items_in_batch isn't provided (e.g., manual forward calls)
901
+ if num_items_in_batch is None:
902
+ if labels is not None:
903
+ num_items_in_batch = (labels != ignore_index).sum()
904
+ elif attention_mask is not None:
905
+ num_items_in_batch = attention_mask.sum()
906
+ else:
907
+ # fall back to total tokens in batch/seq from gate logits
908
+ num_items_in_batch = gate_logits[0].shape[0]
909
+
910
+ if torch.is_tensor(num_items_in_batch):
911
+ num_items_in_batch = num_items_in_batch.to(compute_device)
912
+
913
+ # we follow olmo-core and use counts for dot product instead of frequency, and divide by total number token across gradient accumulation steps
914
+ overall_loss = overall_loss / (num_items_in_batch * top_k)
915
+
916
+ overall_loss = (
917
+ overall_loss * num_experts / num_hidden_layers
918
+ ) # times num_experts according to lb equation, divide by num_hidden_layers to get average over layers
919
+
920
+ return overall_loss
921
+
922
+ else:
923
+ # Variable expert counts - compute loss per layer and average
924
+ if num_shared_experts_per_layer is None:
925
+ num_shared_experts_per_layer = [num_shared_experts] * num_hidden_layers
926
+
927
+ # Compute attention mask once
928
+ if labels is not None:
929
+ attention_mask = labels != ignore_index
930
+
931
+ if attention_mask is not None:
932
+ batch_size, sequence_length = attention_mask.shape
933
+
934
+ # Fallback when num_items_in_batch isn't provided
935
+ if num_items_in_batch is None:
936
+ if labels is not None:
937
+ num_items_in_batch = (labels != ignore_index).sum()
938
+ elif attention_mask is not None:
939
+ num_items_in_batch = attention_mask.sum()
940
+ else:
941
+ num_items_in_batch = gate_logits[0].shape[0]
942
+
943
+ if torch.is_tensor(num_items_in_batch):
944
+ num_items_in_batch = num_items_in_batch.to(compute_device)
945
+
946
+ layer_losses = []
947
+
948
+ for layer_idx, layer_gate in enumerate(gate_logits):
949
+ layer_gate = layer_gate.to(compute_device)
950
+ layer_num_experts = num_experts_per_layer[layer_idx]
951
+ layer_num_shared = num_shared_experts_per_layer[layer_idx]
952
+
953
+ # Remove shared experts from logits
954
+ if layer_num_shared > 0:
955
+ layer_gate = layer_gate[:, :-layer_num_shared]
956
+ effective_num_experts = layer_num_experts - layer_num_shared
957
+ effective_top_k = top_k - layer_num_shared
958
+ else:
959
+ effective_num_experts = layer_num_experts
960
+ effective_top_k = top_k
961
+
962
+ # Compute routing weights
963
+ routing_weights = torch.nn.functional.softmax(layer_gate, dim=-1)
964
+
965
+ # Exclude always-active experts from the LB loss by removing their columns
966
+ layer_aa = (
967
+ always_active_experts_per_layer[layer_idx]
968
+ if always_active_experts_per_layer is not None
969
+ else None
970
+ )
971
+ if layer_aa is not None and len(layer_aa) > 0:
972
+ routed_mask = torch.ones(
973
+ effective_num_experts, dtype=torch.bool, device=compute_device
974
+ )
975
+ routed_mask[layer_aa] = False
976
+ routing_weights = routing_weights[:, routed_mask]
977
+ effective_num_experts = effective_num_experts - len(layer_aa)
978
+ effective_top_k = effective_top_k - len(layer_aa)
979
+
980
+ _, selected_experts = torch.topk(
981
+ routing_weights, effective_top_k, dim=-1
982
+ ) # shape: (batch_size * sequence_length, top_k)
983
+
984
+ expert_counts_onehot = torch.nn.functional.one_hot(
985
+ selected_experts, effective_num_experts
986
+ ) # shape: (batch_size * sequence_length, top_k, num_experts)
987
+
988
+ if attention_mask is None:
989
+ counts_per_expert = torch.mean(
990
+ expert_counts_onehot.float(), dim=(0, 1)
991
+ ) # shape: (num_experts,)
992
+ prob_per_expert = torch.mean(routing_weights, dim=0) # shape: (num_experts,)
993
+ else:
994
+ # Reshape for masking
995
+ expert_attention_mask = (
996
+ attention_mask[:, :, None, None]
997
+ .expand((batch_size, sequence_length, effective_top_k, effective_num_experts))
998
+ .reshape(-1, effective_top_k, effective_num_experts)
999
+ .to(compute_device)
1000
+ )
1001
+
1002
+ counts_per_expert = torch.sum(
1003
+ expert_counts_onehot.float() * expert_attention_mask, dim=(0, 1)
1004
+ )
1005
+
1006
+ router_attention_mask = (
1007
+ attention_mask[:, :, None]
1008
+ .expand((batch_size, sequence_length, effective_num_experts))
1009
+ .reshape(-1, effective_num_experts)
1010
+ .to(compute_device)
1011
+ )
1012
+
1013
+ prob_per_expert = torch.sum(
1014
+ routing_weights * router_attention_mask, dim=0
1015
+ ) / torch.sum(attention_mask)
1016
+
1017
+ layer_loss = torch.sum(counts_per_expert * prob_per_expert)
1018
+ layer_loss = layer_loss / (num_items_in_batch * effective_top_k)
1019
+ layer_loss = layer_loss * effective_num_experts
1020
+
1021
+ layer_losses.append(layer_loss)
1022
+
1023
+ # Average across layers
1024
+ overall_loss = torch.stack(layer_losses).mean()
1025
+
1026
+ return overall_loss
1027
+
1028
+
1029
+ class EmoForCausalLM(EmoPreTrainedModel, GenerationMixin):
1030
+ _tied_weights_keys = ["lm_head.weight"]
1031
+
1032
+ def __init__(self, config):
1033
+ super().__init__(config)
1034
+ self.model = EmoModel(config)
1035
+ self.vocab_size = config.vocab_size
1036
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1037
+
1038
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1039
+ self.num_experts = config.num_experts
1040
+ self.num_experts_per_tok = config.num_experts_per_tok
1041
+ # Initialize weights and apply final processing
1042
+ self.post_init()
1043
+
1044
+ @auto_docstring
1045
+ def forward(
1046
+ self,
1047
+ input_ids: Optional[torch.LongTensor] = None,
1048
+ attention_mask: Optional[torch.Tensor] = None,
1049
+ position_ids: Optional[torch.LongTensor] = None,
1050
+ past_key_values: Optional[Cache] = None,
1051
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1052
+ labels: Optional[torch.LongTensor] = None,
1053
+ use_cache: Optional[bool] = None,
1054
+ output_attentions: Optional[bool] = None,
1055
+ output_hidden_states: Optional[bool] = None,
1056
+ output_router_logits: Optional[bool] = None,
1057
+ return_dict: Optional[bool] = None,
1058
+ cache_position: Optional[torch.LongTensor] = None,
1059
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1060
+ **kwargs,
1061
+ ) -> Union[tuple, MoeCausalLMOutputWithPast]:
1062
+ r"""
1063
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1064
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1065
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1066
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1067
+
1068
+ Example:
1069
+
1070
+ ```python
1071
+ >>> from transformers import AutoTokenizer, EmoForCausalLM
1072
+
1073
+ >>> model = EmoForCausalLM.from_pretrained("allenai/Emo-1B-7B-0924")
1074
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Emo-1B-7B-0924")
1075
+
1076
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1077
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1078
+
1079
+ >>> # Generate
1080
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1081
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1082
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
1083
+ ```
1084
+ """
1085
+ output_attentions = (
1086
+ output_attentions if output_attentions is not None else self.config.output_attentions
1087
+ )
1088
+ output_router_logits = (
1089
+ output_router_logits
1090
+ if output_router_logits is not None
1091
+ else self.config.output_router_logits
1092
+ )
1093
+ output_hidden_states = (
1094
+ output_hidden_states
1095
+ if output_hidden_states is not None
1096
+ else self.config.output_hidden_states
1097
+ )
1098
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1099
+
1100
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1101
+ outputs = self.model(
1102
+ input_ids=input_ids,
1103
+ attention_mask=attention_mask,
1104
+ position_ids=position_ids,
1105
+ past_key_values=past_key_values,
1106
+ inputs_embeds=inputs_embeds,
1107
+ use_cache=use_cache,
1108
+ output_attentions=output_attentions,
1109
+ output_hidden_states=output_hidden_states,
1110
+ output_router_logits=output_router_logits,
1111
+ return_dict=return_dict,
1112
+ cache_position=cache_position,
1113
+ )
1114
+
1115
+ hidden_states = outputs[0]
1116
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1117
+ slice_indices = (
1118
+ slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1119
+ )
1120
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1121
+
1122
+ loss = None
1123
+ ce_loss = None
1124
+ if labels is not None:
1125
+ ce_loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
1126
+ loss = ce_loss
1127
+
1128
+ lb_loss = None
1129
+
1130
+ if output_router_logits:
1131
+ # Get per-layer expert counts if available
1132
+ num_experts_per_layer = getattr(self.config, "num_experts_per_layer", None)
1133
+ num_shared_experts_per_layer = getattr(
1134
+ self.config, "num_shared_experts_per_layer", None
1135
+ )
1136
+
1137
+ # Filter out dense layers (num_experts == 0) since they produce no router_logits
1138
+ if num_experts_per_layer is not None:
1139
+ moe_mask = [i for i, n in enumerate(num_experts_per_layer) if n > 0]
1140
+ num_experts_per_layer = [num_experts_per_layer[i] for i in moe_mask]
1141
+ if num_shared_experts_per_layer is not None:
1142
+ num_shared_experts_per_layer = [
1143
+ num_shared_experts_per_layer[i] for i in moe_mask
1144
+ ]
1145
+
1146
+ # Resolve always_active_experts for LB loss
1147
+ always_active_experts_per_layer_for_loss = getattr(
1148
+ self.config, "always_active_experts_per_layer", None
1149
+ )
1150
+ always_active_experts_for_loss = getattr(self.config, "always_active_experts", None)
1151
+ # Filter out dense layers if needed
1152
+ if (
1153
+ num_experts_per_layer is not None
1154
+ and always_active_experts_per_layer_for_loss is not None
1155
+ ):
1156
+ always_active_experts_per_layer_for_loss = [
1157
+ always_active_experts_per_layer_for_loss[i] for i in moe_mask
1158
+ ]
1159
+
1160
+ lb_loss = load_balancing_loss_func_olmoe(
1161
+ outputs.router_logits if return_dict else outputs[-1],
1162
+ self.num_experts,
1163
+ self.num_experts_per_tok,
1164
+ attention_mask,
1165
+ labels,
1166
+ num_shared_experts=self.config.num_shared_experts,
1167
+ num_experts_per_layer=num_experts_per_layer,
1168
+ num_shared_experts_per_layer=num_shared_experts_per_layer,
1169
+ always_active_experts=always_active_experts_for_loss,
1170
+ always_active_experts_per_layer=always_active_experts_per_layer_for_loss,
1171
+ **kwargs,
1172
+ )
1173
+ if labels is not None:
1174
+ loss += self.router_aux_loss_coef * lb_loss.to(
1175
+ loss.device
1176
+ ) # make sure to reside in the same device
1177
+
1178
+ if not return_dict:
1179
+ output = (logits,) + outputs[1:]
1180
+ if output_router_logits:
1181
+ output = (lb_loss,) + output
1182
+ return (loss,) + output if loss is not None else output
1183
+
1184
+ return MoeCausalLMOutputWithPast(
1185
+ loss=loss,
1186
+ aux_loss=lb_loss,
1187
+ lb_loss=lb_loss.detach().clone()
1188
+ if lb_loss is not None
1189
+ else None, # for logging callback
1190
+ ce_loss=ce_loss.detach().clone()
1191
+ if ce_loss is not None
1192
+ else None, # for logging callback
1193
+ logits=logits,
1194
+ past_key_values=outputs.past_key_values,
1195
+ hidden_states=outputs.hidden_states,
1196
+ attentions=outputs.attentions,
1197
+ router_logits=outputs.router_logits,
1198
+ )
1199
+
1200
+
1201
+ __all__ = [
1202
+ "EmoForCausalLM",
1203
+ "EmoModel",
1204
+ "EmoPreTrainedModel",
1205
+ ]