YongganFu commited on
Commit
190e665
·
verified ·
1 Parent(s): 36e52f3

Patch modeling_ministral.py: handle both `input_embeds` and `inputs_embeds` kwargs

Browse files
Files changed (1) hide show
  1. modeling_ministral.py +9 -1
modeling_ministral.py CHANGED
@@ -420,13 +420,21 @@ class Ministral3Model(Ministral3PreTrainedModel):
420
 
421
  if kwargs.get("use_causal_mask", False):
422
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
 
 
 
 
 
 
 
 
423
  causal_mask = mask_function(
424
  config=self.config,
425
- input_embeds=inputs_embeds,
426
  attention_mask=attention_mask,
427
  cache_position=cache_position,
428
  past_key_values=past_key_values,
429
  position_ids=position_ids,
 
430
  )
431
 
432
  else:
 
420
 
421
  if kwargs.get("use_causal_mask", False):
422
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
+ # `create_causal_mask` renamed the embeds kwarg from `input_embeds` (transformers <= 4.x)
424
+ # to `inputs_embeds` (transformers >= 5.0). Detect which the installed version uses.
425
+ import inspect
426
+ mask_input_kw = (
427
+ "inputs_embeds"
428
+ if "inputs_embeds" in inspect.signature(mask_function).parameters
429
+ else "input_embeds"
430
+ )
431
  causal_mask = mask_function(
432
  config=self.config,
 
433
  attention_mask=attention_mask,
434
  cache_position=cache_position,
435
  past_key_values=past_key_values,
436
  position_ids=position_ids,
437
+ **{mask_input_kw: inputs_embeds},
438
  )
439
 
440
  else: