YongganFu commited on
Commit
f318bfe
·
verified ·
1 Parent(s): d42bc62

Upload model

Browse files
chat_utils.py CHANGED
@@ -113,10 +113,13 @@ def generate_with_prefix_cache_block_diff(
113
  shift_logits=False,
114
  neg_entropy=False,
115
  causal_context=False,
 
 
 
116
  ):
117
  dream_style=shift_logits
118
- # Initialize the accumulator
119
  x_accum = prompt.clone()
 
120
 
121
  assert gen_length % block_length == 0
122
  num_blocks = gen_length // block_length
@@ -141,30 +144,66 @@ def generate_with_prefix_cache_block_diff(
141
  if hasattr(layer.self_attn, 'diffusion_lm'):
142
  layer.self_attn.diffusion_lm=True
143
 
 
 
 
 
 
 
 
 
 
 
144
  # For dream_style: store the "next token logit" of the context
145
  next_logits_context = None
146
  if dream_style:
147
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
148
 
149
  for num_block in range(num_blocks):
150
- # Create a new block with mask tokens (no seeding)
 
 
151
  mask_block = torch.ones(
152
  (prompt.shape[0], block_length),
153
  dtype=prompt.dtype,
154
  device=prompt.device
155
  ) * mask_id
 
 
156
 
157
  # Append the block of masks
158
  x_accum = torch.cat([x_accum, mask_block], dim=1)
159
  current_block_start = prompt.size(1) + num_block * block_length
160
  block_slice = slice(current_block_start, current_block_start + block_length)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # Build the initial mask for this block
163
  mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
164
 
165
  # Precompute the transfer schedule for this block
166
  if dream_style:
167
- # still denoise *all* positions (0..Lb-1), since none are seeded
168
  schedule_mask = mask_block_idx0
169
  else:
170
  schedule_mask = mask_block_idx0
@@ -221,6 +260,16 @@ def generate_with_prefix_cache_block_diff(
221
  cur[transfer_idx] = x0[transfer_idx]
222
  x_accum[:, block_slice] = cur
223
 
 
 
 
 
 
 
 
 
 
 
224
  if causal_context:
225
  for layer in model_module.encoder.layers:
226
  if hasattr(layer.self_attn, 'diffusion_lm'):
@@ -234,14 +283,31 @@ def generate_with_prefix_cache_block_diff(
234
  use_causal_mask=causal_context
235
  )
236
  past_key_values = output.past_key_values
 
237
 
238
  if causal_context:
239
  for layer in model_module.encoder.layers:
240
  if hasattr(layer.self_attn, 'diffusion_lm'):
241
  layer.self_attn.diffusion_lm=True
 
 
 
 
 
 
 
242
 
243
  if dream_style and num_block < num_blocks - 1:
244
  # refresh context-next logit for the next block
245
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
246
 
 
 
 
 
 
 
 
 
 
247
  return x_accum, nfe
 
113
  shift_logits=False,
114
  neg_entropy=False,
115
  causal_context=False,
116
+ eos_token_id=None,
117
+ max_thinking_tokens=None,
118
+ end_think_token_id=None,
119
  ):
120
  dream_style=shift_logits
 
121
  x_accum = prompt.clone()
122
+ B = prompt.shape[0]
123
 
124
  assert gen_length % block_length == 0
125
  num_blocks = gen_length // block_length
 
144
  if hasattr(layer.self_attn, 'diffusion_lm'):
145
  layer.self_attn.diffusion_lm=True
146
 
147
+ # Causal prefill: next token from last position (same as linear_spec_generate).
148
+ next_token = None
149
+ if causal_context:
150
+ last_logit = output.logits[:, -1, :]
151
+ if temperature > 0:
152
+ probs = torch.softmax(last_logit / temperature, dim=-1)
153
+ next_token = torch.multinomial(probs, num_samples=1)
154
+ else:
155
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
156
+
157
  # For dream_style: store the "next token logit" of the context
158
  next_logits_context = None
159
  if dream_style:
160
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
161
 
162
  for num_block in range(num_blocks):
163
+ # Create a new block with mask tokens; under causal context, seed position 0
164
+ # with the next-token prediction from the previous causal forward (prefill or
165
+ # post-block encode), matching linear_spec_generate.
166
  mask_block = torch.ones(
167
  (prompt.shape[0], block_length),
168
  dtype=prompt.dtype,
169
  device=prompt.device
170
  ) * mask_id
171
+ if causal_context:
172
+ mask_block[:, 0] = next_token[:, 0]
173
 
174
  # Append the block of masks
175
  x_accum = torch.cat([x_accum, mask_block], dim=1)
176
  current_block_start = prompt.size(1) + num_block * block_length
177
  block_slice = slice(current_block_start, current_block_start + block_length)
178
 
179
+ # ---- thinking budget enforcement ----
180
+ # If we've generated >= max_thinking_tokens without a </think>, inject one.
181
+ if end_think_token_id is not None and max_thinking_tokens is not None:
182
+ tokens_before_block = num_block * block_length
183
+ tokens_after_block = tokens_before_block + block_length
184
+ if tokens_after_block > max_thinking_tokens:
185
+ gen_so_far = x_accum[:, prompt.size(1):current_block_start]
186
+ has_end_think = (
187
+ (gen_so_far == end_think_token_id).any(dim=1)
188
+ if gen_so_far.size(1) > 0
189
+ else torch.zeros(B, dtype=torch.bool, device=prompt.device)
190
+ )
191
+ if not has_end_think.all():
192
+ if tokens_before_block < max_thinking_tokens:
193
+ offset = max_thinking_tokens - tokens_before_block
194
+ else:
195
+ offset = 0
196
+ inject_pos = current_block_start + offset
197
+ for b in range(B):
198
+ if not has_end_think[b]:
199
+ x_accum[b, inject_pos] = end_think_token_id
200
+
201
  # Build the initial mask for this block
202
  mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
203
 
204
  # Precompute the transfer schedule for this block
205
  if dream_style:
206
+ # masked positions only (position 0 may be causal-seeded, not mask_id)
207
  schedule_mask = mask_block_idx0
208
  else:
209
  schedule_mask = mask_block_idx0
 
260
  cur[transfer_idx] = x0[transfer_idx]
261
  x_accum[:, block_slice] = cur
262
 
263
+ if eos_token_id is not None:
264
+ block_tokens = x_accum[:, block_slice] # (B, Lb)
265
+ eos_mask = (block_tokens == eos_token_id) # (B, Lb)
266
+ any_eos = eos_mask.any(dim=1) # (B,)
267
+ if any_eos.any():
268
+ after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
269
+ mask_before = (block_tokens == mask_id) & ~after_eos
270
+ if (any_eos & ~mask_before.any(dim=1)).any():
271
+ break
272
+
273
  if causal_context:
274
  for layer in model_module.encoder.layers:
275
  if hasattr(layer.self_attn, 'diffusion_lm'):
 
283
  use_causal_mask=causal_context
284
  )
285
  past_key_values = output.past_key_values
286
+ nfe += 1
287
 
288
  if causal_context:
289
  for layer in model_module.encoder.layers:
290
  if hasattr(layer.self_attn, 'diffusion_lm'):
291
  layer.self_attn.diffusion_lm=True
292
+ # Next block's first position = greedy/sampled next token from this causal encode
293
+ last_logit = output.logits[:, -1, :]
294
+ if temperature > 0:
295
+ probs = torch.softmax(last_logit / temperature, dim=-1)
296
+ next_token = torch.multinomial(probs, num_samples=1)
297
+ else:
298
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
299
 
300
  if dream_style and num_block < num_blocks - 1:
301
  # refresh context-next logit for the next block
302
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
303
 
