YongganFu commited on
Commit
5df8662
·
verified ·
1 Parent(s): 5cddf22

Upload model

Browse files
config.json CHANGED
@@ -22,6 +22,7 @@
22
  "dlm_paradigm": "bidirectional",
23
  "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
 
25
  "enforce_mask": false,
26
  "eos_token_id": 2,
27
  "global_loss_avg": false,
 
22
  "dlm_paradigm": "bidirectional",
23
  "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
25
+ "enable_self_spec": false,
26
  "enforce_mask": false,
27
  "eos_token_id": 2,
28
  "global_loss_avg": false,
configuration_ministral_dlm.py CHANGED
@@ -112,6 +112,9 @@ class MinistralDLMConfig(PretrainedConfig):
112
  Adaptive permutation ratio for each block.
113
  ada_perm_ratio_global (`float`, *optional*):
114
  Adaptive permutation ratio for global.
 
 
 
115
  """
116
 
117
  model_type = "ministral_dlm"
@@ -181,6 +184,7 @@ class MinistralDLMConfig(PretrainedConfig):
181
  ada_perm_ratio_per_block=None,
182
  ada_perm_ratio_global=None,
183
  ada_dlm_loss_ratio=None,
 
184
  **kwargs,
185
  ):
186
  self.vocab_size = vocab_size
@@ -234,6 +238,7 @@ class MinistralDLMConfig(PretrainedConfig):
234
  self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
235
  self.ada_perm_ratio_global = ada_perm_ratio_global
236
  self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
 
237
  super().__init__(
238
  pad_token_id=pad_token_id,
239
  bos_token_id=bos_token_id,
 
112
  Adaptive permutation ratio for each block.
113
  ada_perm_ratio_global (`float`, *optional*):
114
  Adaptive permutation ratio for global.
115
+ enable_self_spec (`bool`, *optional*, defaults to `False`):
116
+ Force MinistralFlexAttention for all paradigms (including bidirectional/autoregressive).
117
+ Required for self speculative generation; leave False for standard eval to use faster SDPA kernels.
118
  """
119
 
120
  model_type = "ministral_dlm"
 
184
  ada_perm_ratio_per_block=None,
185
  ada_perm_ratio_global=None,
186
  ada_dlm_loss_ratio=None,
187
+ enable_self_spec=False,
188
  **kwargs,
189
  ):
190
  self.vocab_size = vocab_size
 
238
  self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
239
  self.ada_perm_ratio_global = ada_perm_ratio_global
240
  self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
