joerowell commited on
Commit
825ca3a
·
verified ·
1 Parent(s): 94107a2

Sync bundled HF code with upstream Laguna PR (v5 schema)

Browse files
Files changed (1) hide show
  1. modeling_laguna.py +224 -177
modeling_laguna.py CHANGED
@@ -1,5 +1,4 @@
1
- # ruff: noqa
2
- # Copyright 2025 Poolside and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -13,37 +12,34 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- from typing import Optional
17
  from collections.abc import Callable
 
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch import nn
 
22
  from transformers import initialization as init
23
- from transformers.utils import auto_docstring, can_return_tuple, is_grouped_mm_available
24
- from transformers.generation import GenerationMixin
25
  from transformers.activations import ACT2FN
26
  from transformers.cache_utils import Cache, DynamicCache
27
- from transformers.integrations import (
28
- use_kernelized_func,
29
- use_kernel_func_from_hub,
30
- use_kernel_forward_from_hub,
31
- )
32
- from transformers.masking_utils import create_causal_mask
33
- from transformers.utils.generic import OutputRecorder, TransformersKwargs, maybe_autocast, check_model_inputs
34
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
  from transformers.modeling_layers import GradientCheckpointingLayer
36
- from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
37
- from transformers.processing_utils import Unpack
38
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
40
-
 
 
 
41
  from .configuration_laguna import LagunaConfig
42
 
43
 
44
  @use_kernel_forward_from_hub("RMSNorm")
45
  class LagunaRMSNorm(nn.Module):
46
- def __init__(self, hidden_size, eps=1e-6):
47
  """
48
  LagunaRMSNorm is equivalent to T5LayerNorm
49
  """
@@ -51,7 +47,7 @@ class LagunaRMSNorm(nn.Module):
51
  self.weight = nn.Parameter(torch.ones(hidden_size))
52
  self.variance_epsilon = eps
53
 
54
- def forward(self, hidden_states):
55
  input_dtype = hidden_states.dtype
56
  hidden_states = hidden_states.to(torch.float32)
57
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -65,27 +61,35 @@ class LagunaRMSNorm(nn.Module):
65
  class LagunaRotaryEmbedding(nn.Module):
66
  inv_freq: torch.Tensor # fix linting for `register_buffer`
67
 
68
- def __init__(self, config: LagunaConfig, device=None):
69
  super().__init__()
70
  self.max_seq_len_cached = config.max_position_embeddings
71
  self.original_max_seq_len = config.max_position_embeddings
72
 
73
  self.config = config
74
 
75
- self.rope_type = self.config.rope_parameters["rope_type"]
76
- rope_init_fn: Callable = self.compute_default_rope_parameters
77
- if self.rope_type != "default":
78
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
79
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
 
80
 
81
- self.register_buffer("inv_freq", inv_freq, persistent=False)
82
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
 
 
 
 
 
 
83
 
84
  @staticmethod
85
  def compute_default_rope_parameters(
86
  config: LagunaConfig | None = None,
87
  device: Optional["torch.device"] = None,
88
  seq_len: int | None = None,
 
89
  ) -> tuple["torch.Tensor", float]:
90
  """
91
  Computes the inverse frequencies according to the original RoPE implementation
@@ -96,14 +100,18 @@ class LagunaRotaryEmbedding(nn.Module):
96
  The device to use for initialization of the inverse frequencies.
97
  seq_len (`int`, *optional*):
98
  The current sequence length. Unused for this type of RoPE.
99
-
100
- Returns
101
- -------
 
102
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
103
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
104
  """
105
- base = config.rope_parameters["rope_theta"]
106
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
 
 
 
107
 
108
  attention_factor = 1.0 # Unused in this type of RoPE
109
 
@@ -115,16 +123,19 @@ class LagunaRotaryEmbedding(nn.Module):
115
 
116
  @torch.no_grad()
117
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
118
- def forward(self, x, position_ids):
119
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
 
 
120
  position_ids_expanded = position_ids[:, None, :].float()
121
 
122
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
123
  with maybe_autocast(device_type=device_type, enabled=False): # Force float32
124
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125
  emb = torch.cat((freqs, freqs), dim=-1)
126
- cos = emb.cos() * self.attention_scaling
127
- sin = emb.sin() * self.attention_scaling
128
 
129
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
130
 
@@ -146,71 +157,97 @@ class LagunaMLP(nn.Module):
146
 
147
 
148
  class LagunaTopKRouter(nn.Module):
149
- """Laguna MoE router using sigmoid scoring (not softmax)."""
150
-
151
  def __init__(self, config):