304
+ if eos_token_id is not None:
305
+ gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
306
+ is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
307
+ has_eos = is_eos.any(dim=1) # (B,)
308
+ if has_eos.all():
309
+ first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
310
+ max_eos = first_eos_pos.max().item()
311
+ return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
312
+
313
  return x_accum, nfe
config.json CHANGED
@@ -22,6 +22,7 @@
22
  "dlm_paradigm": "bidirectional",
23
  "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
 
25
  "enforce_mask": false,
26
  "eos_token_id": 2,
27
  "global_loss_avg": false,
 
22
  "dlm_paradigm": "bidirectional",
23
  "dlm_type": "llada",
24
  "dp_varying_mask_ratio": false,
25
+ "enable_self_spec": false,
26
  "enforce_mask": false,
27
  "eos_token_id": 2,
28
  "global_loss_avg": false,
configuration_ministral_dlm.py CHANGED
@@ -112,6 +112,9 @@ class MinistralDLMConfig(PretrainedConfig):
112
  Adaptive permutation ratio for each block.
113
  ada_perm_ratio_global (`float`, *optional*):
114
  Adaptive permutation ratio for global.
 
 
 
115
  """
116
 
117
  model_type = "ministral_dlm"
@@ -181,6 +184,7 @@ class MinistralDLMConfig(PretrainedConfig):
181
  ada_perm_ratio_per_block=None,
182
  ada_perm_ratio_global=None,
183
  ada_dlm_loss_ratio=None,
 
184
  **kwargs,
185
  ):
186
  self.vocab_size = vocab_size
@@ -234,6 +238,7 @@ class MinistralDLMConfig(PretrainedConfig):
234
  self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
235
  self.ada_perm_ratio_global = ada_perm_ratio_global
236
  self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
 
237
  super().__init__(
238
  pad_token_id=pad_token_id,
239
  bos_token_id=bos_token_id,
 
112
  Adaptive permutation ratio for each block.
113
  ada_perm_ratio_global (`float`, *optional*):
114
  Adaptive permutation ratio for global.
115
+ enable_self_spec (`bool`, *optional*, defaults to `False`):
116
+ Force MinistralFlexAttention for all paradigms (including bidirectional/autoregressive).
117
+ Required for self speculative generation; leave False for standard eval to use faster SDPA kernels.
118
  """
119
 
120
  model_type = "ministral_dlm"
 
184
  ada_perm_ratio_per_block=None,
185
  ada_perm_ratio_global=None,
186
  ada_dlm_loss_ratio=None,
187
+ enable_self_spec=False,
188
  **kwargs,
189
  ):
190
  self.vocab_size = vocab_size
 
238
  self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
239
  self.ada_perm_ratio_global = ada_perm_ratio_global
240
  self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
241
+ self.enable_self_spec = enable_self_spec
242
  super().__init__(
243
  pad_token_id=pad_token_id,
244
  bos_token_id=bos_token_id,
modeling_ministral_dlm.py CHANGED
@@ -13,7 +13,7 @@ from torch import nn
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
  from transformers.utils import ModelOutput
15
 
16
- from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
 
18
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
 
@@ -31,6 +31,7 @@ from .chat_utils import generate_with_prefix_cache_block_diff
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
 
34
 
35
  @dataclass
36
  class MinistralDiffOutputWithPast(ModelOutput):
@@ -49,11 +50,49 @@ class MinistralDiffOutputWithPast(ModelOutput):
49
  def fused_flex_attention(q, k, v, block_mask=None):
50
  return flex_attention(q, k, v, block_mask=block_mask)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
53
  class MinistralFlexAttention(Ministral3Attention):
54
  def __init__(self, *args, **kwargs):
55
  super().__init__(*args, **kwargs)
56
-
 
57
  self.block_size_orig = self.config.block_size
58
 
59
  if self.config.dlm_paradigm == 'bidirectional':
@@ -69,40 +108,60 @@ class MinistralFlexAttention(Ministral3Attention):
69
 
70
  self.block_size = self.block_size_orig
71
  self.mode = self.config.dlm_paradigm
 
72
 
73
  import torch._dynamo.config as dcfg
74
  dcfg.cache_size_limit = 512
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def set_attention_mode(self, mode, block_size=None):
78
  self.mode = mode
79
  self.block_size = block_size
80
 
81
- def compute_block_mask(self, mode, q_len, block_size=None):
82
 
83
  def bidirectional_mask(b, h, q, kv):
84
  return (q >= kv) | (q < kv)
85
 
86
  def autoregressive_mask(b, h, q, kv):
87
  return (q >= kv)
88
-
89
- def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
90
- """
91
- Constructs the specialized block diffusion attention mask for training
92
- composed of three masks:
93
- - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
94
- - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
95
- - **Block Causal Mask (M_BC)**: Attention to update x0
96
- Args:
97
- b, h: Batch and head indices (ignored for mask logic).
98
- q_idx, kv_idx: Query and Key indices.
99
- seq_len: Total sequence length.
100
- block_size: Defines the block structure.
101
- Returns:
102
- A boolean attention mask.
103
- """
104
 
105
- # Indicate whether token belongs to xt or x0
106
  x0_flag_q = (q_idx >= n)
107
  x0_flag_kv = (kv_idx >= n)
108
 
@@ -165,15 +224,23 @@ class MinistralFlexAttention(Ministral3Attention):
165
  attn_mask = autoregressive_mask
166
  elif mode == 'block_diff':
167
  assert block_size is not None
168
- attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
169
  elif mode == 'sbd_block_diff':
170
  assert block_size is not None
171
- attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, q_len//2)
172
  else:
173
  raise ValueError(f"Unknown attention mode: {mode}")
174
 
 
 
 
 
 
 
 
 
175
  block_mask = create_block_mask(
176
- attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
177
  )
178
 
179
  return block_mask
@@ -225,40 +292,131 @@ class MinistralFlexAttention(Ministral3Attention):
225
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
226
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
227
 
228
- key_states = repeat_kv(key_states, self.num_key_value_groups)
229
- value_states = repeat_kv(value_states, self.num_key_value_groups)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- if self.mode == 'bidirectional':
232
- if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
233
- block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
234
- else:
235
- block_mask = self.bidirectional_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- elif self.mode == 'autoregressive':
238
- if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
239
- block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
240
- else:
241
- block_mask = self.autoregressive_mask
242
 
243
- elif self.mode == 'block_diff':
244
- if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
245
- block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
246
- else:
247
- block_mask = self.block_diff_mask
248
- elif self.mode == 'sbd_block_diff':
249
- if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
250
- block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
 
 
 
 
 
 
 
 
251
  else:
252
- block_mask = self.sbd_block_diff_mask
253
- else:
254
- raise ValueError(f"Unknown attention mode: {self.mode}")
255
 
256
- attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
257
- attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
258
 
259
- attn_output = self.o_proj(attn_output)
260
 
261
- return attn_output, None
262
 
263
 
264
  def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
@@ -285,11 +443,12 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
285
  diffusion_config = copy.deepcopy(config)
286
  diffusion_config.diffusion_lm = True
287
 
 
 
288
  if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
289
  diffusion_config.attn_class = MinistralFlexAttention
290
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
291
- diffusion_config.attn_class = Ministral3Attention
292
-
293
  if config.dlm_paradigm == 'autoregressive':
294
  diffusion_config.diffusion_lm = False
295
  else:
@@ -713,7 +872,10 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
713
  )
714
 
715
 
716
- def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0):
 
 
 
717
  out_ids, nfe = generate_with_prefix_cache_block_diff(
718
  model=self,
719
  prompt=prompt_ids,
@@ -727,8 +889,956 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
727
  shift_logits=shift_logits,
728
  neg_entropy=False,
729
  causal_context=causal_context,
 
 
 
730
  )