241
+ self.enable_self_spec = enable_self_spec
242
  super().__init__(
243
  pad_token_id=pad_token_id,
244
  bos_token_id=bos_token_id,
modeling_ministral_dlm.py CHANGED
@@ -90,7 +90,8 @@ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block
90
  class MinistralFlexAttention(Ministral3Attention):
91
  def __init__(self, *args, **kwargs):
92
  super().__init__(*args, **kwargs)
93
-
 
94
  self.block_size_orig = self.config.block_size
95
 
96
  if self.config.dlm_paradigm == 'bidirectional':
@@ -151,31 +152,15 @@ class MinistralFlexAttention(Ministral3Attention):
151
  self.mode = mode
152
  self.block_size = block_size
153
 
154
- def compute_block_mask(self, mode, q_len, block_size=None):
155
 
156
  def bidirectional_mask(b, h, q, kv):
157
  return (q >= kv) | (q < kv)
158
 
159
  def autoregressive_mask(b, h, q, kv):
160
  return (q >= kv)
161
-
162
  def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
163
- """
164
- Constructs the specialized block diffusion attention mask for training
165
- composed of three masks:
166
- - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
167
- - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
168
- - **Block Causal Mask (M_BC)**: Attention to update x0
169
- Args:
170
- b, h: Batch and head indices (ignored for mask logic).
171
- q_idx, kv_idx: Query and Key indices.
172
- seq_len: Total sequence length.
173
- block_size: Defines the block structure.
174
- Returns:
175
- A boolean attention mask.
176
- """
177
-
178
- # Indicate whether token belongs to xt or x0
179
  x0_flag_q = (q_idx >= n)
180
  x0_flag_kv = (kv_idx >= n)
181
 
@@ -238,15 +223,23 @@ class MinistralFlexAttention(Ministral3Attention):
238
  attn_mask = autoregressive_mask
239
  elif mode == 'block_diff':
240
  assert block_size is not None
241
- attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
242
  elif mode == 'sbd_block_diff':
243
  assert block_size is not None
244
- attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, q_len//2)
245
  else:
246
  raise ValueError(f"Unknown attention mode: {mode}")
247
 
 
 
 
 
 
 
 
 
248
  block_mask = create_block_mask(
249
- attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
250
  )
251
 
252
  return block_mask
@@ -298,9 +291,9 @@ class MinistralFlexAttention(Ministral3Attention):
298
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
299
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
300
 
301
- tidar_inference_mode = getattr(self.config, "tidar_inference_mode", None)
302
- if tidar_inference_mode is not None:
303
- if tidar_inference_mode == "quadratic":
304
  block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
305
  if block_length is None:
306
  raise ValueError("SBD quadratic decoding requires block_length in config.")
@@ -360,7 +353,7 @@ class MinistralFlexAttention(Ministral3Attention):
360
  attn_output = self.o_proj(attn_output)
361
  return attn_output, None
362
 
363
- elif tidar_inference_mode == "default":
364
  block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
365
  if block_length is None:
366
  raise ValueError("SBD default decoding requires block_length in config.")
@@ -449,11 +442,12 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
449
  diffusion_config = copy.deepcopy(config)
450
  diffusion_config.diffusion_lm = True
451
 
 
 
452
  if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
453
  diffusion_config.attn_class = MinistralFlexAttention
454
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
455
- diffusion_config.attn_class = Ministral3Attention
456
-
457
  if config.dlm_paradigm == 'autoregressive':
458
  diffusion_config.diffusion_lm = False
459
  else:
@@ -907,7 +901,6 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
907
  past_key_values: Optional[Cache] = None,
908
  use_cache: bool = False,
909
  ):
910
- """SBD quadratic inference (injected by build_hf_tidar_repo)."""
911
  enc_config = self.encoder.config
912
  enc_config.use_sbd_objective = True
913
  enc_config.block_length = block_length
@@ -918,7 +911,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
918
  if use_cache and past_key_values is None:
919
  past_key_values = DynamicCache()
920
 
921
- enc_config.tidar_inference_mode = "default"
922
  input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
923
  outputs = self.encoder(
924
  input_ids=input_ids,
@@ -937,7 +930,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
937
 
938
  return logits, past_key_values
939
  else:
940
- enc_config.tidar_inference_mode = "quadratic"
941
 
942
  draft_len = block_length * (block_length + 1)
943
  draft_input_ids = torch.cat(
@@ -994,23 +987,22 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
994
  return logits, past_key_values
995
 
996
  @torch.no_grad()
997
- def tidar_generate(
998
  self,
999
  prompt_ids: torch.Tensor,
1000
  max_new_tokens: int = 128,
1001
  steps: int = 128,
1002
  block_length: int = 16,
1003
- threshold: Optional[float] = None,
1004
  temperature: float = 0.0,
1005
  mask_token_id: Optional[int] = None,
1006
  eos_token_id: Optional[int] = None,
1007
  ):
1008
- """TiDAR quadratic speculative decoding (injected by build_hf_tidar_repo)."""
1009
  self.config.use_sbd_objective = True
1010
  self.config.dlm_paradigm = "sbd"
1011
 
1012
  if prompt_ids.shape[0] != 1:
1013
- raise ValueError("TiDAR quadratic decoding currently requires batch_size == 1")
1014
 
1015
  token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1016
  if eos_token_id is None:
@@ -1064,12 +1056,12 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1064
  )
1065
 
1066
  useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1067
- if threshold is None:
1068
  useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1069
  else:
1070
- if not (0.0 <= threshold <= 1.0):
1071
- raise ValueError("threshold must be between 0 and 1")
1072
- mix_logits = useful_token_logits[:, :, 0] * threshold + useful_token_logits[:, :, 1] * (1 - threshold)
1073
  useful_token_logits[:, :, 0] = mix_logits
1074
  useful_token_logits[:, :, 1] = mix_logits
1075
 
 
90
  class MinistralFlexAttention(Ministral3Attention):
91
  def __init__(self, *args, **kwargs):
92
  super().__init__(*args, **kwargs)
93
+
94
+ self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
95
  self.block_size_orig = self.config.block_size
96
 
97
  if self.config.dlm_paradigm == 'bidirectional':
 
152
  self.mode = mode
153
  self.block_size = block_size
154
 
155
+ def compute_block_mask(self, mode, q_len=None, block_size=None):
156
 
157
  def bidirectional_mask(b, h, q, kv):
158
  return (q >= kv) | (q < kv)
159
 
160
  def autoregressive_mask(b, h, q, kv):
161
  return (q >= kv)
162
+
163
  def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  x0_flag_q = (q_idx >= n)
165
  x0_flag_kv = (kv_idx >= n)
166
 
 
223
  attn_mask = autoregressive_mask
224
  elif mode == 'block_diff':
225
  assert block_size is not None
226
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
227
  elif mode == 'sbd_block_diff':
228
  assert block_size is not None
229
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
230
  else:
231
  raise ValueError(f"Unknown attention mode: {mode}")
232
 
233
+ if q_len is not None:
234
+ Q_LEN = q_len
235
+ else:
236
+ if mode in ['block_diff', 'sbd_block_diff']:
237
+ Q_LEN = self.max_seq_length * 2
238
+ else:
239
+ Q_LEN = self.max_seq_length
240
+
241
  block_mask = create_block_mask(
242
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
243
  )
244
 
245
  return block_mask
 
291
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
292
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
293
 
294
+ self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
295
+ if self_spec_inference_mode is not None:
296
+ if self_spec_inference_mode == "quadratic":
297
  block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
298
  if block_length is None:
299
  raise ValueError("SBD quadratic decoding requires block_length in config.")
 
353
  attn_output = self.o_proj(attn_output)
354
  return attn_output, None
355
 
356
+ elif self_spec_inference_mode == "default":
357
  block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
358
  if block_length is None:
359
  raise ValueError("SBD default decoding requires block_length in config.")
 
442
  diffusion_config = copy.deepcopy(config)
443
  diffusion_config.diffusion_lm = True
444
 
445
+ use_flex = getattr(config, 'enable_self_spec', False)
446
+
447
  if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
448
  diffusion_config.attn_class = MinistralFlexAttention
449
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
450
+ diffusion_config.attn_class = MinistralFlexAttention if use_flex else Ministral3Attention
 
451
  if config.dlm_paradigm == 'autoregressive':
452
  diffusion_config.diffusion_lm = False
453
  else:
 
901
  past_key_values: Optional[Cache] = None,
902
  use_cache: bool = False,
903
  ):
 
904
  enc_config = self.encoder.config
905
  enc_config.use_sbd_objective = True
906
  enc_config.block_length = block_length
 
911
  if use_cache and past_key_values is None:
912
  past_key_values = DynamicCache()
913
 
914
+ enc_config.self_spec_inference_mode = "default"
915
  input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
916
  outputs = self.encoder(
917
  input_ids=input_ids,
 
930
 
931
  return logits, past_key_values
932
  else:
933
+ enc_config.self_spec_inference_mode = "quadratic"
934
 
935
  draft_len = block_length * (block_length + 1)
936
  draft_input_ids = torch.cat(
 
987
  return logits, past_key_values
988
 
989
  @torch.no_grad()
990
+ def self_spec_generate(
991
  self,
992
  prompt_ids: torch.Tensor,
993
  max_new_tokens: int = 128,
994
  steps: int = 128,
995
  block_length: int = 16,
996
+ ar_mix_weight: Optional[float] = None,
997
  temperature: float = 0.0,
998
  mask_token_id: Optional[int] = None,
999
  eos_token_id: Optional[int] = None,
1000
  ):
 
1001
  self.config.use_sbd_objective = True
1002
  self.config.dlm_paradigm = "sbd"
1003
 
1004
  if prompt_ids.shape[0] != 1:
1005
+ raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
1006
 
1007
  token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1008
  if eos_token_id is None:
 
1056
  )
1057
 
1058
  useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1059
+ if ar_mix_weight is None:
1060
  useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1061
  else:
1062
+ if not (0.0 <= ar_mix_weight <= 1.0):
1063
+ raise ValueError("ar_mix_weight must be between 0 and 1")
1064
+ mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
1065
  useful_token_logits[:, :, 0] = mix_logits
1066
  useful_token_logits[:, :, 1] = mix_logits
1067