trias702 commited on
Commit
418d9e4
·
1 Parent(s): 618a1c8

Added custom MinistralDiffOutputWithPast return type and skip_loss functionality

Browse files
Files changed (1) hide show
  1. modeling_ministral_dlm.py +17 -2
modeling_ministral_dlm.py CHANGED
@@ -1,4 +1,5 @@
1
  import copy
 
2
  from typing import Callable, Optional, Tuple, Union
3
  import random
4
  import os
@@ -10,6 +11,7 @@ import torch
10
  import torch.nn.functional as F
11
  from torch import nn
12
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
 
13
 
14
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
15
 
@@ -29,6 +31,17 @@ from .chat_utils import generate_with_prefix_cache_block_diff
29
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
30
  from .configuration_ministral_dlm import MinistralDLMConfig
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  # @torch.compile(dynamic=True, mode="reduce-overhead")
33
  # @torch.compile(mode="default")
34
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
@@ -479,6 +492,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
479
  loss_mask: Optional[torch.Tensor] = None,
480
  ce_loss_weight: float = 1.0,
481
  output_last_hidden_states_only: bool = False,
 
482
  **kwargs,
483
  ) -> CausalLMOutputWithPast:
484
 
@@ -565,7 +579,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
565
  logits = logits[:, :input_ids_len]
566
 
567
  loss = None
568
- if labels is not None:
569
  if self.config.dlm_paradigm == 'autoregressive':
570
  shift_logits = logits[..., :-1, :].contiguous()
571
  shift_labels = labels[..., 1:].contiguous()
@@ -702,9 +716,10 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
702
  else:
703
  loss = (loss, num_mask_tokens)
704
 
705
- return CausalLMOutputWithPast(
706
  loss=loss if not is_teacher else logits,
707
  logits=logits,
 
708
  past_key_values=enc_out.past_key_values,
709
  hidden_states=None,
710
  attentions=None,
 
1
  import copy
2
+ from dataclasses import dataclass
3
  from typing import Callable, Optional, Tuple, Union
4
  import random
5
  import os
 
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
 
16
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
 
 
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
34
+
35
+ @dataclass
36
+ class MinistralDiffOutputWithPast(ModelOutput):
37
+ loss: torch.FloatTensor | None = None
38
+ logits: torch.FloatTensor | None = None
39
+ causal_logits: torch.FloatTensor | None = None
40
+ past_key_values: Cache | None = None
41
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
42
+ attentions: tuple[torch.FloatTensor, ...] | None = None
43
+
44
+
45
  # @torch.compile(dynamic=True, mode="reduce-overhead")
46
  # @torch.compile(mode="default")
47
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
 
492
  loss_mask: Optional[torch.Tensor] = None,
493
  ce_loss_weight: float = 1.0,
494
  output_last_hidden_states_only: bool = False,
495
+ skip_loss: bool = False,
496
  **kwargs,
497
  ) -> CausalLMOutputWithPast:
498
 
 
579
  logits = logits[:, :input_ids_len]
580
 
581
  loss = None
582
+ if labels is not None and not skip_loss:
583
  if self.config.dlm_paradigm == 'autoregressive':
584
  shift_logits = logits[..., :-1, :].contiguous()
585
  shift_labels = labels[..., 1:].contiguous()
 
716
  else:
717
  loss = (loss, num_mask_tokens)
718
 
719
+ return MinistralDiffOutputWithPast(
720
  loss=loss if not is_teacher else logits,
721
  logits=logits,
722
+ causal_logits=causal_logits,
723
  past_key_values=enc_out.past_key_values,
724
  hidden_states=None,
725
  attentions=None,