YongganFu commited on
Commit
f9e0c41
·
verified ·
1 Parent(s): 262d402

Upload model

Browse files
chat_utils.py CHANGED
@@ -133,7 +133,7 @@ def generate_with_prefix_cache_block_diff(
133
  layer.self_attn.diffusion_lm=False
134
 
135
  # Compute KV cache for the prompt initially
136
- output = model(prompt, use_cache=True, use_causal_mask=causal_context)
137
  past_key_values = output.past_key_values
138
 
139
  if causal_context:
@@ -230,8 +230,7 @@ def generate_with_prefix_cache_block_diff(
230
  output = model(
231
  x_accum[:, block_slice],
232
  past_key_values=past_key_values,
233
- use_cache=True,
234
- use_causal_mask=causal_context
235
  )
236
  past_key_values = output.past_key_values
237
 
 
133
  layer.self_attn.diffusion_lm=False
134
 
135
  # Compute KV cache for the prompt initially
136
+ output = model(prompt, use_cache=True)
137
  past_key_values = output.past_key_values
138
 
139
  if causal_context:
 
230
  output = model(
231
  x_accum[:, block_slice],
232
  past_key_values=past_key_values,
233
+ use_cache=True
 
234
  )
235
  past_key_values = output.past_key_values
236
 
config.json CHANGED
@@ -69,7 +69,6 @@
69
  "type": "yarn"
70
  },
71
  "rope_theta": 1000000.0,
72
- "seq_length": 8192,
73
  "sliding_window": null,
74
  "tie_word_embeddings": false,
75
  "tok_mask_half_life_ratio": null,
 
69
  "type": "yarn"
70
  },
71
  "rope_theta": 1000000.0,
 
72
  "sliding_window": null,
73
  "tie_word_embeddings": false,
74
  "tok_mask_half_life_ratio": null,
configuration_ministral_dlm.py CHANGED
@@ -70,8 +70,6 @@ class MinistralDLMConfig(PretrainedConfig):
70
  Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
  sliding_window (`int`, *optional*, defaults to None):
72
  Sliding window attention size.
73
- seq_length (`int`, *optional*, defaults to 8192):
74
- Sequence length for training.
75
  mask_token_id (`int`, *optional*, defaults to -1):
76
  Token ID for masking in diffusion.
77
  dlm_type (`str`, *optional*, defaults to 'llada'):
@@ -161,7 +159,6 @@ class MinistralDLMConfig(PretrainedConfig):
161
  mlp_bias=False,
162
  sliding_window=None,
163
  attn_implementation="sdpa",
164
- seq_length=8192,
165
  mask_token_id=-1,
166
  dlm_type='llada',
167
  random_length_prob=None,
@@ -214,7 +211,6 @@ class MinistralDLMConfig(PretrainedConfig):
214
  rope_config_validation(self)
215
 
216
  self.attn_implementation = attn_implementation
217
- self.seq_length = seq_length
218
 
219
  self.mask_token_id = mask_token_id
220
  self.dlm_type = dlm_type
 
70
  Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
  sliding_window (`int`, *optional*, defaults to None):
72
  Sliding window attention size.
 
 
73
  mask_token_id (`int`, *optional*, defaults to -1):
74
  Token ID for masking in diffusion.
75
  dlm_type (`str`, *optional*, defaults to 'llada'):
 
159
  mlp_bias=False,
160
  sliding_window=None,
161
  attn_implementation="sdpa",
 
162
  mask_token_id=-1,
163
  dlm_type='llada',
164
  random_length_prob=None,
 
211
  rope_config_validation(self)
212
 
213
  self.attn_implementation = attn_implementation
 
214
 
215
  self.mask_token_id = mask_token_id
216
  self.dlm_type = dlm_type
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:581e534c77fd49b8ab1d234f65bf08b88b03bd4dd10e397285a3441260957a8d
3
  size 16979144720
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73af2cd1c982f85bac01c7da43765deb3f2deced76eb93dbd2a6a968ff531349
3
  size 16979144720
modeling_ministral.py CHANGED
@@ -11,7 +11,7 @@ from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
@@ -27,7 +27,6 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
- #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
32
  def rotate_half(x):
33
  """Rotates half the hidden dims of the input."""
@@ -418,9 +417,10 @@ class Ministral3Model(Ministral3PreTrainedModel):
418
  if position_ids is None:
419
  position_ids = cache_position.unsqueeze(0)
420
 
421
- #if self.training:
422
- # causal_mask = None
423
- if kwargs.get("use_causal_mask", False):
 
424
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
425
  causal_mask = mask_function(
426
  config=self.config,
@@ -430,8 +430,6 @@ class Ministral3Model(Ministral3PreTrainedModel):
430
  past_key_values=past_key_values,
431
  position_ids=position_ids,
432
  )
433
- else:
434
- causal_mask = None
435
 
436
  hidden_states = inputs_embeds
437
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
 
11
  from transformers.generation import GenerationMixin
12
  # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
  from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
 
27
  # from transformers.utils.generic import maybe_autocast
28
  from .configuration_ministral_dlm import MinistralDLMConfig
29
 
 
30
 
31
  def rotate_half(x):
32
  """Rotates half the hidden dims of the input."""
 
417
  if position_ids is None:
418
  position_ids = cache_position.unsqueeze(0)
