abhgarg commited on
Commit
96d0421
·
verified ·
1 Parent(s): 36a1c93

Clean up rope params; ensure transformers 4.55/5.0 compatibility

Browse files

- Remove duplicate top-level rope_scaling block and stray rope_theta from config.json
- Remove duplicate 'type' key from rope_parameters
- For 3B-Base/8B-Base: set max_position_embeddings=4096 and factor=0.25 to match training
- Mirror rope_theta and rope_scaling from rope_parameters in MinistralDLMConfig for v4.55 yarn
- Drop unused sdpa_mask_older_torch import (removed in transformers v5.0)
- Bump transformers_version to 5.0.0

config.json CHANGED
@@ -31,7 +31,7 @@
31
  "initializer_range": 0.02,
32
  "intermediate_size": 9216,
33
  "mask_token_id": 100,
34
- "max_position_embeddings": 262144,
35
  "mlp_bias": false,
36
  "model_type": "ministral_dlm",
37
  "multi_sampling": null,
@@ -47,33 +47,19 @@
47
  "rope_parameters": {
48
  "beta_fast": 32.0,
49
  "beta_slow": 1.0,
50
- "factor": 16.0,
51
  "llama_4_scaling_beta": 0.1,
52
  "mscale": 1.0,
53
  "mscale_all_dim": 1.0,
54
  "original_max_position_embeddings": 16384,
55
  "rope_theta": 1000000.0,
56
- "rope_type": "yarn",
57
- "type": "yarn"
58
  },
59
- "rope_scaling": {
60
- "beta_fast": 32.0,
61
- "beta_slow": 1.0,
62
- "factor": 16.0,
63
- "llama_4_scaling_beta": 0.1,
64
- "mscale": 1.0,
65
- "mscale_all_dim": 1.0,
66
- "original_max_position_embeddings": 16384,
67
- "rope_theta": 1000000.0,
68
- "rope_type": "yarn",
69
- "type": "yarn"
70
- },
71
- "rope_theta": 1000000.0,
72
  "sliding_window": null,
73
  "tie_word_embeddings": false,
74
  "tok_mask_half_life_ratio": null,
75
  "torch_dtype": "bfloat16",
76
- "transformers_version": "4.55.4",
77
  "use_cache": false,
78
  "vocab_size": 131072
79
  }
 
31
  "initializer_range": 0.02,
32
  "intermediate_size": 9216,
33
  "mask_token_id": 100,
34
+ "max_position_embeddings": 4096,
35
  "mlp_bias": false,
36
  "model_type": "ministral_dlm",
37
  "multi_sampling": null,
 
47
  "rope_parameters": {
48
  "beta_fast": 32.0,
49
  "beta_slow": 1.0,
50
+ "factor": 0.25,
51
  "llama_4_scaling_beta": 0.1,
52
  "mscale": 1.0,
53
  "mscale_all_dim": 1.0,
54
  "original_max_position_embeddings": 16384,
55
  "rope_theta": 1000000.0,
56
+ "rope_type": "yarn"
 
57
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  "sliding_window": null,
59
  "tie_word_embeddings": false,
60
  "tok_mask_half_life_ratio": null,
61
  "torch_dtype": "bfloat16",
62
+ "transformers_version": "5.0.0",
63
  "use_cache": false,
64
  "vocab_size": 131072
65
  }
configuration_ministral_dlm.py CHANGED
@@ -153,7 +153,6 @@ class MinistralDLMConfig(PretrainedConfig):
153
  tie_word_embeddings=False,
154
  rope_theta=1000000.0,
155
  rope_parameters=None,
156
- rope_scaling=None,
157
  attention_bias=False,
158
  attention_dropout=0.0,
159
  mlp_bias=False,
@@ -200,9 +199,11 @@ class MinistralDLMConfig(PretrainedConfig):
200
  self.initializer_range = initializer_range
201
  self.rms_norm_eps = rms_norm_eps
202
  self.use_cache = use_cache
203
- self.rope_theta = rope_theta
204
  self.rope_parameters = rope_parameters
205
- self.rope_scaling = rope_scaling
 
 
 
206
  self.attention_bias = attention_bias
207
  self.attention_dropout = attention_dropout
208
  self.mlp_bias = mlp_bias
 
153
  tie_word_embeddings=False,
154
  rope_theta=1000000.0,
155
  rope_parameters=None,
 
156
  attention_bias=False,
157
  attention_dropout=0.0,
158
  mlp_bias=False,
 
199
  self.initializer_range = initializer_range
200
  self.rms_norm_eps = rms_norm_eps
201
  self.use_cache = use_cache
 
202
  self.rope_parameters = rope_parameters
203
+ # `rope_theta` is read at the top level by transformers v4.55's yarn impl; mirror from rope_parameters when present.
204
+ self.rope_theta = (rope_parameters or {}).get("rope_theta", rope_theta)
205
+ # v4.55 reads rope params from `rope_scaling`; in v5.0 `rope_scaling` is a property alias for rope_parameters.
206
+ self.rope_scaling = rope_parameters
207
  self.attention_bias = attention_bias
208
  self.attention_dropout = attention_dropout
209
  self.mlp_bias = mlp_bias
modeling_ministral.py CHANGED
@@ -11,7 +11,7 @@ from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
 
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,