abhgarg commited on
Commit
ddd0759
·
verified ·
1 Parent(s): 0eeb742

Clean up rope params; ensure transformers 4.55/5.0 compatibility

Browse files

- Remove duplicate top-level rope_scaling block and stray rope_theta from config.json
- Remove duplicate 'type' key from rope_parameters
- For 3B-Base/8B-Base: set max_position_embeddings=4096 and factor=0.25 to match training
- Mirror rope_theta and rope_scaling from rope_parameters in MinistralDLMConfig for v4.55 yarn
- Drop unused sdpa_mask_older_torch import (removed in transformers v5.0)
- Bump transformers_version to 5.0.0
- In linear_spec_generate_mp, guard direct past_kv.key_cache / value_cache access behind a hasattr(past_kv, 'layers') check so v5.0's DynamicCache API works too

config.json CHANGED
@@ -54,27 +54,13 @@
54
  "mscale_all_dim": 1.0,
55
  "original_max_position_embeddings": 16384,
56
  "rope_theta": 1000000000.0,
57
- "rope_type": "yarn",
58
- "type": "yarn"
59
  },
60
- "rope_scaling": {
61
- "beta_fast": 32.0,
62
- "beta_slow": 1.0,
63
- "factor": 16.0,
64
- "llama_4_scaling_beta": 0.1,
65
- "mscale": 1.0,
66
- "mscale_all_dim": 1.0,
67
- "original_max_position_embeddings": 16384,
68
- "rope_theta": 1000000.0,
69
- "rope_type": "yarn",
70
- "type": "yarn"
71
- },
72
- "rope_theta": 1000000000.0,
73
  "sliding_window": null,
74
  "tie_word_embeddings": false,
75
  "tok_mask_half_life_ratio": null,
76
  "torch_dtype": "bfloat16",
77
- "transformers_version": "4.55.4",
78
  "use_cache": false,
79
  "vocab_size": 131072
80
  }
 
54
  "mscale_all_dim": 1.0,
55
  "original_max_position_embeddings": 16384,
56
  "rope_theta": 1000000000.0,
57
+ "rope_type": "yarn"
 
58
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  "sliding_window": null,
60
  "tie_word_embeddings": false,
61
  "tok_mask_half_life_ratio": null,
62
  "torch_dtype": "bfloat16",
63
+ "transformers_version": "5.0.0",
64
  "use_cache": false,
65
  "vocab_size": 131072
66
  }
configuration_ministral_dlm.py CHANGED
@@ -156,7 +156,6 @@ class MinistralDLMConfig(PretrainedConfig):
156
  tie_word_embeddings=False,
157
  rope_theta=1000000.0,
158
  rope_parameters=None,
159
- rope_scaling=None,
160
  attention_bias=False,
161
  attention_dropout=0.0,
162
  mlp_bias=False,
@@ -204,9 +203,11 @@ class MinistralDLMConfig(PretrainedConfig):
204
  self.initializer_range = initializer_range
205
  self.rms_norm_eps = rms_norm_eps
206
  self.use_cache = use_cache
207
- self.rope_theta = rope_theta
208
  self.rope_parameters = rope_parameters
209
- self.rope_scaling = rope_scaling
 
 
 
210
  self.attention_bias = attention_bias
211
  self.attention_dropout = attention_dropout
212
  self.mlp_bias = mlp_bias
 
156
  tie_word_embeddings=False,
157
  rope_theta=1000000.0,
158
  rope_parameters=None,
 
159
  attention_bias=False,
160
  attention_dropout=0.0,
161
  mlp_bias=False,
 
203
  self.initializer_range = initializer_range
204
  self.rms_norm_eps = rms_norm_eps
205
  self.use_cache = use_cache
 
206
  self.rope_parameters = rope_parameters
207
+ # `rope_theta` is read at the top level by transformers v4.55's yarn impl; mirror from rope_parameters when present.
208
+ self.rope_theta = (rope_parameters or {}).get("rope_theta", rope_theta)
209
+ # v4.55 reads rope params from `rope_scaling`; in v5.0 `rope_scaling` is a property alias for rope_parameters.
210
+ self.rope_scaling = rope_parameters
211
  self.attention_bias = attention_bias
212
  self.attention_dropout = attention_dropout
213
  self.mlp_bias = mlp_bias
modeling_ministral.py CHANGED
@@ -11,7 +11,7 @@ from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
 
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
modeling_ministral_dlm.py CHANGED
@@ -1489,9 +1489,15 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1489
  layer.self_attn.diffusion_lm = val
1490
 
1491
  def _crop_cache(kv, length):
 
1492
  for li in range(len(kv)):
1493
- kv.key_cache[li] = kv.key_cache[li][:, :, :length]
1494
- kv.value_cache[li] = kv.value_cache[li][:, :, :length]
 
 
 
 
 
1495
  kv._seen_tokens = length
1496
 
1497
  # ----- tree verify helpers (inlined) -----
@@ -1546,9 +1552,14 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1546
  candidate_blocks[pi, p] = combo[ci]
1547
 
1548
  # Expand KV cache batch dimension (shared, no copy)
1549
- for li in range(len(past_kv.key_cache)):
1550
- past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
1551
- past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
 
 
 
 
 
1552
 
1553
  # Batched causal verify — uses flash attention + GQA
1554
  _set_dlm(False)
@@ -1583,9 +1594,14 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1583
  accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
1584
 
1585
  # Extract winning path's KV cache slice
1586
- for li in range(len(past_kv.key_cache)):
1587
- past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
1588
- past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
 
 
 
 
 
1589
  _crop_cache(past_kv, cache_len + best_acc)
1590
 
1591
  return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]
 
1489
  layer.self_attn.diffusion_lm = val
1490
 
1491
  def _crop_cache(kv, length):
1492
+ # transformers 4.55 exposes .key_cache/.value_cache lists; 5.0 moved them under .layers[i].keys/.values.
1493
  for li in range(len(kv)):
1494
+ if hasattr(kv, 'layers'):
1495
+ layer = kv.layers[li]
1496
+ layer.keys = layer.keys[:, :, :length]
1497
+ layer.values = layer.values[:, :, :length]
1498
+ else:
1499
+ kv.key_cache[li] = kv.key_cache[li][:, :, :length]
1500
+ kv.value_cache[li] = kv.value_cache[li][:, :, :length]
1501
  kv._seen_tokens = length
1502
 
1503
  # ----- tree verify helpers (inlined) -----
 
1552
  candidate_blocks[pi, p] = combo[ci]
1553
 
1554
  # Expand KV cache batch dimension (shared, no copy)
1555
+ for li in range(len(past_kv)):
1556
+ if hasattr(past_kv, 'layers'):
1557
+ layer = past_kv.layers[li]
1558
+ layer.keys = layer.keys.expand(num_paths, -1, -1, -1)
1559
+ layer.values = layer.values.expand(num_paths, -1, -1, -1)
1560
+ else:
1561
+ past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
1562
+ past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
1563
 
1564
  # Batched causal verify — uses flash attention + GQA
1565
  _set_dlm(False)
 
1594
  accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
1595
 
1596
  # Extract winning path's KV cache slice
1597
+ for li in range(len(past_kv)):
1598
+ if hasattr(past_kv, 'layers'):
1599
+ layer = past_kv.layers[li]
1600
+ layer.keys = layer.keys[best_pidx:best_pidx+1].contiguous()
1601
+ layer.values = layer.values[best_pidx:best_pidx+1].contiguous()
1602
+ else:
1603
+ past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
1604
+ past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
1605
  _crop_cache(past_kv, cache_len + best_acc)
1606
 
1607
  return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]