731
 
732
  return out_ids, nfe
733
 
734
- __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
  from transformers.utils import ModelOutput
15
 
16
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
17
 
18
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
 
 
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
34
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
35
 
36
  @dataclass
37
  class MinistralDiffOutputWithPast(ModelOutput):
 
50
  def fused_flex_attention(q, k, v, block_mask=None):
51
  return flex_attention(q, k, v, block_mask=block_mask)
52
 
53
+
54
+ def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
55
+ """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
56
+ if hasattr(past_key_values, 'crop'):
57
+ past_key_values.crop(max_length)
58
+ else:
59
+ for layer_idx in range(len(past_key_values)):
60
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
61
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
62
+ past_key_values._seen_tokens = max_length
63
+
64
+
65
+ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
66
+ """After quadratic decoding, extract only draft tokens (first of each block) from cache."""
67
+ for layer_idx in range(len(past_key_values)):
68
+ if hasattr(past_key_values, 'layers'):
69
+ layer_cache = past_key_values.layers[layer_idx]
70
+ k, v = layer_cache.keys, layer_cache.values
71
+ else:
72
+ k = past_key_values.key_cache[layer_idx]
73
+ v = past_key_values.value_cache[layer_idx]
74
+
75
+ clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
76
+ clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
77
+ new_k = torch.cat([clean_k, draft_k], dim=2)
78
+ new_v = torch.cat([clean_v, draft_v], dim=2)
79
+
80
+ if hasattr(past_key_values, 'layers'):
81
+ layer_cache.keys = new_k
82
+ layer_cache.values = new_v
83
+ else:
84
+ past_key_values.key_cache[layer_idx] = new_k
85
+ past_key_values.value_cache[layer_idx] = new_v
86
+
87
+ past_key_values._seen_tokens = clean_len + block_length
88
+
89
+
90
  # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
91
  class MinistralFlexAttention(Ministral3Attention):
92
  def __init__(self, *args, **kwargs):
93
  super().__init__(*args, **kwargs)
94
+
95
+ self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
96
  self.block_size_orig = self.config.block_size
97
 
98
  if self.config.dlm_paradigm == 'bidirectional':
 
108
 
109
  self.block_size = self.block_size_orig
110
  self.mode = self.config.dlm_paradigm
111
+ self._quadratic_block_mask = {}
112
 
113
  import torch._dynamo.config as dcfg
114
  dcfg.cache_size_limit = 512
115
 
116
 
117
+ def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
118
+ if block_length not in self._quadratic_block_mask:
119
+ draft_len = block_length * (block_length + 1)
120
+
121
+ def quadratic(b, h, q_idx, kv_idx):
122
+ first_clean = torch.logical_and(
123
+ kv_idx % (block_length + 1) == 0,
124
+ kv_idx < draft_len,
125
+ )
126
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
127
+ block_q = q_idx // (block_length + 1)
128
+ block_kv = kv_idx // (block_length + 1)
129
+ same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
130
+ same_block_except_first = torch.logical_and(
131
+ same_block,
132
+ q_idx % (block_length + 1) != 0,
133
+ )
134
+ draft_part = torch.logical_or(first_clean, same_block_except_first)
135
+ clean_part = kv_idx >= draft_len
136
+ return torch.logical_or(draft_part, clean_part)
137
+
138
+ block_mask = create_block_mask(
139
+ quadratic,
140
+ B=None,
141
+ H=None,
142
+ Q_LEN=draft_len,
143
+ KV_LEN=draft_len + self.config.max_position_embeddings,
144
+ device="cuda",
145
+ )
146
+
147
+ self._quadratic_block_mask[block_length] = block_mask
148
+
149
+ return self._quadratic_block_mask[block_length]
150
+
151
+
152
  def set_attention_mode(self, mode, block_size=None):
153
  self.mode = mode
154
  self.block_size = block_size
155
 
156
+ def compute_block_mask(self, mode, q_len=None, block_size=None):
157
 
158
  def bidirectional_mask(b, h, q, kv):
159
  return (q >= kv) | (q < kv)
160
 
161
  def autoregressive_mask(b, h, q, kv):
162
  return (q >= kv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
165
  x0_flag_q = (q_idx >= n)
166
  x0_flag_kv = (kv_idx >= n)
167
 
 
224
  attn_mask = autoregressive_mask
225
  elif mode == 'block_diff':
226
  assert block_size is not None
227
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
228
  elif mode == 'sbd_block_diff':
229
  assert block_size is not None
230
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
231
  else:
232
  raise ValueError(f"Unknown attention mode: {mode}")
233
 
234
+ if q_len is not None:
235
+ Q_LEN = q_len
236
+ else:
237
+ if mode in ['block_diff', 'sbd_block_diff']:
238
+ Q_LEN = self.max_seq_length * 2
239
+ else:
240
+ Q_LEN = self.max_seq_length
241
+
242
  block_mask = create_block_mask(
243
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
244
  )
245
 
246
  return block_mask
 
292
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
293
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
294
 
295
+ self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
296
+ if self_spec_inference_mode is not None:
297
+ if self_spec_inference_mode == "quadratic":
298
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
299
+ if block_length is None:
300
+ raise ValueError("SBD quadratic decoding requires block_length in config.")
301
+ if past_key_values is not None:
302
+ seq_len = key_states.shape[2]
303
+ draft_len = block_length * (block_length + 1)
304
+
305
+ clean_keys = key_states[:, :, :-draft_len]
306
+ draft_keys = key_states[:, :, -draft_len:]
307
+ clean_values = value_states[:, :, :-draft_len]
308
+ draft_values = value_states[:, :, -draft_len:]
309
+ key_states = torch.cat([draft_keys, clean_keys], dim=2)
310
+ value_states = torch.cat([draft_values, clean_values], dim=2)
311
+
312
+ block_mask: BlockMask = self._get_sbd_inference_quadratic_decoding_block_mask(
313
+ block_length=block_length
314
+ )
315
+ block_mask.seq_lengths = (draft_len, seq_len)
316
+ else:
317
+ seq_len = query_states.shape[2]
318
+ draft_len = block_length * (block_length + 1)
319
+ clean_len = seq_len - draft_len
320
+
321
+ def _causal_mask(b, h, q_idx, kv_idx):
322
+ return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
323
+
324
+ def _draft2clean_mask(b, h, q_idx, kv_idx):
325
+ full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
326
+ first_clean = torch.logical_and(
327
+ q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
328
+ )
329
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
330
+ return torch.logical_or(full_clean, first_clean)
331
+
332
+ def _draft_mask(b, h, q_idx, kv_idx):
333
+ block_q = (q_idx - clean_len) // (block_length + 1)
334
+ block_kv = (kv_idx - clean_len) // (block_length + 1)
335
+ quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
336
+ same_block = torch.logical_and(block_q == block_kv, quadrant)
337
+ same_block_except_first = torch.logical_and(
338
+ same_block,
339
+ (q_idx - clean_len) % (block_length + 1) != 0,
340
+ )
341
+ return torch.logical_and(block_q == block_kv, same_block_except_first)
342
+
343
+ mask = or_masks(_causal_mask, _draft2clean_mask)
344
+ mask = or_masks(mask, _draft_mask)
345
+
346
+ block_mask = create_block_mask(
347
+ mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
348
+ )
349
 
350
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
351
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
352
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
353
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
354
+ attn_output = self.o_proj(attn_output)
355
+ return attn_output, None
356
+
357
+ elif self_spec_inference_mode == "default":
358
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
359
+ if block_length is None:
360
+ raise ValueError("SBD default decoding requires block_length in config.")
361
+ seq_len = query_states.shape[2]
362
+ prefix_len = seq_len - block_length
363
+
364
+ def _clean_q_mask(b, h, q_idx, kv_idx):
365
+ return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
366
+
367
+ def _noisy_q_mask(b, h, q_idx, kv_idx):
368
+ return q_idx >= prefix_len
369
+
370
+ block_mask = create_block_mask(
371
+ or_masks(_clean_q_mask, _noisy_q_mask),
372
+ B=None,
373
+ H=None,
374
+ Q_LEN=seq_len,
375
+ KV_LEN=seq_len,
376
+ )
377
+
378
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
379
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
380
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
381
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
382
+ attn_output = self.o_proj(attn_output)
383
+ return attn_output, None
384
+
385
+ else:
386
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
387
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
388
 
389
+ if self.mode == 'bidirectional':
390
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
391
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
392
+ else:
393
+ block_mask = self.bidirectional_mask
394
 
395
+ elif self.mode == 'autoregressive':
396
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
397
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
398
+ else:
399
+ block_mask = self.autoregressive_mask
400
+
401
+ elif self.mode == 'block_diff':
402
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
403
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
404
+ else:
405
+ block_mask = self.block_diff_mask
406
+ elif self.mode == 'sbd_block_diff':
407
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
408
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
409
+ else:
410
+ block_mask = self.sbd_block_diff_mask
411
  else:
412
+ raise ValueError(f"Unknown attention mode: {self.mode}")
 
 
413
 
414
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
415
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
416
 
417
+ attn_output = self.o_proj(attn_output)
418
 
419
+ return attn_output, None
420
 
421
 
422
  def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
 
443
  diffusion_config = copy.deepcopy(config)
444
  diffusion_config.diffusion_lm = True
445
 
446
+ use_flex = getattr(config, 'enable_self_spec', False)
447
+
448
  if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
449
  diffusion_config.attn_class = MinistralFlexAttention
450
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
451
+ diffusion_config.attn_class = MinistralFlexAttention if use_flex else Ministral3Attention
 
452
  if config.dlm_paradigm == 'autoregressive':
453
  diffusion_config.diffusion_lm = False
454
  else:
 
872
  )
873
 
874
 
875
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None, max_thinking_tokens=None, end_think_token_id=None):
876
+ if eos_token_id is None:
877
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
878
+
879
  out_ids, nfe = generate_with_prefix_cache_block_diff(
880
  model=self,
881
  prompt=prompt_ids,
 
889
  shift_logits=shift_logits,
890
  neg_entropy=False,
891
  causal_context=causal_context,
892
+ eos_token_id=eos_token_id,
893
+ max_thinking_tokens=max_thinking_tokens,
894
+ end_think_token_id=end_think_token_id,
895
  )
896
 
897
  return out_ids, nfe
898
 
899
+
900
+ @torch.no_grad()
901
+ def sbd_inference_diffusion_quadratic(
902
+ self,
903
+ clean_input_ids: Optional[torch.Tensor],
904
+ draft_input_ids: torch.Tensor,
905
+ block_length: int,
906
+ draft_only: bool = False,
907
+ past_key_values: Optional[Cache] = None,
908
+ use_cache: bool = False,
909
+ ):
910
+ enc_config = self.encoder.config
911
+ enc_config.use_sbd_objective = True
912
+ enc_config.block_length = block_length
913
+
914
+ if draft_only:
915
+ assert clean_input_ids is not None
916
+
917
+ if use_cache and past_key_values is None:
918
+ past_key_values = DynamicCache()
919
+
920
+ enc_config.self_spec_inference_mode = "default"
921
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
922
+ outputs = self.encoder(
923
+ input_ids=input_ids,
924
+ position_ids=None,
925
+ past_key_values=past_key_values,
926
+ use_cache=use_cache,
927
+ is_training=False,
928
+ )
929
+
930
+ hidden_states = outputs.last_hidden_state
931
+ logits = self.diffusion_head(hidden_states)
932
+
933
+ past_key_values = getattr(outputs, "past_key_values", None)
934
+ if use_cache and past_key_values is not None:
935
+ _crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
936
+
937
+ return logits, past_key_values
938
+ else:
939
+ enc_config.self_spec_inference_mode = "quadratic"
940
+
941
+ draft_len = block_length * (block_length + 1)
942
+ draft_input_ids = torch.cat(
943
+ [
944
+ draft_input_ids.view(-1, block_length, 1),
945
+ torch.full(
946
+ (draft_input_ids.shape[0], block_length, block_length),
947
+ fill_value=self.config.mask_token_id,
948
+ device=draft_input_ids.device,
949
+ ),
950
+ ],
951
+ dim=-1,
952
+ ).view(-1, draft_len)
953
+
954
+ if use_cache:
955
+ assert past_key_values is not None, (
956
+ "Past key values should be provided when using cache, e.g. run draft_only=True first."
957
+ )
958
+ assert clean_input_ids is None, (
959
+ "Clean input ids should already be in cache, thus none should be provided."
960
+ )
961
+ clean_len = past_key_values.get_seq_length()
962
+ input_ids = draft_input_ids
963
+ else:
964
+ clean_len = clean_input_ids.shape[1]
965
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
966
+
967
+ per_block_position_ids = torch.arange(
968
+ clean_len, clean_len + block_length + 1, device=draft_input_ids.device
969
+ )[None,].repeat(block_length, 1)
970
+ per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
971
+
972
+ if use_cache:
973
+ position_ids = per_block_position_ids.view(-1)[None,]
974
+ else:
975
+ clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
976
+ position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
977
+
978
+ outputs = self.encoder(
979
+ input_ids=input_ids,
980
+ position_ids=position_ids,
981
+ past_key_values=past_key_values,
982
+ use_cache=use_cache,
983
+ is_training=False,
984
+ )
985
+
986
+ hidden_states = outputs.last_hidden_state
987
+ logits = self.diffusion_head(hidden_states)
988
+ past_key_values = getattr(outputs, "past_key_values", None)
989
+
990
+ if use_cache and past_key_values is not None:
991
+ _extract_draft_kv_cache(past_key_values, clean_len, block_length)
992
+
993
+ return logits, past_key_values
994
+
995
+
996
+ @torch.no_grad()
997
+ def ar_generate(
998
+ self,
999
+ prompt_ids: torch.Tensor,
1000
+ max_new_tokens: int = 128,
1001
+ temperature: float = 0.0,
1002
+ eos_token_id: Optional[int] = None,
1003
+ max_thinking_tokens: Optional[int] = None,
1004
+ end_think_token_id: Optional[int] = None,
1005
+ ) -> tuple:
1006
+ """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
1007
+
1008
+ Bypasses MinistralDiffEncoderModel.forward() to avoid diffusion-specific
1009
+ code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
1010
+ position_ids, and use_cache so the KV cache and causal masking behave
1011
+ identically to MistralForCausalLM / vLLM.
1012
+
1013
+ Returns:
1014
+ (output_ids, nfe) where output_ids includes the prompt.
1015
+ """
1016
+ for layer in self.encoder.layers:
1017
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1018
+ layer.self_attn.diffusion_lm = False
1019
+
1020
+ if eos_token_id is None:
1021
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
1022
+
1023
+ device = prompt_ids.device
1024
+ batch_size, prompt_len = prompt_ids.shape
1025
+
1026
+ past_key_values = DynamicCache()
1027
+ cache_position = torch.arange(prompt_len, device=device)
1028
+ position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
1029
+
1030
+ enc_out = self.encoder(
1031
+ input_ids=prompt_ids,
1032
+ position_ids=position_ids,
1033
+ past_key_values=past_key_values,
1034
+ use_cache=True,
1035
+ cache_position=cache_position,
1036
+ )
1037
+ past_key_values = enc_out.past_key_values
1038
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1039
+
1040
+ generated_tokens = []
1041
+ nfe = 0
1042
+
1043
+ for step in range(max_new_tokens):
1044
+ nfe += 1
1045
+
1046
+ if temperature > 0:
1047
+ probs = torch.softmax(next_logit / temperature, dim=-1)
1048
+ next_token = torch.multinomial(probs, num_samples=1)
1049
+ else:
1050
+ next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
1051
+
1052
+ # ---- thinking budget enforcement ----
1053
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1054
+ if step >= max_thinking_tokens:
1055
+ if generated_tokens:
1056
+ gen_tensor = torch.cat(generated_tokens, dim=1)
1057
+ has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
1058
+ else:
1059
+ has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
1060
+ for b in range(batch_size):
1061
+ if not has_end_think[b]:
1062
+ next_token[b] = end_think_token_id
1063
+
1064
+ generated_tokens.append(next_token)
1065
+
1066
+ if eos_token_id is not None and (next_token == eos_token_id).all():
1067
+ break
1068
+
1069
+ if step < max_new_tokens - 1:
1070
+ cur_pos = prompt_len + step
1071
+ step_cache_pos = torch.tensor([cur_pos], device=device)
1072
+ step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
1073
+
1074
+ enc_out = self.encoder(
1075
+ input_ids=next_token,
1076
+ position_ids=step_pos_ids,
1077
+ past_key_values=past_key_values,
1078
+ use_cache=True,
1079
+ cache_position=step_cache_pos,
1080
+ )
1081
+ past_key_values = enc_out.past_key_values
1082
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1083
+
1084
+ all_generated = torch.cat(generated_tokens, dim=1)
1085
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1086
+ return output_ids, nfe
1087
+
1088
+
1089
+ @torch.no_grad()
1090
+ def self_spec_generate(
1091
+ self,
1092
+ prompt_ids: torch.Tensor,
1093
+ max_new_tokens: int = 128,
1094
+ steps: int = 128,
1095
+ block_length: int = 16,
1096
+ ar_mix_weight: Optional[float] = None,
1097
+ temperature: float = 0.0,
1098
+ mask_token_id: Optional[int] = None,
1099
+ eos_token_id: Optional[int] = None,
1100
+ max_thinking_tokens: Optional[int] = None,
1101
+ end_think_token_id: Optional[int] = None,
1102
+ ):
1103
+ self.config.use_sbd_objective = True
1104
+ self.config.dlm_paradigm = "sbd"
1105
+
1106
+ if prompt_ids.shape[0] != 1:
1107
+ raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
1108
+
1109
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1110
+ if eos_token_id is None:
1111
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1112
+
1113
+ x = torch.full(
1114
+ (1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
1115
+ token_mask_id,
1116
+ dtype=torch.long,
1117
+ device=prompt_ids.device,
1118
+ )
1119
+ x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
1120
+
1121
+ if max_new_tokens % block_length != 0:
1122
+ raise ValueError("max_new_tokens must be divisible by block_length")
1123
+ num_blocks = max_new_tokens // block_length
1124
+ if steps % num_blocks != 0:
1125
+ raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
1126
+
1127
+ prompt_len = prompt_ids.shape[1]
1128
+ nfe = 0
1129
+ nfe += 1
1130
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1131
+ clean_input_ids=x[:, :prompt_len],
1132
+ draft_input_ids=x[:, prompt_len : prompt_len + block_length],
1133
+ block_length=block_length,
1134
+ draft_only=True,
1135
+ use_cache=True,
1136
+ )
1137
+
1138
+ logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
1139
+ logits_proposal[:, 1] = logits_proposal[:, 0]
1140
+ logits_proposal = logits_proposal[:, 1:]
1141
+ x0_proposal = torch.argmax(logits_proposal, dim=-1)
1142
+ x[:, prompt_len : prompt_len + block_length] = x0_proposal
1143
+
1144
+ total_accept_token = 0
1145
+ while True:
1146
+ nfe += 1
1147
+ block_start = prompt_len + total_accept_token
1148
+ block_end = block_start + block_length
1149
+ draft_input_ids = x[:, block_start:block_end]
1150
+
1151
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1152
+ clean_input_ids=None,
1153
+ draft_input_ids=draft_input_ids,
1154
+ block_length=block_length,
1155
+ draft_only=False,
1156
+ past_key_values=past_key_values,
1157
+ use_cache=True,
1158
+ )
1159
+
1160
+ useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1161
+ if ar_mix_weight is None:
1162
+ useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1163
+ else:
1164
+ if not (0.0 <= ar_mix_weight <= 1.0):
1165
+ raise ValueError("ar_mix_weight must be between 0 and 1")
1166
+ mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
1167
+ useful_token_logits[:, :, 0] = mix_logits
1168
+ useful_token_logits[:, :, 1] = mix_logits
1169
+
1170
+ if temperature > 0:
1171
+ useful_token_logits = useful_token_logits / temperature
1172
+
1173
+ useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
1174
+ new_draft_input_ids = useful_token_pred[:, 0, 1:]
1175
+ accept_cnt = 1
1176
+
1177
+ while accept_cnt < block_length:
1178
+ if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
1179
+ break
1180
+ new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
1181
+ accept_cnt += 1
1182
+
1183
+ x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
1184
+
1185
+ # EoS early stopping: all accepted tokens are finalized left-to-right,
1186
+ # so if any is EoS we can truncate and return immediately.
1187
+ if eos_token_id is not None:
1188
+ accepted = x[0, block_start : block_start + accept_cnt]
1189
+ eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
1190
+ if len(eos_positions) > 0:
1191
+ first_eos_rel = eos_positions[0].item()
1192
+ total_accept_token += first_eos_rel + 1
1193
+ output_end = prompt_len + total_accept_token
1194
+ return x[:, :output_end], nfe
1195
+
1196
+ x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1197
+ past_key_values.crop(block_start + accept_cnt)
1198
+
1199
+ # ---- thinking budget enforcement ----
1200
+ # Insert end_think as the first token of the next draft block,
1201
+ # shifting all subsequent tokens right by 1 (discarding the last).
1202
+ # The first draft token is always accepted unconditionally, so
1203
+ # end_think is guaranteed to be finalized in the next iteration
1204
+ # without needing to re-encode or touch the KV cache.
1205
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1206
+ tokens_so_far = total_accept_token + accept_cnt
1207
+ if tokens_so_far > max_thinking_tokens:
1208
+ gen_so_far = x[0, prompt_len : prompt_len + tokens_so_far]
1209
+ has_end_think = (gen_so_far == end_think_token_id).any()
1210
+ if not has_end_think:
1211
+ insert_pos = block_start + accept_cnt
1212
+ x[0, insert_pos + 1:] = x[0, insert_pos:-1].clone()
1213
+ x[0, insert_pos] = end_think_token_id
1214
+
1215
+ total_accept_token += accept_cnt
1216
+
1217
+ if total_accept_token >= max_new_tokens:
1218
+ break
1219
+
1220
+ return x[:, : -(block_length * 2)], nfe
1221
+
1222
+
1223
+ @torch.no_grad()
1224
+ def linear_spec_generate(
1225
+ self,
1226
+ prompt_ids: torch.Tensor,
1227
+ max_new_tokens: int = 128,
1228
+ block_length: int = 32,
1229
+ temperature: float = 0.0,
1230
+ mask_token_id: Optional[int] = None,
1231
+ eos_token_id: Optional[int] = None,
1232
+ max_thinking_tokens: Optional[int] = None,
1233
+ end_think_token_id: Optional[int] = None,
1234
+ threshold: float = 0.0,
1235
+ ):
1236
+ """Linear speculative decoding: diffusion draft + AR verification.
1237
+
1238
+ Each step:
1239
+ 1. Draft: forward [last_accepted, mask, ...] with bidirectional attention
1240
+ (diffusion_lm=True, use_cache=False). Shift AR logits to get
1241
+ per-position predictions; apply confidence filtering.
1242
+ 2. Verify: forward the drafted block with causal attention
1243
+ (diffusion_lm=False, use_cache=True, use_causal_mask=True).
1244
+ Accept consecutive AR-matching tokens plus one bonus token.
1245
+
1246
+ Args:
1247
+ prompt_ids: Input token IDs of shape (1, prompt_len).
1248
+ max_new_tokens: Maximum number of tokens to generate.
1249
+ block_length: Number of tokens per draft/verify block.
1250
+ temperature: Sampling temperature (0 = greedy).
1251
+ mask_token_id: Override for config.mask_token_id.
1252
+ eos_token_id: Override for config.eos_token_id.
1253
+ max_thinking_tokens: Budget for thinking tokens before forcing end_think.
1254
+ end_think_token_id: Token ID inserted when thinking budget is exceeded.
1255
+ threshold: Confidence threshold for accepting draft predictions.
1256
+
1257
+ Returns:
1258
+ (output_ids, nfe): output_ids includes the prompt; nfe is the number
1259
+ of forward evaluations (matching self_spec_generate interface).
1260
+ """
1261
+ if prompt_ids.shape[0] != 1:
1262
+ raise ValueError("Linear speculative decoding requires batch_size == 1")
1263
+
1264
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1265
+ if eos_token_id is None:
1266
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1267
+
1268
+ device = prompt_ids.device
1269
+ prompt_len = prompt_ids.shape[1]
1270
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1271
+
1272
+ def _set_diffusion_lm(val: bool):
1273
+ for layer in self.encoder.layers:
1274
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1275
+ layer.self_attn.diffusion_lm = val
1276
+
1277
+ # ===== Prefill (causal) =====
1278
+ _set_diffusion_lm(False)
1279
+
1280
+ enc_out = self.encoder(
1281
+ input_ids=prompt_ids,
1282
+ past_key_values=DynamicCache(),
1283
+ use_cache=True,
1284
+ use_causal_mask=True,
1285
+ )
1286
+ past_key_values = enc_out.past_key_values
1287
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1288
+ nfe = 1
1289
+
1290
+ if temperature > 0:
1291
+ probs = torch.softmax(last_logit / temperature, dim=-1)
1292
+ next_token = torch.multinomial(probs, num_samples=1)
1293
+ else:
1294
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1295
+
1296
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1297
+ output_ids = torch.cat([prompt_ids, next_token], dim=1)
1298
+ return output_ids, nfe
1299
+
1300
+ generated = [next_token]
1301
+ total_gen = 1
1302
+
1303
+ # ===== Main loop =====
1304
+ while total_gen < max_new_tokens:
1305
+ cache_len = past_key_values.get_seq_length()
1306
+
1307
+ block = torch.full(
1308
+ (1, block_length), token_mask_id, dtype=torch.long, device=device
1309
+ )
1310
+ block[0, 0] = next_token.item()
1311
+
1312
+ # -------- Draft (bidirectional, don't update cache) --------
1313
+ _set_diffusion_lm(True)
1314
+ while True:
1315
+ is_mask = block == token_mask_id
1316
+ if not is_mask.any():
1317
+ break
1318
+
1319
+ enc_out = self.encoder(
1320
+ input_ids=block,
1321
+ past_key_values=past_key_values,
1322
+ use_cache=False,
1323
+ )
1324
+ nfe += 1
1325
+
1326
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1327
+ if dream_style:
1328
+ # DREAM: logit[i] predicts position i+1 → shift to self-prediction
1329
+ draft_logits = torch.cat(
1330
+ [draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1
1331
+ )
1332
+ # LLaDA: logit[i] already predicts position i → no shift needed
1333
+
1334
+ if temperature > 0:
1335
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1336
+ draft_tokens = torch.multinomial(
1337
+ draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
1338
+ ).view(1, block_length)
1339
+ else:
1340
+ draft_tokens = draft_logits.argmax(dim=-1)
1341
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1342
+
1343
+ if threshold > 0:
1344
+ draft_conf = torch.gather(
1345
+ draft_probs, -1, draft_tokens.unsqueeze(-1)
1346
+ ).squeeze(-1)
1347
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1348
+ unmask = draft_conf >= threshold
1349
+
1350
+ # Ensure each iteration makes progress even when every masked
1351
+ # position falls below the confidence threshold.
1352
+ if not unmask.any():
1353
+ best_idx = draft_conf.view(-1).argmax()
1354
+ unmask = torch.zeros_like(is_mask, dtype=torch.bool)
1355
+ unmask.view(-1)[best_idx] = True
1356
+
1357
+ block[unmask] = draft_tokens[unmask]
1358
+ else:
1359
+ block[is_mask] = draft_tokens[is_mask]
1360
+ break
1361
+
1362
+ # -------- Verify (causal, update cache) --------
1363
+ _set_diffusion_lm(False)
1364
+ enc_out = self.encoder(
1365
+ input_ids=block,
1366
+ past_key_values=past_key_values,
1367
+ use_cache=True,
1368
+ use_causal_mask=True,
1369
+ )
1370
+ past_key_values = enc_out.past_key_values
1371
+ nfe += 1
1372
+
1373
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1374
+ if temperature > 0:
1375
+ verify_probs = torch.softmax(verify_logits / temperature, dim=-1)
1376
+ ar_tokens = torch.multinomial(
1377
+ verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1
1378
+ ).view(1, block_length)
1379
+ else:
1380
+ ar_tokens = verify_logits.argmax(dim=-1)
1381
+
1382
+ accepted = 0
1383
+ for i in range(block_length - 1):
1384
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1385
+ accepted += 1
1386
+ else:
1387
+ break
1388
+ accepted += 1 # bonus token from AR verification
1389
+
1390
+ accepted_toks = ar_tokens[:, :accepted]
1391
+ generated.append(accepted_toks)
1392
+ total_gen += accepted
1393
+
1394
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1395
+
1396
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1397
+
1398
+ # -------- EOS check --------
1399
+ if eos_token_id is not None:
1400
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1401
+ if len(eos_pos) > 0:
1402
+ first_eos = eos_pos[0].item()
1403
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1404
+ total_gen = total_gen - accepted + first_eos + 1
1405
+ break
1406
+
1407
+ # -------- Thinking budget enforcement --------
1408
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1409
+ if total_gen > max_thinking_tokens:
1410
+ all_gen = torch.cat(generated, dim=1)
1411
+ if not (all_gen == end_think_token_id).any():
1412
+ next_token = torch.tensor(
1413
+ [[end_think_token_id]], device=device
1414
+ )
1415
+
1416
+ if total_gen >= max_new_tokens:
1417
+ break
1418
+
1419
+ all_generated = torch.cat(generated, dim=1)
1420
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1421
+
1422
+ return output_ids, nfe
1423
+
1424
+
1425
+ @torch.no_grad()
1426
+ def linear_spec_generate_mp(
1427
+ self,
1428
+ prompt_ids: torch.Tensor,
1429
+ max_new_tokens: int = 512,
1430
+ block_length: int = 32,
1431
+ temperature: float = 0.0,
1432
+ mask_token_id: Optional[int] = None,
1433
+ eos_token_id: Optional[int] = None,
1434
+ max_paths: int = 16,
1435
+ uncertain_threshold: float = 0.7,
1436
+ top_k_candidates: int = 2,
1437
+ threshold: float = 0.0,
1438
+ max_thinking_tokens: Optional[int] = None,
1439
+ end_think_token_id: Optional[int] = None,
1440
+ ):
1441
+ """Linear speculative decoding with multi-path tree verification.
1442
+
1443
+ Self-contained method — no external file dependencies beyond the model itself.
1444
+
1445
+ Each iteration costs 2 NFE (1 draft + 1 verify):
1446
+ 1. Draft: single-step bidirectional diffusion fills a block of masks.
1447
+ 2. Verify: tree-structured AR verification with multiple candidate paths.
1448
+
1449
+ Multi-path verification identifies low-confidence draft positions and
1450
+ explores top-k alternative tokens. All candidate paths share a trie
1451
+ prefix and are verified in one forward pass via a 4D tree-ancestry
1452
+ attention mask (~40 tokens), picking the path with the longest
1453
+ accepted prefix.
1454
+
1455
+ Benchmark results (NeMo Skills prompt, enable_thinking=False):
1456
+ GSM8K bl=32: +17.1% UW-TPF vs vanilla (acc 93.9%)
1457
+ MBPP bl=64: +17.8% UW-TPF vs vanilla (pass@1 78.2%)
1458
+
1459
+ Args:
1460
+ prompt_ids: (1, prompt_len) input token IDs.
1461
+ max_new_tokens: Maximum tokens to generate.
1462
+ block_length: Draft block size. Use 32 for math, 64 for code.
1463
+ temperature: Sampling temperature (0.0 = greedy).
1464
+ eos_token_id: Stop token ID.
1465
+ max_paths: Tree verification budget. 16 = up to 4 uncertain
1466
+ positions x 2 candidates each.
1467
+ uncertain_threshold: Confidence below which a position is
1468
+ considered uncertain and expanded with alternatives.
1469
+ top_k_candidates: Number of alternative tokens to try at each
1470
+ uncertain position.
1471
+
1472
+ Returns:
1473
+ output_ids: (1, prompt_len + generated_len) full sequence.
1474
+ nfe: Total number of forward evaluations.
1475
+ """
1476
+ from itertools import product as _product
1477
+
1478
+ if prompt_ids.shape[0] != 1:
1479
+ raise ValueError("Requires batch_size == 1")
1480
+
1481
+ device = prompt_ids.device
1482
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1483
+ if eos_token_id is None:
1484
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1485
+
1486
+ def _set_dlm(val: bool):
1487
+ for layer in self.encoder.layers:
1488
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1489
+ layer.self_attn.diffusion_lm = val
1490
+
1491
+ def _crop_cache(kv, length):
1492
+ for li in range(len(kv)):
1493
+ kv.key_cache[li] = kv.key_cache[li][:, :, :length]
1494
+ kv.value_cache[li] = kv.value_cache[li][:, :, :length]
1495
+ kv._seen_tokens = length
1496
+
1497
+ # ----- tree verify helpers (inlined) -----
1498
+
1499
+ def _mp_verify(block, draft_probs, draft_conf, past_kv, cache_len):
1500
+ """Multi-path verify via batch-stacking (flash-attention compatible).
1501
+
1502
+ Unlike tree attention (4D mask), batch-stacking expands the KV cache
1503
+ batch dimension and runs all candidate paths as separate batch entries.
1504
+ This keeps flash attention + GQA enabled, avoiding OOM from the 4D
1505
+ mask path which disables both.
1506
+
1507
+ Returns (accepted_toks, n_accepted, past_kv, next_tok) or None.
1508
+ """
1509
+ bl = block.shape[1]
1510
+
1511
+ # Identify uncertain positions
1512
+ is_filled = block[0] != token_mask_id
1513
+ pos_conf = torch.zeros(bl, device=device)
1514
+ pos_conf[0] = float('inf')
1515
+ for p in range(1, bl):
1516
+ if is_filled[p]:
1517
+ c = draft_conf[0, p].item()
1518
+ pos_conf[p] = c if c != float('-inf') else float('inf')
1519
+ else:
1520
+ pos_conf[p] = float('-inf')
1521
+
1522
+ unc_mask = (pos_conf < uncertain_threshold) & (pos_conf > float('-inf'))
1523
+ unc_pos = unc_mask.nonzero(as_tuple=True)[0].tolist()
1524
+ if not unc_pos:
1525
+ return None
1526
+
1527
+ import math as _math
1528
+ max_unc = min(len(unc_pos), max(1, int(_math.log2(max_paths))))
1529
+ unc_pos = sorted(unc_pos)[:max_unc]
1530
+
1531
+ # Build candidate blocks
1532
+ topk_at = {}
1533
+ for p in unc_pos:
1534
+ _, ids = draft_probs[0, p].topk(top_k_candidates)
1535
+ topk_at[p] = ids.tolist()
1536
+
1537
+ combos = list(_product(*(topk_at[p] for p in sorted(topk_at))))[:max_paths]
1538
+ num_paths = len(combos)
1539
+ if num_paths <= 1:
1540
+ return None
1541
+
1542
+ candidate_blocks = block.expand(num_paths, -1).clone()
1543
+ pos_list = sorted(topk_at.keys())
1544
+ for pi, combo in enumerate(combos):
1545
+ for ci, p in enumerate(pos_list):
1546
+ candidate_blocks[pi, p] = combo[ci]
1547
+
1548
+ # Expand KV cache batch dimension (shared, no copy)
1549
+ for li in range(len(past_kv.key_cache)):
1550
+ past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
1551
+ past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
1552
+
1553
+ # Batched causal verify — uses flash attention + GQA
1554
+ _set_dlm(False)
1555
+ enc_out = self.encoder(
1556
+ input_ids=candidate_blocks,
1557
+ past_key_values=past_kv,
1558
+ use_cache=True,
1559
+ use_causal_mask=True,
1560
+ )
1561
+ past_kv = enc_out.past_key_values
1562
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1563
+
1564
+ if temperature > 0:
1565
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1566
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(num_paths, bl)
1567
+ else:
1568
+ ar_tokens = vlogits.argmax(dim=-1)
1569
+
1570
+ # Find best path (longest accepted prefix)
1571
+ best_acc, best_pidx = 0, 0
1572
+ for pi in range(num_paths):
1573
+ acc = 0
1574
+ for i in range(bl - 1):
1575
+ if ar_tokens[pi, i].item() == candidate_blocks[pi, i + 1].item():
1576
+ acc += 1
1577
+ else:
1578
+ break
1579
+ acc += 1
1580
+ if acc > best_acc:
1581
+ best_acc, best_pidx = acc, pi
1582
+
1583
+ accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
1584
+
1585
+ # Extract winning path's KV cache slice
1586
+ for li in range(len(past_kv.key_cache)):
1587
+ past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
1588
+ past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
1589
+ _crop_cache(past_kv, cache_len + best_acc)
1590
+
1591
+ return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]
1592
+
1593
+ # ── Prefill (causal) ──
1594
+ _set_dlm(False)
1595
+ enc_out = self.encoder(
1596
+ input_ids=prompt_ids, past_key_values=DynamicCache(),
1597
+ use_cache=True, use_causal_mask=True,
1598
+ )
1599
+ past_key_values = enc_out.past_key_values
1600
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1601
+ nfe = 1
1602
+
1603
+ if temperature > 0:
1604
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), 1)
1605
+ else:
1606
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1607
+
1608
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1609
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1610
+
1611
+ generated = [next_token]
1612
+ total_gen = 1
1613
+
1614
+ # ── Main draft-verify loop ──
1615
+ while total_gen < max_new_tokens:
1616
+ cache_len = past_key_values.get_seq_length()
1617
+
1618
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1619
+ block[0, 0] = next_token.item()
1620
+
1621
+ # Draft: single-step bidirectional diffusion (1 NFE)
1622
+ _set_dlm(True)
1623
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1624
+ nfe += 1
1625
+
1626
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1627
+ if temperature > 0:
1628
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1629
+ draft_tokens = torch.multinomial(
1630
+ draft_probs.view(-1, draft_probs.shape[-1]), 1
1631
+ ).view(1, block_length)
1632
+ else:
1633
+ draft_tokens = draft_logits.argmax(dim=-1)
1634
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1635
+
1636
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1637
+ is_mask = block == token_mask_id
1638
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1639
+ block[is_mask] = draft_tokens[is_mask]
1640
+
1641
+ # Verify: multi-path batch-stacking (1 NFE, flash-attention compatible)
1642
+ result = _mp_verify(block, draft_probs, draft_conf, past_key_values, cache_len)
1643
+
1644
+ if result is not None:
1645
+ accepted_toks, accepted, past_key_values, next_token = result
1646
+ nfe += 1
1647
+ else:
1648
+ # No uncertain positions — single-path causal verify
1649
+ _set_dlm(False)
1650
+ enc_out = self.encoder(
1651
+ input_ids=block, past_key_values=past_key_values,
1652
+ use_cache=True, use_causal_mask=True,
1653
+ )
1654
+ past_key_values = enc_out.past_key_values
1655
+ nfe += 1
1656
+
1657
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1658
+ if temperature > 0:
1659
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1660
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(1, block_length)
1661
+ else:
1662
+ ar_tokens = vlogits.argmax(dim=-1)
1663
+
1664
+ accepted = 0
1665
+ for i in range(block_length - 1):
1666
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1667
+ accepted += 1
1668
+ else:
1669
+ break
1670
+ accepted += 1
1671
+
1672
+ accepted_toks = ar_tokens[:, :accepted]
1673
+ _crop_cache(past_key_values, cache_len + accepted)
1674
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1675
+
1676
+ generated.append(accepted_toks)
1677
+ total_gen += accepted
1678
+
1679
+ if eos_token_id is not None:
1680
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1681
+ if len(eos_pos) > 0:
1682
+ first_eos = eos_pos[0].item()
1683
+ generated[-1] = accepted_toks[:, :first_eos + 1]
1684
+ total_gen = total_gen - accepted + first_eos + 1
1685
+ break
1686
+
1687
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1688
+ if total_gen > max_thinking_tokens:
1689
+ all_gen = torch.cat(generated, dim=1)
1690
+ if not (all_gen == end_think_token_id).any():
1691
+ next_token = torch.tensor(
1692
+ [[end_think_token_id]], device=device
1693
+ )
1694
+
1695
+ if total_gen >= max_new_tokens:
1696
+ break
1697
+
1698
+ all_generated = torch.cat(generated, dim=1)
1699
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1700
+ return output_ids, nfe
1701
+
1702
+
1703
+ @torch.no_grad()
1704
+ def linear_spec_generate_lora(
1705
+ self,
1706
+ prompt_ids: torch.Tensor,
1707
+ max_new_tokens: int = 128,
1708
+ block_length: int = 32,
1709
+ temperature: float = 0.0,
1710
+ mask_token_id: Optional[int] = None,
1711
+ eos_token_id: Optional[int] = None,
1712
+ threshold: float = 0.0,
1713
+ rebuild_kv: str = 'none',
1714
+ max_thinking_tokens: Optional[int] = None,
1715
+ end_think_token_id: Optional[int] = None,
1716
+ ):
1717
+ """Linear speculative decoding: diffusion draft + AR verify.
1718
+ LoRA adapter toggling: ON for draft (bidirectional), OFF for verify (causal).
1719
+ Returns (output_ids, nfe).
1720
+ """
1721
+ if prompt_ids.shape[0] != 1:
1722
+ raise ValueError("linear_spec_generate requires batch_size == 1")
1723
+
1724
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1725
+ if eos_token_id is None:
1726
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1727
+
1728
+ device = prompt_ids.device
1729
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1730
+
1731
+ def _set_diffusion_lm(val: bool):
1732
+ for layer in self.encoder.layers:
1733
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1734
+ layer.self_attn.diffusion_lm = val
1735
+
1736
+ def _toggle_adapters(model, enable: bool):
1737
+ for module in model.modules():
1738
+ if hasattr(module, '_disable_adapters'):
1739
+ module._disable_adapters = not enable
1740
+
1741
+ # Prefill (causal, LoRA OFF)
1742
+ _set_diffusion_lm(False)
1743
+ _toggle_adapters(self, False)
1744
+ enc_out = self.encoder(
1745
+ input_ids=prompt_ids,
1746
+ past_key_values=DynamicCache(),
1747
+ use_cache=True,
1748
+ use_causal_mask=True,
1749
+ )
1750
+ past_key_values = enc_out.past_key_values
1751
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1752
+ nfe = 1
1753
+
1754
+ if temperature > 0:
1755
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
1756
+ else:
1757
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1758
+
1759
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1760
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1761
+
1762
+ generated = [next_token]
1763
+ total_gen = 1
1764
+
1765
+ while total_gen < max_new_tokens:
1766
+ cache_len = past_key_values.get_seq_length()
1767
+
1768
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1769
+ block[0, 0] = next_token.item()
1770
+
1771
+ # Draft (bidirectional, LoRA ON)
1772
+ _set_diffusion_lm(True)
1773
+ _toggle_adapters(self, True)
1774
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1775
+ nfe += 1
1776
+
1777
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1778
+ if dream_style:
1779
+ draft_logits = torch.cat([draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1)
1780
+
1781
+ if temperature > 0:
1782
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1783
+ draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1).view(1, block_length)
1784
+ else:
1785
+ draft_tokens = draft_logits.argmax(dim=-1)
1786
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1787
+
1788
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1789
+ is_mask = block == token_mask_id
1790
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1791
+ unmask = draft_conf > threshold
1792
+ if unmask.sum() > 0:
1793
+ block[unmask] = draft_tokens[unmask]
1794
+
1795
+ # Verify (causal, LoRA OFF)
1796
+ _set_diffusion_lm(False)
1797
+ _toggle_adapters(self, False)
1798
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=True, use_causal_mask=True)
1799
+ past_key_values = enc_out.past_key_values
1800
+ nfe += 1
1801
+
1802
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1803
+ if temperature > 0:
1804
+ ar_tokens = torch.multinomial(torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]), num_samples=1).view(1, block_length)
1805
+ else:
1806
+ ar_tokens = verify_logits.argmax(dim=-1)
1807
+
1808
+ accepted = 0
1809
+ for i in range(block_length - 1):
1810
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1811
+ accepted += 1
1812
+ else:
1813
+ break
1814
+ accepted += 1 # bonus token
1815
+
1816
+ accepted_toks = ar_tokens[:, :accepted]
1817
+ generated.append(accepted_toks)
1818
+ total_gen += accepted
1819
+
1820
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1821
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1822
+
1823
+ # EOS check
1824
+ if eos_token_id is not None:
1825
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1826
+ if len(eos_pos) > 0:
1827
+ first_eos = eos_pos[0].item()
1828
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1829
+ total_gen = total_gen - accepted + first_eos + 1
1830
+ break
1831
+
1832
+ # Thinking budget enforcement
1833
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1834
+ if total_gen > max_thinking_tokens:
1835
+ all_gen = torch.cat(generated, dim=1)
1836
+ if not (all_gen == end_think_token_id).any():
1837
+ next_token = torch.tensor([[end_think_token_id]], device=device)
1838
+
1839
+ if total_gen >= max_new_tokens:
1840
+ break
1841
+
1842
+ all_generated = torch.cat(generated, dim=1)
1843
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1844
+ return output_ids, nfe