152
  super().__init__()
153
  self.top_k = config.num_experts_per_tok
154
  self.num_experts = config.num_experts
155
- self.norm_topk_prob = config.norm_topk_prob
156
  self.hidden_dim = config.hidden_size
157
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
 
 
158
 
159
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
160
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
161
- router_logits = F.linear(hidden_states, self.weight)
162
- # Laguna-specific: sigmoid routing in float32 for precision
163
- routing_weights = torch.sigmoid(router_logits.float())
164
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
165
- if self.norm_topk_prob:
166
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
 
 
 
 
 
167
  routing_weights = routing_weights.to(hidden_states.dtype)
 
168
  return router_logits, routing_weights, selected_experts
169
 
170
 
171
- class LagunaSparseMoeBlock(nn.Module):
172
- """Laguna MoE block using sigmoid router, per-expert MLPs, and a shared expert."""
 
173
 
174
  def __init__(self, config):
175
  super().__init__()
176
  self.num_experts = config.num_experts
177
- self.top_k = config.num_experts_per_tok
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  self.gate = LagunaTopKRouter(config)
179
- self.experts = nn.ModuleList(
180
- [LagunaMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
181
- )
182
- self.shared_expert = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
183
- self.shared_expert_gate = (
184
- nn.Linear(config.hidden_size, 1, bias=False) if getattr(config, "moe_shared_gate", False) else None
185
- )
186
 
187
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
188
  batch_size, sequence_length, hidden_dim = hidden_states.shape
189
  hidden_states = hidden_states.view(-1, hidden_dim)
 
190
 
191
- shared_expert_output = self.shared_expert(hidden_states)
192
- if self.shared_expert_gate is not None:
193
- shared_expert_output = shared_expert_output * torch.sigmoid(self.shared_expert_gate(hidden_states))
194
-
195
- # Routed experts
196
  _, routing_weights, selected_experts = self.gate(hidden_states)
197
- final_hidden_states = torch.zeros_like(hidden_states)
198
-
199
- expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
200
- expert_mask = expert_mask.permute(2, 1, 0)
201
-
202
- for expert_idx in range(self.num_experts):
203
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
204
- if token_idx.shape[0] == 0:
205
- continue
206
- current_state = hidden_states[token_idx]
207
- current_hidden_states = self.experts[expert_idx](current_state)
208
- current_hidden_states = current_hidden_states * routing_weights[token_idx, top_k_pos, None]
209
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
210
 
211
- final_hidden_states = final_hidden_states + shared_expert_output
212
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
213
- return final_hidden_states
214
 
215
 
216
  def rotate_half(x):
@@ -220,10 +257,12 @@ def rotate_half(x):
220
  return torch.cat((-x2, x1), dim=-1)
221
 
222
 
223
- @use_kernel_func_from_hub("rotary_pos_emb")
224
  def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
225
  """Applies Rotary Position Embedding to the query and key tensors.
226
 
 
 
227
  Args:
228
  q (`torch.Tensor`): The query tensor.
229
  k (`torch.Tensor`): The key tensor.
@@ -236,15 +275,24 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
236
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
237
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
238
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
239
-
240
- Returns
241
- -------
242
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
243
  """
244
  cos = cos.unsqueeze(unsqueeze_dim)
245
  sin = sin.unsqueeze(unsqueeze_dim)
246
- q_embed = (q * cos) + (rotate_half(q) * sin)
247
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
 
 
248
  return q_embed, k_embed
249
 
250
 
@@ -275,8 +323,7 @@ def eager_attention_forward(
275
 
276
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
277
  if attention_mask is not None:
278
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
279
- attn_weights = attn_weights + causal_mask
280
 
281
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
282
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
@@ -286,33 +333,39 @@ def eager_attention_forward(
286
  return attn_output, attn_weights
287
 
288
 
289
- # Laguna attention is identical to Qwen2MoE attention except:
290
- # - No QKV bias
291
- # - Explicit head_dim from config
292
- # - Output gating: attn_output = attn_output * softplus(g_proj(hidden_states))
293
- # - No sliding window (full attention only)
294
  @use_kernelized_func(apply_rotary_pos_emb)
295
  class LagunaAttention(nn.Module):
296
- def __init__(self, config: LagunaConfig, layer_idx: int):
 
 
297
  super().__init__()
 
 
298
  self.config = config
299
  self.layer_idx = layer_idx
300
- self.head_dim = config.head_dim
301
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
302
  self.scaling = self.head_dim**-0.5
303
  self.attention_dropout = config.attention_dropout
304
  self.is_causal = True
305
 
306
- # Laguna: no QKV bias, explicit head_dim
307
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
308
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
309
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False)
310
- self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False)
311
- # Laguna-specific: gating projection
312
- self.g_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False)
313
- # QK normalization (RMSNorm applied per-head after reshape, before RoPE)
314
- self.q_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
315
- self.k_norm = LagunaRMSNorm(config.head_dim, eps=config.rms_norm_eps)
 
 
 
 
 
 
 
