YongganFu commited on
Commit
15597b8
·
verified ·
1 Parent(s): f9e0c41

Upload model

Browse files
Files changed (2) hide show
  1. chat_utils.py +3 -2
  2. modeling_ministral.py +6 -5
chat_utils.py CHANGED
@@ -133,7 +133,7 @@ def generate_with_prefix_cache_block_diff(
133
  layer.self_attn.diffusion_lm=False
134
 
135
  # Compute KV cache for the prompt initially
136
- output = model(prompt, use_cache=True)
137
  past_key_values = output.past_key_values
138
 
139
  if causal_context:
@@ -230,7 +230,8 @@ def generate_with_prefix_cache_block_diff(
230
  output = model(
231
  x_accum[:, block_slice],
232
  past_key_values=past_key_values,
233
- use_cache=True
 
234
  )
235
  past_key_values = output.past_key_values
236
 
 
133
  layer.self_attn.diffusion_lm=False
134
 
135
  # Compute KV cache for the prompt initially
136
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context)
137
  past_key_values = output.past_key_values
138
 
139
  if causal_context:
 
230
  output = model(
231
  x_accum[:, block_slice],
232
  past_key_values=past_key_values,
233
+ use_cache=True,
234
+ use_causal_mask=causal_context
235
  )
236
  past_key_values = output.past_key_values
237
 
modeling_ministral.py CHANGED
@@ -11,7 +11,7 @@ from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
@@ -27,6 +27,7 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
 
30
 
31
  def rotate_half(x):
32
  """Rotates half the hidden dims of the input."""
@@ -417,10 +418,7 @@ class Ministral3Model(Ministral3PreTrainedModel):
417
  if position_ids is None:
418
  position_ids = cache_position.unsqueeze(0)
419
 
420
- if self.training:
421
- causal_mask = None
422
-
423
- else:
424
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
425
  causal_mask = mask_function(
426
  config=self.config,
@@ -430,6 +428,9 @@ class Ministral3Model(Ministral3PreTrainedModel):
430
  past_key_values=past_key_values,
431
  position_ids=position_ids,
432
  )
 
 
 
433
 
434
  hidden_states = inputs_embeds
435
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
 
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
 
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
+ #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
32
  def rotate_half(x):
33
  """Rotates half the hidden dims of the input."""
 
418
  if position_ids is None:
419
  position_ids = cache_position.unsqueeze(0)
420
 
421
+ if kwargs.get("use_causal_mask", False):
 
 
 
422
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
  causal_mask = mask_function(
424
  config=self.config,
 
428
  past_key_values=past_key_values,
429
  position_ids=position_ids,
430
  )
431
+
432
+ else:
433
+ causal_mask = None
434
 
435
  hidden_states = inputs_embeds
436
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)