trias702 commited on
Commit
33b2954
·
1 Parent(s): 456e96b

Overriding the old function doesn't work, reverting to old approach

Browse files
Files changed (1) hide show
  1. modeling_ministral.py +6 -5
modeling_ministral.py CHANGED
@@ -27,7 +27,7 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
- ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
32
  def rotate_half(x):
33
  """Rotates half the hidden dims of the input."""
@@ -418,10 +418,9 @@ class Ministral3Model(Ministral3PreTrainedModel):
418
  if position_ids is None:
419
  position_ids = cache_position.unsqueeze(0)
420
 
421
- if self.training:
422
- causal_mask = None
423
- #elif kwargs.get("use_causal_mask", False):
424
- else:
425
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
426
  causal_mask = mask_function(
427
  config=self.config,
@@ -431,6 +430,8 @@ class Ministral3Model(Ministral3PreTrainedModel):
431
  past_key_values=past_key_values,
432
  position_ids=position_ids,
433
  )
 
 
434
 
435
  hidden_states = inputs_embeds
436
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
 
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
+ #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
32
  def rotate_half(x):
33
  """Rotates half the hidden dims of the input."""
 
418
  if position_ids is None:
419
  position_ids = cache_position.unsqueeze(0)
420
 
421
+ #if self.training:
422
+ # causal_mask = None
423
+ if kwargs.get("use_causal_mask", False):
 
424
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
425
  causal_mask = mask_function(
426
  config=self.config,
 
430
  past_key_values=past_key_values,
431
  position_ids=position_ids,
432
  )
433
+ else:
434
+ causal_mask = None
435
 
436
  hidden_states = inputs_embeds
437
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)