trias702 commited on
Commit
456e96b
·
1 Parent(s): ff07748

Trying to force transformers to use the older causal mask

Browse files
Files changed (1) hide show
  1. modeling_ministral.py +4 -2
modeling_ministral.py CHANGED
@@ -11,7 +11,7 @@ from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
@@ -27,6 +27,7 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
 
30
 
31
  def rotate_half(x):
32
  """Rotates half the hidden dims of the input."""
@@ -419,7 +420,8 @@ class Ministral3Model(Ministral3PreTrainedModel):
419
 
420
  if self.training:
421
  causal_mask = None
422
- elif kwargs.get("use_causal_mask", False):
 
423
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
424
  causal_mask = mask_function(
425
  config=self.config,
 
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
 
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
+ ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
32
  def rotate_half(x):
33
  """Rotates half the hidden dims of the input."""
 
420
 
421
  if self.training:
422
  causal_mask = None
423
+ #elif kwargs.get("use_causal_mask", False):
424
+ else:
425
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
426
  causal_mask = mask_function(
427
  config=self.config,