316
 
317
  def forward(
318
  self,
@@ -320,36 +373,28 @@ class LagunaAttention(nn.Module):
320
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
321
  attention_mask: torch.Tensor | None,
322
  past_key_values: Cache | None = None,
323
- cache_position: torch.LongTensor | None = None,
324
  **kwargs: Unpack[FlashAttentionKwargs],
325
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
326
  input_shape = hidden_states.shape[:-1]
327
  hidden_shape = (*input_shape, -1, self.head_dim)
328
 
329
- query_states = self.q_proj(hidden_states)
330
- key_states = self.k_proj(hidden_states)
331
- value_states = self.v_proj(hidden_states)
332
-
333
- query_states = query_states.view(hidden_shape).transpose(1, 2)
334
- key_states = key_states.view(hidden_shape).transpose(1, 2)
335
- value_states = value_states.view(hidden_shape).transpose(1, 2)
336
 
337
- # QK normalization (applied per-head before RoPE)
338
- query_states = self.q_norm(query_states)
339
- key_states = self.k_norm(key_states)
340
 
341
  cos, sin = position_embeddings
342
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
343
 
344
  if past_key_values is not None:
345
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
346
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
347
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
348
-
349
- attention_interface: Callable = eager_attention_forward
350
- if self.config._attn_implementation != "eager":
351
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
352
 
 
 
 
353
  attn_output, attn_weights = attention_interface(
354
  self,
355
  query_states,
@@ -358,37 +403,30 @@ class LagunaAttention(nn.Module):
358
  attention_mask,
359
  dropout=0.0 if not self.training else self.attention_dropout,
360
  scaling=self.scaling,
 
361
  **kwargs,
362
  )
363
 
364
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
365
 
366
- # Laguna-specific: apply gating BEFORE o_proj
367
- # gate values are computed from original hidden_states, applied in attention dimension
368
  gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
369
- attn_output = attn_output * gate
370
 
371
  attn_output = self.o_proj(attn_output)
372
-
373
  return attn_output, attn_weights
374
 
375
 
376
  class LagunaDecoderLayer(GradientCheckpointingLayer):
377
- """Laguna decoder layer with gated attention and sigmoid-routed MoE."""
378
-
379
  def __init__(self, config: LagunaConfig, layer_idx: int):
380
  super().__init__()
381
- self.self_attn = LagunaAttention(config, layer_idx)
382
- # Use MoE or dense MLP based on layer configuration
383
- if (layer_idx not in config.mlp_only_layers) and (
384
- config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
385
- ):
386
  self.mlp = LagunaSparseMoeBlock(config)
387
  else:
388
  self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
389
  self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
390
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
- self.hidden_size = config.hidden_size
392
 
393
  def forward(
394
  self,
@@ -397,7 +435,6 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
397
  position_ids: torch.LongTensor | None = None,
398
  past_key_values: Cache | None = None,
399
  use_cache: bool | None = False,
400
- cache_position: torch.LongTensor | None = None,
401
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
402
  **kwargs: Unpack[TransformersKwargs],
403
  ) -> torch.Tensor:
@@ -410,7 +447,6 @@ class LagunaDecoderLayer(GradientCheckpointingLayer):
410
  position_ids=position_ids,
411
  past_key_values=past_key_values,
412
  use_cache=use_cache,
413
- cache_position=cache_position,
414
  position_embeddings=position_embeddings,
415
  **kwargs,
416
  )
@@ -434,9 +470,8 @@ class LagunaPreTrainedModel(PreTrainedModel):
434
  _supports_flash_attn = True
435
  _supports_sdpa = True
436
  _supports_flex_attn = True
