YongganFu commited on
Commit
36a1c93
·
verified ·
1 Parent(s): a11afdb

Upload model

Browse files
Files changed (2) hide show
  1. chat_utils.py +20 -0
  2. modeling_ministral_dlm.py +407 -29
chat_utils.py CHANGED
@@ -113,6 +113,7 @@ def generate_with_prefix_cache_block_diff(
113
  shift_logits=False,
114
  neg_entropy=False,
115
  causal_context=False,
 
116
  ):
117
  dream_style=shift_logits
118
  # Initialize the accumulator
@@ -221,6 +222,16 @@ def generate_with_prefix_cache_block_diff(
221
  cur[transfer_idx] = x0[transfer_idx]
222
  x_accum[:, block_slice] = cur
223
 
 
 
 
 
 
 
 
 
 
 
224
  if causal_context:
225
  for layer in model_module.encoder.layers:
226
  if hasattr(layer.self_attn, 'diffusion_lm'):
@@ -244,4 +255,13 @@ def generate_with_prefix_cache_block_diff(
244
  # refresh context-next logit for the next block
245
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
246
 
 
 
 
 
 
 
 
 
 
247
  return x_accum, nfe
 
113
  shift_logits=False,
114
  neg_entropy=False,
115
  causal_context=False,
116
+ eos_token_id=None,
117
  ):
118
  dream_style=shift_logits
119
  # Initialize the accumulator
 
222
  cur[transfer_idx] = x0[transfer_idx]
223
  x_accum[:, block_slice] = cur
224
 
225
+ if eos_token_id is not None:
226
+ block_tokens = x_accum[:, block_slice] # (B, Lb)
227
+ eos_mask = (block_tokens == eos_token_id) # (B, Lb)
228
+ any_eos = eos_mask.any(dim=1) # (B,)
229
+ if any_eos.any():
230
+ after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
231
+ mask_before = (block_tokens == mask_id) & ~after_eos
232
+ if (any_eos & ~mask_before.any(dim=1)).any():
233
+ break
234
+
235
  if causal_context:
236
  for layer in model_module.encoder.layers:
237
  if hasattr(layer.self_attn, 'diffusion_lm'):
 
255
  # refresh context-next logit for the next block
256
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
257
 
258
+ if eos_token_id is not None:
259
+ gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
260
+ is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
261
+ has_eos = is_eos.any(dim=1) # (B,)
262
+ if has_eos.all():
263
+ first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
264
+ max_eos = first_eos_pos.max().item()
265
+ return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
266
+
267
  return x_accum, nfe
modeling_ministral_dlm.py CHANGED
@@ -13,7 +13,7 @@ from torch import nn
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
  from transformers.utils import ModelOutput
15
 
16
- from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
 
18
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
 
@@ -49,6 +49,43 @@ class MinistralDiffOutputWithPast(ModelOutput):
49
  def fused_flex_attention(q, k, v, block_mask=None):
50
  return flex_attention(q, k, v, block_mask=block_mask)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
53
  class MinistralFlexAttention(Ministral3Attention):
54
  def __init__(self, *args, **kwargs):
@@ -69,11 +106,47 @@ class MinistralFlexAttention(Ministral3Attention):
69
 
70
  self.block_size = self.block_size_orig
71
  self.mode = self.config.dlm_paradigm
 
72
 
73
  import torch._dynamo.config as dcfg
74
  dcfg.cache_size_limit = 512
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def set_attention_mode(self, mode, block_size=None):
78
  self.mode = mode
79
  self.block_size = block_size
@@ -225,40 +298,131 @@ class MinistralFlexAttention(Ministral3Attention):
225
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
226
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
227
 
228
- key_states = repeat_kv(key_states, self.num_key_value_groups)
229
- value_states = repeat_kv(value_states, self.num_key_value_groups)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- if self.mode == 'bidirectional':
232
- if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
233
- block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
234
- else:
235
- block_mask = self.bidirectional_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- elif self.mode == 'autoregressive':
238
- if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
239
- block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
240
- else:
241
- block_mask = self.autoregressive_mask
242
 
243
- elif self.mode == 'block_diff':
244
- if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
245
- block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
246
- else:
247
- block_mask = self.block_diff_mask
248
- elif self.mode == 'sbd_block_diff':
249
- if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
250
- block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
 
 
 
 
 
 
 
 
251
  else:
252
- block_mask = self.sbd_block_diff_mask
253
- else:
254
- raise ValueError(f"Unknown attention mode: {self.mode}")
255
 
256
- attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
257
- attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
258
 
259
- attn_output = self.o_proj(attn_output)
260
 
261
- return attn_output, None
262
 
263
 
264
  def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
@@ -713,7 +877,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
713
  )