419
 
420
+ if self.training:
421
+ causal_mask = None
422
+
423
+ else:
424
  mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
425
  causal_mask = mask_function(
426
  config=self.config,
 
430
  past_key_values=past_key_values,
431
  position_ids=position_ids,
432
  )
 
 
433
 
434
  hidden_states = inputs_embeds
435
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
modeling_ministral_dlm.py CHANGED
@@ -54,7 +54,6 @@ class MinistralFlexAttention(Ministral3Attention):
54
  def __init__(self, *args, **kwargs):
55
  super().__init__(*args, **kwargs)
56
 
57
- self.max_seq_length = self.config.seq_length
58
  self.block_size_orig = self.config.block_size
59
 
60
  if self.config.dlm_paradigm == 'bidirectional':
@@ -62,9 +61,9 @@ class MinistralFlexAttention(Ministral3Attention):
62
  elif self.config.dlm_paradigm == 'autoregressive':
63
  self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
64
  elif self.config.dlm_paradigm == 'block_diff':
65
- self.block_diff_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size_orig)
66
  elif self.config.dlm_paradigm == 'sbd_block_diff':
67
- self.sbd_block_diff_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size_orig)
68
  else:
69
  raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
70
 
@@ -79,7 +78,7 @@ class MinistralFlexAttention(Ministral3Attention):
79
  self.mode = mode
80
  self.block_size = block_size
81
 
82
- def compute_block_mask(self, mode, q_len=None, block_size=None):
83
 
84
  def bidirectional_mask(b, h, q, kv):
85
  return (q >= kv) | (q < kv)
@@ -166,23 +165,15 @@ class MinistralFlexAttention(Ministral3Attention):
166
  attn_mask = autoregressive_mask
167
  elif mode == 'block_diff':
168
  assert block_size is not None
169
- attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
170
  elif mode == 'sbd_block_diff':
171
  assert block_size is not None
172
- attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
173
  else:
174
  raise ValueError(f"Unknown attention mode: {mode}")
175
 
176
- if q_len is not None:
177
- Q_LEN = q_len
178
- else:
179
- if mode in ['block_diff', 'sbd_block_diff']:
180
- Q_LEN = self.max_seq_length * 2
181
- else:
182
- Q_LEN = self.max_seq_length
183
-
184
  block_mask = create_block_mask(
185
- attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
186
  )
187
 
188
  return block_mask
@@ -238,24 +229,24 @@ class MinistralFlexAttention(Ministral3Attention):
238
  value_states = repeat_kv(value_states, self.num_key_value_groups)
239
 
240
  if self.mode == 'bidirectional':
241
- if q_len != self.bidirectional_mask.shape[-2]:
242
  block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
243
  else:
244
  block_mask = self.bidirectional_mask
245
 
246
  elif self.mode == 'autoregressive':
247
- if q_len != self.autoregressive_mask.shape[-2]:
248
  block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
249
  else:
250
  block_mask = self.autoregressive_mask
251
 
252
  elif self.mode == 'block_diff':
253
- if self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
254
  block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
255
  else:
256
  block_mask = self.block_diff_mask
257
  elif self.mode == 'sbd_block_diff':
258
- if self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
259
  block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
260
  else:
261
  block_mask = self.sbd_block_diff_mask
 
54
  def __init__(self, *args, **kwargs):
55
  super().__init__(*args, **kwargs)
56
 
 
57
  self.block_size_orig = self.config.block_size
58
 
59
  if self.config.dlm_paradigm == 'bidirectional':
 
61
  elif self.config.dlm_paradigm == 'autoregressive':
62
  self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
63
  elif self.config.dlm_paradigm == 'block_diff':
64
+ self.block_diff_mask = None
65
  elif self.config.dlm_paradigm == 'sbd_block_diff':
66
+ self.sbd_block_diff_mask = None
67
  else:
68
  raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
69
 
 
78
  self.mode = mode
79
  self.block_size = block_size
80
 
81
+ def compute_block_mask(self, mode, q_len, block_size=None):
82
 
83
  def bidirectional_mask(b, h, q, kv):
84
  return (q >= kv) | (q < kv)
 
165
  attn_mask = autoregressive_mask
166
  elif mode == 'block_diff':
167
  assert block_size is not None
168
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
169
  elif mode == 'sbd_block_diff':
170
  assert block_size is not None
171
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, q_len//2)
172
  else:
173
  raise ValueError(f"Unknown attention mode: {mode}")
174
 
 
 
 
 
 
 
 
 
175
  block_mask = create_block_mask(
176
+ attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
177
  )
178
 
179
  return block_mask
 
229
  value_states = repeat_kv(value_states, self.num_key_value_groups)
230
 
231
  if self.mode == 'bidirectional':
232
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
233
  block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
234
  else:
235
  block_mask = self.bidirectional_mask
236
 
237
  elif self.mode == 'autoregressive':
238
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
239
  block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
240
  else:
241
  block_mask = self.autoregressive_mask
242
 
243
  elif self.mode == 'block_diff':
244
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
245
  block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
246
  else:
247
  block_mask = self.block_diff_mask
248
  elif self.mode == 'sbd_block_diff':
249
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
250
  block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
251
  else:
252
  block_mask = self.sbd_block_diff_mask