YongganFu commited on
Commit
4bc9a52
·
verified ·
1 Parent(s): 190e665

Patch modeling_ministral.py: filter create_causal_mask kwargs by signature (handles input_embeds rename + cache_position removal in 5.9.0)

Browse files
Files changed (1) hide show
  1. modeling_ministral.py +16 -16
modeling_ministral.py CHANGED
@@ -420,23 +420,23 @@ class Ministral3Model(Ministral3PreTrainedModel):
420
 
421
  if kwargs.get("use_causal_mask", False):
422
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
- # `create_causal_mask` renamed the embeds kwarg from `input_embeds` (transformers <= 4.x)
424
- # to `inputs_embeds` (transformers >= 5.0). Detect which the installed version uses.
 
 
425
  import inspect
426
- mask_input_kw = (
427
- "inputs_embeds"
428
- if "inputs_embeds" in inspect.signature(mask_function).parameters
429
- else "input_embeds"
430
- )
431
- causal_mask = mask_function(
432
- config=self.config,
433
- attention_mask=attention_mask,
434
- cache_position=cache_position,
435
- past_key_values=past_key_values,
436
- position_ids=position_ids,
437
- **{mask_input_kw: inputs_embeds},
438
- )
439
-
440
  else:
441
  causal_mask = None
442
 
 
420
 
421
  if kwargs.get("use_causal_mask", False):
422
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
+ # Build candidate kwargs and filter against the function's signature
424
+ # for cross-transformers-version compatibility:
425
+ # - `input_embeds` (<= 4.x) was renamed to `inputs_embeds` (>= 5.0)
426
+ # - `cache_position` was removed from the signature in 5.9.0
427
  import inspect
428
+ sig_params = inspect.signature(mask_function).parameters
429
+ embeds_kw = "inputs_embeds" if "inputs_embeds" in sig_params else "input_embeds"
430
+ candidate = {
431
+ "config": self.config,
432
+ "attention_mask": attention_mask,
433
+ "cache_position": cache_position,
434
+ "past_key_values": past_key_values,
435
+ "position_ids": position_ids,
436
+ embeds_kw: inputs_embeds,
437
+ }
438
+ causal_mask = mask_function(**{k: v for k, v in candidate.items() if k in sig_params})
439
+
 
 
440
  else:
441
  causal_mask = None
442