714
 
715
 
716
- def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0):
717
  out_ids, nfe = generate_with_prefix_cache_block_diff(
718
  model=self,
719
  prompt=prompt_ids,
@@ -727,8 +891,222 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
727
  shift_logits=shift_logits,
728
  neg_entropy=False,
729
  causal_context=causal_context,
 
730
  )
731
 
732
  return out_ids, nfe
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
 
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
  from transformers.utils import ModelOutput
15
 
16
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
17
 
18
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
 
 
49
  def fused_flex_attention(q, k, v, block_mask=None):
50
  return flex_attention(q, k, v, block_mask=block_mask)
51
 
52
+
53
+ def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
54
+ """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
55
+ if hasattr(past_key_values, 'crop'):
56
+ past_key_values.crop(max_length)
57
+ else:
58
+ for layer_idx in range(len(past_key_values)):
59
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
60
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
61
+ past_key_values._seen_tokens = max_length
62
+
63
+
64
+ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
65
+ """After quadratic decoding, extract only draft tokens (first of each block) from cache."""
66
+ for layer_idx in range(len(past_key_values)):
67
+ if hasattr(past_key_values, 'layers'):
68
+ layer_cache = past_key_values.layers[layer_idx]
69
+ k, v = layer_cache.keys, layer_cache.values
70
+ else:
71
+ k = past_key_values.key_cache[layer_idx]
72
+ v = past_key_values.value_cache[layer_idx]
73
+
74
+ clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
75
+ clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
76
+ new_k = torch.cat([clean_k, draft_k], dim=2)
77
+ new_v = torch.cat([clean_v, draft_v], dim=2)
78
+
79
+ if hasattr(past_key_values, 'layers'):
80
+ layer_cache.keys = new_k
81
+ layer_cache.values = new_v
82
+ else:
83
+ past_key_values.key_cache[layer_idx] = new_k
84
+ past_key_values.value_cache[layer_idx] = new_v
85
+
86
+ past_key_values._seen_tokens = clean_len + block_length
87
+
88
+
89
  # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
90
  class MinistralFlexAttention(Ministral3Attention):
91
  def __init__(self, *args, **kwargs):
 
106
 
107
  self.block_size = self.block_size_orig
108
  self.mode = self.config.dlm_paradigm
109
+ self._quadratic_block_mask = {}
110
 
111
  import torch._dynamo.config as dcfg
112
  dcfg.cache_size_limit = 512
113
 
114
 
115
+ def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
116
+ if block_length not in self._quadratic_block_mask:
117
+ draft_len = block_length * (block_length + 1)
118
+
119
+ def quadratic(b, h, q_idx, kv_idx):
120
+ first_clean = torch.logical_and(
121
+ kv_idx % (block_length + 1) == 0,
122
+ kv_idx < draft_len,
123
+ )
124
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
125
+ block_q = q_idx // (block_length + 1)
126
+ block_kv = kv_idx // (block_length + 1)
127
+ same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
128
+ same_block_except_first = torch.logical_and(
129
+ same_block,
130
+ q_idx % (block_length + 1) != 0,
131
+ )
132
+ draft_part = torch.logical_or(first_clean, same_block_except_first)
133
+ clean_part = kv_idx >= draft_len
134
+ return torch.logical_or(draft_part, clean_part)
135
+
136
+ block_mask = create_block_mask(
137
+ quadratic,
138
+ B=None,
139
+ H=None,
140
+ Q_LEN=draft_len,
141
+ KV_LEN=draft_len + self.config.max_position_embeddings,
142
+ device="cuda",
143
+ )
144
+
145
+ self._quadratic_block_mask[block_length] = block_mask
146
+
147
+ return self._quadratic_block_mask[block_length]
148
+
149
+
150
  def set_attention_mode(self, mode, block_size=None):
151
  self.mode = mode
152
  self.block_size = block_size
 
298
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
299
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
300
 
301
+ tidar_inference_mode = getattr(self.config, "tidar_inference_mode", None)
302
+ if tidar_inference_mode is not None:
303
+ if tidar_inference_mode == "quadratic":
304
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
305
+ if block_length is None:
306
+ raise ValueError("SBD quadratic decoding requires block_length in config.")
307
+ if past_key_values is not None:
308
+ seq_len = key_states.shape[2]
309
+ draft_len = block_length * (block_length + 1)
310
+
311
+ clean_keys = key_states[:, :, :-draft_len]
312
+ draft_keys = key_states[:, :, -draft_len:]
313
+ clean_values = value_states[:, :, :-draft_len]
314
+ draft_values = value_states[:, :, -draft_len:]
315
+ key_states = torch.cat([draft_keys, clean_keys], dim=2)
316
+ value_states = torch.cat([draft_values, clean_values], dim=2)
317
+
318
+ block_mask: BlockMask = self._get_sbd_inference_quadratic_decoding_block_mask(
319
+ block_length=block_length
320
+ )
321
+ block_mask.seq_lengths = (draft_len, seq_len)
322
+ else:
323
+ seq_len = query_states.shape[2]
324
+ draft_len = block_length * (block_length + 1)
325
+ clean_len = seq_len - draft_len
326
+
327
+ def _causal_mask(b, h, q_idx, kv_idx):
328
+ return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
329
+
330
+ def _draft2clean_mask(b, h, q_idx, kv_idx):
331
+ full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
332
+ first_clean = torch.logical_and(
333
+ q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
334
+ )
335
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
336
+ return torch.logical_or(full_clean, first_clean)
337
+
338
+ def _draft_mask(b, h, q_idx, kv_idx):
339
+ block_q = (q_idx - clean_len) // (block_length + 1)
340
+ block_kv = (kv_idx - clean_len) // (block_length + 1)
341
+ quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
342
+ same_block = torch.logical_and(block_q == block_kv, quadrant)
343
+ same_block_except_first = torch.logical_and(
344
+ same_block,
345
+ (q_idx - clean_len) % (block_length + 1) != 0,
346
+ )
347
+ return torch.logical_and(block_q == block_kv, same_block_except_first)
348
+
349
+ mask = or_masks(_causal_mask, _draft2clean_mask)
350
+ mask = or_masks(mask, _draft_mask)
351
+
352
+ block_mask = create_block_mask(
353
+ mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
354
+ )
355
 
356
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
357
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
358
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
359
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
360
+ attn_output = self.o_proj(attn_output)
361
+ return attn_output, None
362
+
363
+ elif tidar_inference_mode == "default":
364
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
365
+ if block_length is None:
366
+ raise ValueError("SBD default decoding requires block_length in config.")
367
+ seq_len = query_states.shape[2]
368
+ prefix_len = seq_len - block_length
369
+
370
+ def _clean_q_mask(b, h, q_idx, kv_idx):
371
+ return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
372
+
373
+ def _noisy_q_mask(b, h, q_idx, kv_idx):
374
+ return q_idx >= prefix_len
375
+
376
+ block_mask = create_block_mask(
377
+ or_masks(_clean_q_mask, _noisy_q_mask),
378
+ B=None,
379
+ H=None,
380
+ Q_LEN=seq_len,
381
+ KV_LEN=seq_len,
382
+ )
383
+
384
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
385
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
386
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
387
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
388
+ attn_output = self.o_proj(attn_output)
389
+ return attn_output, None
390
+
391
+ else:
392
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
393
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
394
 
395
+ if self.mode == 'bidirectional':
396
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
397
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
398
+ else:
399
+ block_mask = self.bidirectional_mask
400
 
401
+ elif self.mode == 'autoregressive':
402
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
403
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
404
+ else:
405
+ block_mask = self.autoregressive_mask
406
+
407
+ elif self.mode == 'block_diff':
408
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
409
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
410
+ else:
411
+ block_mask = self.block_diff_mask
412
+ elif self.mode == 'sbd_block_diff':
413
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
414
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
415
+ else:
416
+ block_mask = self.sbd_block_diff_mask
417
  else:
418
+ raise ValueError(f"Unknown attention mode: {self.mode}")
 
 
419
 
420
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
421
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
422
 
423
+ attn_output = self.o_proj(attn_output)
424
 
425
+ return attn_output, None
426
 
427
 
428
  def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
 
877
  )
878
 
879
 
880
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None):
881
  out_ids, nfe = generate_with_prefix_cache_block_diff(
882
  model=self,
883
  prompt=prompt_ids,
 
891
  shift_logits=shift_logits,
892
  neg_entropy=False,
893
  causal_context=causal_context,
894
+ eos_token_id=eos_token_id,
895
  )
896
 
897
  return out_ids, nfe
898
 
899
+
900
+ @torch.no_grad()
901
+ def sbd_inference_diffusion_quadratic(
902
+ self,
903
+ clean_input_ids: Optional[torch.Tensor],
904
+ draft_input_ids: torch.Tensor,
905
+ block_length: int,
906
+ draft_only: bool = False,
907
+ past_key_values: Optional[Cache] = None,
908
+ use_cache: bool = False,
909
+ ):
910
+ """SBD quadratic inference (injected by build_hf_tidar_repo)."""
911
+ enc_config = self.encoder.config
912
+ enc_config.use_sbd_objective = True
913
+ enc_config.block_length = block_length
914
+
915
+ if draft_only:
916
+ assert clean_input_ids is not None
917
+
918
+ if use_cache and past_key_values is None:
919
+ past_key_values = DynamicCache()
920
+
921
+ enc_config.tidar_inference_mode = "default"
922
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
923
+ outputs = self.encoder(
924
+ input_ids=input_ids,
925
+ position_ids=None,
926
+ past_key_values=past_key_values,
927
+ use_cache=use_cache,
928
+ is_training=False,
929
+ )
930
+
931
+ hidden_states = outputs.last_hidden_state
932
+ logits = self.diffusion_head(hidden_states)
933
+
934
+ past_key_values = getattr(outputs, "past_key_values", None)
935
+ if use_cache and past_key_values is not None:
936
+ _crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
937
+
938
+ return logits, past_key_values
939
+ else:
940
+ enc_config.tidar_inference_mode = "quadratic"
941
+
942
+ draft_len = block_length * (block_length + 1)
943
+ draft_input_ids = torch.cat(
944
+ [
945
+ draft_input_ids.view(-1, block_length, 1),
946
+ torch.full(
947
+ (draft_input_ids.shape[0], block_length, block_length),
948
+ fill_value=self.config.mask_token_id,
949
+ device=draft_input_ids.device,
950
+ ),
951
+ ],
952
+ dim=-1,
953
+ ).view(-1, draft_len)
954
+
955
+ if use_cache:
956
+ assert past_key_values is not None, (
957
+ "Past key values should be provided when using cache, e.g. run draft_only=True first."
958
+ )
959
+ assert clean_input_ids is None, (
960
+ "Clean input ids should already be in cache, thus none should be provided."
961
+ )
962
+ clean_len = past_key_values.get_seq_length()
963
+ input_ids = draft_input_ids
964
+ else:
965
+ clean_len = clean_input_ids.shape[1]
966
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
967
+
968
+ per_block_position_ids = torch.arange(
969
+ clean_len, clean_len + block_length + 1, device=draft_input_ids.device
970
+ )[None,].repeat(block_length, 1)
971
+ per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
972
+
973
+ if use_cache:
974
+ position_ids = per_block_position_ids.view(-1)[None,]
975
+ else:
976
+ clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
977
+ position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
978
+
979
+ outputs = self.encoder(
980
+ input_ids=input_ids,
981
+ position_ids=position_ids,
982
+ past_key_values=past_key_values,
983
+ use_cache=use_cache,
984
+ is_training=False,
985
+ )
986
+
987
+ hidden_states = outputs.last_hidden_state
988
+ logits = self.diffusion_head(hidden_states)
989
+ past_key_values = getattr(outputs, "past_key_values", None)
990
+
991
+ if use_cache and past_key_values is not None:
992
+ _extract_draft_kv_cache(past_key_values, clean_len, block_length)
993
+
994
+ return logits, past_key_values
995
+
996
+ @torch.no_grad()
997
+ def tidar_generate(
998
+ self,
999
+ prompt_ids: torch.Tensor,
1000
+ max_new_tokens: int = 128,
1001
+ steps: int = 128,
1002
+ block_length: int = 16,
1003
+ threshold: Optional[float] = None,
1004
+ temperature: float = 0.0,
1005
+ mask_token_id: Optional[int] = None,
1006
+ eos_token_id: Optional[int] = None,
1007
+ ):
1008
+ """TiDAR quadratic speculative decoding (injected by build_hf_tidar_repo)."""
1009
+ self.config.use_sbd_objective = True
1010
+ self.config.dlm_paradigm = "sbd"
1011
+
1012
+ if prompt_ids.shape[0] != 1:
1013
+ raise ValueError("TiDAR quadratic decoding currently requires batch_size == 1")
1014
+
1015
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1016
+ if eos_token_id is None:
1017
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1018
+
1019
+ x = torch.full(
1020
+ (1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
1021
+ token_mask_id,
1022
+ dtype=torch.long,
1023
+ device=prompt_ids.device,
1024
+ )
1025
+ x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
1026
+
1027
+ if max_new_tokens % block_length != 0:
1028
+ raise ValueError("max_new_tokens must be divisible by block_length")
1029
+ num_blocks = max_new_tokens // block_length
1030
+ if steps % num_blocks != 0:
1031
+ raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
1032
+
1033
+ prompt_len = prompt_ids.shape[1]
1034
+ nfe = 0
1035
+ nfe += 1
1036
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1037
+ clean_input_ids=x[:, :prompt_len],
1038
+ draft_input_ids=x[:, prompt_len : prompt_len + block_length],
1039
+ block_length=block_length,
1040
+ draft_only=True,
1041
+ use_cache=True,
1042
+ )
1043
+
1044
+ logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
1045
+ logits_proposal[:, 1] = logits_proposal[:, 0]
1046
+ logits_proposal = logits_proposal[:, 1:]
1047
+ x0_proposal = torch.argmax(logits_proposal, dim=-1)
1048
+ x[:, prompt_len : prompt_len + block_length] = x0_proposal
1049
+
1050
+ total_accept_token = 0
1051
+ while True:
1052
+ nfe += 1
1053
+ block_start = prompt_len + total_accept_token
1054
+ block_end = block_start + block_length
1055
+ draft_input_ids = x[:, block_start:block_end]
1056
+
1057
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1058
+ clean_input_ids=None,
1059
+ draft_input_ids=draft_input_ids,
1060
+ block_length=block_length,
1061
+ draft_only=False,
1062
+ past_key_values=past_key_values,
1063
+ use_cache=True,
1064
+ )
1065
+
1066
+ useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1067
+ if threshold is None:
1068
+ useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1069
+ else:
1070
+ if not (0.0 <= threshold <= 1.0):
1071
+ raise ValueError("threshold must be between 0 and 1")
1072
+ mix_logits = useful_token_logits[:, :, 0] * threshold + useful_token_logits[:, :, 1] * (1 - threshold)
1073
+ useful_token_logits[:, :, 0] = mix_logits
1074
+ useful_token_logits[:, :, 1] = mix_logits
1075
+
1076
+ if temperature > 0:
1077
+ useful_token_logits = useful_token_logits / temperature
1078
+
1079
+ useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
1080
+ new_draft_input_ids = useful_token_pred[:, 0, 1:]
1081
+ accept_cnt = 1
1082
+
1083
+ while accept_cnt < block_length:
1084
+ if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
1085
+ break
1086
+ new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
1087
+ accept_cnt += 1
1088
+
1089
+ x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
1090
+
1091
+ # EoS early stopping: all accepted tokens are finalized left-to-right,
1092
+ # so if any is EoS we can truncate and return immediately.
1093
+ if eos_token_id is not None:
1094
+ accepted = x[0, block_start : block_start + accept_cnt]
1095
+ eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
1096
+ if len(eos_positions) > 0:
1097
+ first_eos_rel = eos_positions[0].item()
1098
+ total_accept_token += first_eos_rel + 1
1099
+ output_end = prompt_len + total_accept_token
1100
+ return x[:, :output_end], nfe
1101
+
1102
+ x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1103
+ past_key_values.crop(block_start + accept_cnt)
1104
+ total_accept_token += accept_cnt
1105
+
1106
+ if total_accept_token >= max_new_tokens:
1107
+ break
1108
+
1109
+ return x[:, : -(block_length * 2)], nfe
1110
+
1111
+
1112
  __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]