trias702 commited on
Commit
ff07748
·
1 Parent(s): 4cabc4d

Made some potential fixes for DSA, need to test if they work

Browse files
Files changed (2) hide show
  1. chat_utils.py +3 -2
  2. modeling_ministral.py +1 -2
chat_utils.py CHANGED
@@ -133,7 +133,7 @@ def generate_with_prefix_cache_block_diff(
133
  layer.self_attn.diffusion_lm=False
134
 
135
  # Compute KV cache for the prompt initially
136
- output = model(prompt, use_cache=True)
137
  past_key_values = output.past_key_values
138
 
139
  if causal_context:
@@ -230,7 +230,8 @@ def generate_with_prefix_cache_block_diff(
230
  output = model(
231
  x_accum[:, block_slice],
232
  past_key_values=past_key_values,
233
- use_cache=True
 
234
  )
235
  past_key_values = output.past_key_values
236
 
 
133
  layer.self_attn.diffusion_lm=False
134
 
135
  # Compute KV cache for the prompt initially
136
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context)
137
  past_key_values = output.past_key_values
138
 
139
  if causal_context:
 
230
  output = model(
231
  x_accum[:, block_slice],
232
  past_key_values=past_key_values,
233
+ use_cache=True,
234
+ use_causal_mask=causal_context
235
  )
236
  past_key_values = output.past_key_values
237
 
modeling_ministral.py CHANGED
@@ -419,8 +419,7 @@ class Ministral3Model(Ministral3PreTrainedModel):
419
 
420
  if self.training:
421
  causal_mask = None
422
-
423
- else:
424
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
425
  causal_mask = mask_function(
426
  config=self.config,
 
419
 
420
  if self.training:
421
  causal_mask = None
422
+ elif kwargs.get("use_causal_mask", False):
 
423
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
424
  causal_mask = mask_function(
425
  config=self.config,