437
- _can_compile_fullgraph = (
438
- is_grouped_mm_available()
439
- ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
440
  _supports_attention_backend = True
441
  _can_record_outputs = {
442
  "router_logits": OutputRecorder(LagunaTopKRouter, index=0),
@@ -448,10 +483,24 @@ class LagunaPreTrainedModel(PreTrainedModel):
448
  def _init_weights(self, module):
449
  super()._init_weights(module)
450
  std = self.config.initializer_range
451
- if isinstance(module, LagunaTopKRouter):
 
 
 
452
  init.normal_(module.weight, mean=0.0, std=std)
 
 
 
 
 
 
 
 
 
 
453
 
454
 
 
455
  class LagunaModel(LagunaPreTrainedModel):
456
  def __init__(self, config: LagunaConfig):
457
  super().__init__(config)
@@ -469,7 +518,8 @@ class LagunaModel(LagunaPreTrainedModel):
469
  # Initialize weights and apply final processing
470
  self.post_init()
471
 
472
- @check_model_inputs
 
473
  def forward(
474
  self,
475
  input_ids: torch.LongTensor | None = None,
@@ -478,49 +528,50 @@ class LagunaModel(LagunaPreTrainedModel):
478
  past_key_values: Cache | None = None,
479
  inputs_embeds: torch.FloatTensor | None = None,
480
  use_cache: bool | None = None,
481
- cache_position: torch.LongTensor | None = None,
482
  **kwargs: Unpack[TransformersKwargs],
483
- ):
484
  if (input_ids is None) ^ (inputs_embeds is not None):
485
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
486
 
487
- if use_cache and past_key_values is None:
488
- past_key_values = DynamicCache(config=self.config)
489
-
490
  if inputs_embeds is None:
491
  inputs_embeds = self.embed_tokens(input_ids)
492
 
493
- if cache_position is None:
494
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
495
- cache_position = torch.arange(
496
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
497
- )
498
 
499
  if position_ids is None:
500
- position_ids = cache_position.unsqueeze(0)
501
-
502
- # Laguna uses full attention only (no sliding window)
503
- causal_mask = create_causal_mask(
504
- config=self.config,
505
- input_embeds=inputs_embeds,
506
- attention_mask=attention_mask,
507
- cache_position=cache_position,
508
- past_key_values=past_key_values,
509
- position_ids=position_ids,
510
- )
 
 
 
 
 
 
 
 
511
 
512
  hidden_states = inputs_embeds
513
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
 
514
 
515
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
516
  hidden_states = decoder_layer(
517
  hidden_states,
518
- attention_mask=causal_mask,
 
519
  position_ids=position_ids,
520
  past_key_values=past_key_values,
521
- use_cache=use_cache,
522
- cache_position=cache_position,
523
- position_embeddings=position_embeddings,
524
  **kwargs,
525
  )
526
 
@@ -528,7 +579,7 @@ class LagunaModel(LagunaPreTrainedModel):
528
 
529
  return MoeModelOutputWithPast(
530
  last_hidden_state=hidden_states,
531
- past_key_values=past_key_values,
532
  )
533
 
534
 
@@ -558,8 +609,7 @@ def load_balancing_loss_func(
558
  The attention_mask used in forward function
559
  shape [batch_size X sequence_length] if not None.
560
 
561
- Returns
562
- -------
563
  The auxiliary loss.
564
  """
565
  if gate_logits is None or not isinstance(gate_logits, tuple):
@@ -618,7 +668,7 @@ def load_balancing_loss_func(
618
  @auto_docstring
619
  class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
620
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
621
- _tp_plan = {"lm_head": "colwise_rep"}
622
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
623
 
624
  def __init__(self, config):
@@ -645,17 +695,15 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
645
  labels: torch.LongTensor | None = None,
646
  use_cache: bool | None = None,
647
  output_router_logits: bool | None = None,
648
- cache_position: torch.LongTensor | None = None,
649
  logits_to_keep: int | torch.Tensor = 0,
650
  **kwargs: Unpack[TransformersKwargs],
651
  ) -> MoeCausalLMOutputWithPast:
652
  r"""
653
- Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
654
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
655
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
656
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
657
  """
658
- # TODO (Joe) add example here after we got rid of the stale mistral example
659
 
660
  output_router_logits = (
661
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
@@ -670,7 +718,6 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
670
  inputs_embeds=inputs_embeds,
671
  use_cache=use_cache,
672
  output_router_logits=output_router_logits,
673
- cache_position=cache_position,
674
  **kwargs,
675
  )
676
 
@@ -691,8 +738,8 @@ class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
691
  self.num_experts_per_tok,
692
  attention_mask,
693
  )
694
- if labels is not None and isinstance(aux_loss, torch.Tensor):
695
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
696
 
697
  return MoeCausalLMOutputWithPast(
698
  loss=loss,
 
1
+ # Copyright 2026 Poolside and the HuggingFace Inc. team. All rights reserved.
 
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  from collections.abc import Callable
16
+ from typing import Optional
17
 
18
  import torch
19
  import torch.nn.functional as F
20
  from torch import nn
21
+
22
  from transformers import initialization as init
 
 
23
  from transformers.activations import ACT2FN
24
  from transformers.cache_utils import Cache, DynamicCache
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
27
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
28
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 
 
 
 
29
  from transformers.modeling_layers import GradientCheckpointingLayer
30
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
 
31
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
32
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
+ from transformers.processing_utils import Unpack
34
+ from transformers.utils import auto_docstring, can_return_tuple
35
+ from transformers.utils.generic import TransformersKwargs, maybe_autocast
36
+ from transformers.utils.output_capturing import OutputRecorder, capture_outputs
37
  from .configuration_laguna import LagunaConfig
38
 
39
 
40
  @use_kernel_forward_from_hub("RMSNorm")
41
  class LagunaRMSNorm(nn.Module):
42
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
43
  """
44
  LagunaRMSNorm is equivalent to T5LayerNorm
45
  """
 
47
  self.weight = nn.Parameter(torch.ones(hidden_size))
48
  self.variance_epsilon = eps
49
 
50
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
51
  input_dtype = hidden_states.dtype
52
  hidden_states = hidden_states.to(torch.float32)
53
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
 
61
  class LagunaRotaryEmbedding(nn.Module):
62
  inv_freq: torch.Tensor # fix linting for `register_buffer`
63
 
64
+ def __init__(self, config: LagunaConfig, device=None, layer_type=None):
65
  super().__init__()
66
  self.max_seq_len_cached = config.max_position_embeddings
67
  self.original_max_seq_len = config.max_position_embeddings
68
 
69
  self.config = config
70
 
71
+ self.layer_types = list(set(config.layer_types))
72
+ self.rope_type = {}
73
+ for layer_type in self.layer_types:
74
+ rope_params = self.config.rope_parameters[layer_type]
75
+ if rope_params is None:
76
+ continue
77
 
78
+ self.rope_type[layer_type] = rope_params["rope_type"]
79
+ rope_init_fn: Callable = self.compute_default_rope_parameters
80
+ if self.rope_type[layer_type] != "default":
81
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
82
+ curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
83
+ self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
84
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
85
+ setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
86
 
87
  @staticmethod
88
  def compute_default_rope_parameters(
89
  config: LagunaConfig | None = None,
90
  device: Optional["torch.device"] = None,
91
  seq_len: int | None = None,
92
+ layer_type: str | None = None,
93
  ) -> tuple["torch.Tensor", float]:
94
  """
95
  Computes the inverse frequencies according to the original RoPE implementation
 
100
  The device to use for initialization of the inverse frequencies.
101
  seq_len (`int`, *optional*):
102
  The current sequence length. Unused for this type of RoPE.
103
+ layer_type (`str`, *optional*):
104
+ The current layer type if the model has different RoPE parameters per type.
105
+ Should not be used unless `config.layer_types is not None`
106
+ Returns:
107
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
108
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
109
  """
110
+ base = config.rope_parameters[layer_type]["rope_theta"]
111
+ # key difference to gemma3: partial rope
112
+ partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0)
113
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
114
+ dim = int(head_dim * partial_rotary_factor)
115
 
116
  attention_factor = 1.0 # Unused in this type of RoPE
117
 
 
123
 
124
  @torch.no_grad()
125
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
126
+ def forward(self, x, position_ids, layer_type=None):
127
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
128
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
129
+
130
+ inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
131
  position_ids_expanded = position_ids[:, None, :].float()
132
 
133
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
134
  with maybe_autocast(device_type=device_type, enabled=False): # Force float32
135
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
  emb = torch.cat((freqs, freqs), dim=-1)
137
+ cos = emb.cos() * attention_scaling
138
+ sin = emb.sin() * attention_scaling
139
 
140
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
141
 
 
157
 
158
 
159
  class LagunaTopKRouter(nn.Module):
 
 
160
  def __init__(self, config):
161
  super().__init__()
162
  self.top_k = config.num_experts_per_tok
163
  self.num_experts = config.num_experts
 
164
  self.hidden_dim = config.hidden_size
165
  self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
166
+ self.e_score_correction_bias = nn.Parameter(torch.zeros(config.num_experts), requires_grad=False)
167
+ self.router_logit_softcapping = config.moe_router_logit_softcapping
168
 
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
173
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
174
+ router_logits = F.linear(hidden_states, self.weight).float()
175
+ # Optional logits softcapping
176
+ if self.router_logit_softcapping > 0.0:
177
+ router_logits = torch.tanh(router_logits / self.router_logit_softcapping) * self.router_logit_softcapping
178
+ # Sigmoid instead of softmax normalization
179
+ routing_scores = torch.sigmoid(router_logits)
180
+
181
+ scores_for_selection = routing_scores + self.e_score_correction_bias.to(routing_scores.dtype)
182
+ _, selected_experts = torch.topk(scores_for_selection, self.top_k, dim=-1)
183
+ routing_weights = routing_scores.gather(-1, selected_experts)
184
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
185
  routing_weights = routing_weights.to(hidden_states.dtype)
186
+
187
  return router_logits, routing_weights, selected_experts
188
 
189
 
190
+ @use_experts_implementation
191
+ class LagunaExperts(nn.Module):
192
+ """Collection of expert weights stored as 3D tensors."""
193
 
194
  def __init__(self, config):
195
  super().__init__()
196
  self.num_experts = config.num_experts
197
+ self.hidden_dim = config.hidden_size
198
+ self.intermediate_dim = config.moe_intermediate_size
199
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
200
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
201
+ self.act_fn = ACT2FN[config.hidden_act]
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ top_k_index: torch.Tensor,
207
+ top_k_weights: torch.Tensor,
208
+ ) -> torch.Tensor:
209
+ final_hidden_states = torch.zeros_like(hidden_states)
210
+ with torch.no_grad():
211
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
212
+ expert_mask = expert_mask.permute(2, 1, 0)
213
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
214
+
215
+ for expert_idx in expert_hit:
216
+ expert_idx = expert_idx[0]
217
+ if expert_idx == self.num_experts:
218
+ continue
219
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
220
+ current_state = hidden_states[token_idx]
221
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
222
+ current_hidden_states = self.act_fn(gate) * up
223
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
224
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
225
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
226
+
227
+ return final_hidden_states
228
+
229
+
230
+ class LagunaSparseMoeBlock(nn.Module):
231
+ def __init__(self, config: LagunaConfig):
232
+ super().__init__()
233
+ self.experts = LagunaExperts(config)
234
  self.gate = LagunaTopKRouter(config)
235
+ self.shared_experts = LagunaMLP(config, intermediate_size=config.shared_expert_intermediate_size)
236
+ self.routed_scaling_factor = config.moe_routed_scaling_factor
 
 
 
 
 
237
 
238
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
239
  batch_size, sequence_length, hidden_dim = hidden_states.shape
240
  hidden_states = hidden_states.view(-1, hidden_dim)
241
+ shared_output = self.shared_experts(hidden_states)
242
 
 
 
 
 
 
243
  _, routing_weights, selected_experts = self.gate(hidden_states)
244
+ hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
245
+ # Additional scaling
246
+ hidden_states = hidden_states * self.routed_scaling_factor
247
+ hidden_states = hidden_states + shared_output
 
 
 
 
 
 
 
 
 
248
 
249
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
250
+ return hidden_states
 
251
 
252
 
253
  def rotate_half(x):
 
257
  return torch.cat((-x2, x1), dim=-1)
258
 
259
 
260
+ # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
261
  def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
262
  """Applies Rotary Position Embedding to the query and key tensors.
263
 
264
+ Removes the interleaving of cos and sin from GLM
265
+
266
  Args:
267
  q (`torch.Tensor`): The query tensor.
268
  k (`torch.Tensor`): The key tensor.
 
275
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
276
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
277
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
278
+ Returns:
 
 
279
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
280
  """
281
  cos = cos.unsqueeze(unsqueeze_dim)
282
  sin = sin.unsqueeze(unsqueeze_dim)
283
+
284
+ # Keep half or full tensor for later concatenation
285
+ rotary_dim = cos.shape[-1]
286
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
287
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
288
+
289
+ # Apply rotary embeddings on the first half or full tensor
290
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
291
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
292
+
293
+ # Concatenate back to full shape
294
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
295
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
296
  return q_embed, k_embed
297
 
298
 
 
323
 
324
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
325
  if attention_mask is not None:
326
+ attn_weights = attn_weights + attention_mask
 
327
 
328
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
329
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
333
  return attn_output, attn_weights
334
 
335
 
 
 
 
 
 
336
  @use_kernelized_func(apply_rotary_pos_emb)
337
  class LagunaAttention(nn.Module):
338
+ """Afmoe-style SWA/GQA attention with Laguna-specific gating and per-layer head count."""
339
+
340
+ def __init__(self, config: LagunaConfig, layer_idx: int, num_heads: int):
341
  super().__init__()
342
+ # Number of heads is controlled via `config.num_attention_heads_per_layer` which is passed from the parent for the specific layer
343
+ self.num_heads = num_heads
344
  self.config = config
345
  self.layer_idx = layer_idx
346
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
347
+ self.num_key_value_groups = self.num_heads // config.num_key_value_heads
348
  self.scaling = self.head_dim**-0.5
349
  self.attention_dropout = config.attention_dropout
350
  self.is_causal = True
351
 
352
+ # Per-layer head count: rebuild q_proj and o_proj using self.num_heads (parent uses config.num_attention_heads).
353
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
354
+ self.k_proj = nn.Linear(
355
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
356
+ )
357
+ self.v_proj = nn.Linear(
358
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
359
+ )
360
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
361
+ # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
362
+ # We only add Laguna-specific attributes
363
+ self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
364
+ self.sliding_window = config.sliding_window if self.is_local_attention else None
365
+
366
+ self.q_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
367
+ self.k_norm = LagunaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
368
+ self.g_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False)
369
 
370
  def forward(
371
  self,
 
373
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
374
  attention_mask: torch.Tensor | None,
375
  past_key_values: Cache | None = None,
 
376
  **kwargs: Unpack[FlashAttentionKwargs],
377
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
378
  input_shape = hidden_states.shape[:-1]
379
  hidden_shape = (*input_shape, -1, self.head_dim)
380
 
381
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
382
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
383
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
 
 
 
 
384
 
385
+ query_states = self.q_norm(query_states).transpose(1, 2)
386
+ key_states = self.k_norm(key_states).transpose(1, 2)
387
+ value_states = value_states.transpose(1, 2)
388
 
389
  cos, sin = position_embeddings
390
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
391
 
392
  if past_key_values is not None:
393
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
 
 
 
 
 
 
394
 
395
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
396
+ self.config._attn_implementation, eager_attention_forward
397
+ )
398
  attn_output, attn_weights = attention_interface(
399
  self,
400
  query_states,
 
403
  attention_mask,
404
  dropout=0.0 if not self.training else self.attention_dropout,
405
  scaling=self.scaling,
406
+ sliding_window=self.sliding_window,
407
  **kwargs,
408
  )
409
 
410
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
411
 
 
 
412
  gate = F.softplus(self.g_proj(hidden_states).float()).to(attn_output.dtype)
413
+ attn_output = (attn_output.view(*input_shape, -1, self.head_dim) * gate.unsqueeze(-1)).view(*input_shape, -1)
414
 
415
  attn_output = self.o_proj(attn_output)
 
416
  return attn_output, attn_weights
417
 
418
 
419
  class LagunaDecoderLayer(GradientCheckpointingLayer):
 
 
420
  def __init__(self, config: LagunaConfig, layer_idx: int):
421
  super().__init__()
422
+ self.hidden_size = config.hidden_size
423
+ self.self_attn = LagunaAttention(config, layer_idx, config.num_attention_heads_per_layer[layer_idx])
424
+ if config.mlp_layer_types[layer_idx] == "sparse":
 
 
425
  self.mlp = LagunaSparseMoeBlock(config)
426
  else:
427
  self.mlp = LagunaMLP(config, intermediate_size=config.intermediate_size)
428
  self.input_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
429
  self.post_attention_layernorm = LagunaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
430
 
431
  def forward(
432
  self,
 
435
  position_ids: torch.LongTensor | None = None,
436
  past_key_values: Cache | None = None,
437
  use_cache: bool | None = False,
 
438
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
439
  **kwargs: Unpack[TransformersKwargs],
440
  ) -> torch.Tensor:
 
447
  position_ids=position_ids,
448
  past_key_values=past_key_values,
449
  use_cache=use_cache,
 
450
  position_embeddings=position_embeddings,
451
  **kwargs,
452
  )
 
470
  _supports_flash_attn = True
471
  _supports_sdpa = True
472
  _supports_flex_attn = True
473
+
474
+ _can_compile_fullgraph = True
 
475
  _supports_attention_backend = True
476
  _can_record_outputs = {
477
  "router_logits": OutputRecorder(LagunaTopKRouter, index=0),
 
483
  def _init_weights(self, module):
484
  super()._init_weights(module)
485
  std = self.config.initializer_range
486
+ if isinstance(module, LagunaExperts):
487
+ init.normal_(module.gate_up_proj, mean=0.0, std=std)
488
+ init.normal_(module.down_proj, mean=0.0, std=std)
489
+ elif isinstance(module, LagunaTopKRouter):
490
  init.normal_(module.weight, mean=0.0, std=std)
491
+ if isinstance(module, LagunaTopKRouter):
492
+ torch.nn.init.zeros_(module.e_score_correction_bias)
493
+ elif isinstance(module, LagunaRotaryEmbedding):
494
+ for layer_type in module.layer_types:
495
+ rope_init_fn = module.compute_default_rope_parameters
496
+ if module.rope_type[layer_type] != "default":
497
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
498
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
499
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
500
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
501
 
502
 
503
+ @auto_docstring
504
  class LagunaModel(LagunaPreTrainedModel):
505
  def __init__(self, config: LagunaConfig):
506
  super().__init__(config)
 
518
  # Initialize weights and apply final processing
519
  self.post_init()
520
 
521
+ @capture_outputs
522
+ @auto_docstring
523
  def forward(
524
  self,
525
  input_ids: torch.LongTensor | None = None,
 
528
  past_key_values: Cache | None = None,
529
  inputs_embeds: torch.FloatTensor | None = None,
530
  use_cache: bool | None = None,
 
531
  **kwargs: Unpack[TransformersKwargs],
532
+ ) -> MoeModelOutputWithPast:
533
  if (input_ids is None) ^ (inputs_embeds is not None):
534
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
535
 
 
 
 
536
  if inputs_embeds is None:
537
  inputs_embeds = self.embed_tokens(input_ids)
538
 
539
+ if use_cache and past_key_values is None:
540
+ past_key_values = DynamicCache(config=self.config)
 
 
 
541
 
542
  if position_ids is None:
543
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
544
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
545
+ position_ids = position_ids.unsqueeze(0)
546
+
547
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
548
+ mask_kwargs = {
549
+ "config": self.config,
550
+ "inputs_embeds": inputs_embeds,
551
+ "attention_mask": attention_mask,
552
+ "past_key_values": past_key_values,
553
+ "position_ids": position_ids,
554
+ }
555
+ mask_creation_functions = {
556
+ "full_attention": lambda: create_causal_mask(**mask_kwargs),
557
+ "sliding_attention": lambda: create_sliding_window_causal_mask(**mask_kwargs),
558
+ }
559
+ causal_mask_mapping = {}
560
+ for layer_type in set(self.config.layer_types):
561
+ causal_mask_mapping[layer_type] = mask_creation_functions[layer_type]()
562
 
563
  hidden_states = inputs_embeds
564
+ position_embeddings = {}
565
+ for layer_type in set(self.config.layer_types):
566
+ position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
567
 
568
+ for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
569
  hidden_states = decoder_layer(
570
  hidden_states,
571
+ attention_mask=causal_mask_mapping[self.config.layer_types[i]],
572
+ position_embeddings=position_embeddings[self.config.layer_types[i]],
573
  position_ids=position_ids,
574
  past_key_values=past_key_values,
 
 
 
575
  **kwargs,
576
  )
577
 
 
579
 
580
  return MoeModelOutputWithPast(
581
  last_hidden_state=hidden_states,
582
+ past_key_values=past_key_values if use_cache else None,
583
  )
584
 
585
 
 
609
  The attention_mask used in forward function
610
  shape [batch_size X sequence_length] if not None.
611
 
612
+ Returns:
 
613
  The auxiliary loss.
614
  """
615
  if gate_logits is None or not isinstance(gate_logits, tuple):
 
668
  @auto_docstring
669
  class LagunaForCausalLM(LagunaPreTrainedModel, GenerationMixin):
670
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
671
+ _tp_plan = {"lm_head": "colwise_gather_output"}
672
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
673
 
674
  def __init__(self, config):
 
695
  labels: torch.LongTensor | None = None,
696
  use_cache: bool | None = None,
697
  output_router_logits: bool | None = None,
 
698
  logits_to_keep: int | torch.Tensor = 0,
699
  **kwargs: Unpack[TransformersKwargs],
700
  ) -> MoeCausalLMOutputWithPast:
701
  r"""
702
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
703
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
704
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
705
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
706
  """
 
707
 
708
  output_router_logits = (
709
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
718
  inputs_embeds=inputs_embeds,
719
  use_cache=use_cache,
720
  output_router_logits=output_router_logits,
 
721
  **kwargs,
722
  )
723
 
 
738
  self.num_experts_per_tok,
739
  attention_mask,
740
  )
741
+ if labels is not None:
742
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
743
 
744
  return MoeCausalLMOutputWithPast(
745
  loss=loss,