YongganFu commited on
Commit
1c8770d
·
verified ·
1 Parent(s): 80ab10a

Upload model

Browse files
Files changed (2) hide show
  1. chat_utils.py +49 -3
  2. modeling_ministral_dlm.py +659 -2
chat_utils.py CHANGED
@@ -114,10 +114,12 @@ def generate_with_prefix_cache_block_diff(
114
  neg_entropy=False,
115
  causal_context=False,
116
  eos_token_id=None,
 
 
117
  ):
118
  dream_style=shift_logits
119
- # Initialize the accumulator
120
  x_accum = prompt.clone()
 
121
 
122
  assert gen_length % block_length == 0
123
  num_blocks = gen_length // block_length
@@ -142,30 +144,66 @@ def generate_with_prefix_cache_block_diff(
142
  if hasattr(layer.self_attn, 'diffusion_lm'):
143
  layer.self_attn.diffusion_lm=True
144
 
 
 
 
 
 
 
 
 
 
 
145
  # For dream_style: store the "next token logit" of the context
146
  next_logits_context = None
147
  if dream_style:
148
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
149
 
150
  for num_block in range(num_blocks):
151
- # Create a new block with mask tokens (no seeding)
 
 
152
  mask_block = torch.ones(
153
  (prompt.shape[0], block_length),
154
  dtype=prompt.dtype,
155
  device=prompt.device
156
  ) * mask_id
 
 
157
 
158
  # Append the block of masks
159
  x_accum = torch.cat([x_accum, mask_block], dim=1)
160
  current_block_start = prompt.size(1) + num_block * block_length
161
  block_slice = slice(current_block_start, current_block_start + block_length)
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Build the initial mask for this block
164
  mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
165
 
166
  # Precompute the transfer schedule for this block
167
  if dream_style:
168
- # still denoise *all* positions (0..Lb-1), since none are seeded
169
  schedule_mask = mask_block_idx0
170
  else:
171
  schedule_mask = mask_block_idx0
@@ -245,11 +283,19 @@ def generate_with_prefix_cache_block_diff(
245
  use_causal_mask=causal_context
246
  )
247
  past_key_values = output.past_key_values
 
248
 
249
  if causal_context:
250
  for layer in model_module.encoder.layers:
251
  if hasattr(layer.self_attn, 'diffusion_lm'):
252
  layer.self_attn.diffusion_lm=True
 
 
 
 
 
 
 
253
 
254
  if dream_style and num_block < num_blocks - 1:
255
  # refresh context-next logit for the next block
 
114
  neg_entropy=False,
115
  causal_context=False,
116
  eos_token_id=None,
117
+ max_thinking_tokens=None,
118
+ end_think_token_id=None,
119
  ):
120
  dream_style=shift_logits
 
121
  x_accum = prompt.clone()
122
+ B = prompt.shape[0]
123
 
124
  assert gen_length % block_length == 0
125
  num_blocks = gen_length // block_length
 
144
  if hasattr(layer.self_attn, 'diffusion_lm'):
145
  layer.self_attn.diffusion_lm=True
146
 
147
+ # Causal prefill: next token from last position (same as linear_spec_generate).
148
+ next_token = None
149
+ if causal_context:
150
+ last_logit = output.logits[:, -1, :]
151
+ if temperature > 0:
152
+ probs = torch.softmax(last_logit / temperature, dim=-1)
153
+ next_token = torch.multinomial(probs, num_samples=1)
154
+ else:
155
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
156
+
157
  # For dream_style: store the "next token logit" of the context
158
  next_logits_context = None
159
  if dream_style:
160
  next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
161
 
162
  for num_block in range(num_blocks):
163
+ # Create a new block with mask tokens; under causal context, seed position 0
164
+ # with the next-token prediction from the previous causal forward (prefill or
165
+ # post-block encode), matching linear_spec_generate.
166
  mask_block = torch.ones(
167
  (prompt.shape[0], block_length),
168
  dtype=prompt.dtype,
169
  device=prompt.device
170
  ) * mask_id
171
+ if causal_context:
172
+ mask_block[:, 0] = next_token[:, 0]
173
 
174
  # Append the block of masks
175
  x_accum = torch.cat([x_accum, mask_block], dim=1)
176
  current_block_start = prompt.size(1) + num_block * block_length
177
  block_slice = slice(current_block_start, current_block_start + block_length)
178
 
179
+ # ---- thinking budget enforcement ----
180
+ # If we've generated >= max_thinking_tokens without a </think>, inject one.
181
+ if end_think_token_id is not None and max_thinking_tokens is not None:
182
+ tokens_before_block = num_block * block_length
183
+ tokens_after_block = tokens_before_block + block_length
184
+ if tokens_after_block > max_thinking_tokens:
185
+ gen_so_far = x_accum[:, prompt.size(1):current_block_start]
186
+ has_end_think = (
187
+ (gen_so_far == end_think_token_id).any(dim=1)
188
+ if gen_so_far.size(1) > 0
189
+ else torch.zeros(B, dtype=torch.bool, device=prompt.device)
190
+ )
191
+ if not has_end_think.all():
192
+ if tokens_before_block < max_thinking_tokens:
193
+ offset = max_thinking_tokens - tokens_before_block
194
+ else:
195
+ offset = 0
196
+ inject_pos = current_block_start + offset
197
+ for b in range(B):
198
+ if not has_end_think[b]:
199
+ x_accum[b, inject_pos] = end_think_token_id
200
+
201
  # Build the initial mask for this block
202
  mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
203
 
204
  # Precompute the transfer schedule for this block
205
  if dream_style:
206
+ # masked positions only (position 0 may be causal-seeded, not mask_id)
207
  schedule_mask = mask_block_idx0
208
  else:
209
  schedule_mask = mask_block_idx0
 
283
  use_causal_mask=causal_context
284
  )
285
  past_key_values = output.past_key_values
286
+ nfe += 1
287
 
288
  if causal_context:
289
  for layer in model_module.encoder.layers:
290
  if hasattr(layer.self_attn, 'diffusion_lm'):
291
  layer.self_attn.diffusion_lm=True
292
+ # Next block's first position = greedy/sampled next token from this causal encode
293
+ last_logit = output.logits[:, -1, :]
294
+ if temperature > 0:
295
+ probs = torch.softmax(last_logit / temperature, dim=-1)
296
+ next_token = torch.multinomial(probs, num_samples=1)
297
+ else:
298
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
299
 
300
  if dream_style and num_block < num_blocks - 1:
301
  # refresh context-next logit for the next block
modeling_ministral_dlm.py CHANGED
@@ -31,6 +31,7 @@ from .chat_utils import generate_with_prefix_cache_block_diff
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
 
34
 
35
  @dataclass
36
  class MinistralDiffOutputWithPast(ModelOutput):
@@ -871,7 +872,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
871
  )
872
 
873
 
874
- def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None):
875
  if eos_token_id is None:
876
  eos_token_id = getattr(self.config, 'eos_token_id', None)
877
 
@@ -889,6 +890,8 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
889
  neg_entropy=False,
890
  causal_context=causal_context,
891
  eos_token_id=eos_token_id,
 
 
892
  )
893
 
894
  return out_ids, nfe
@@ -997,6 +1000,8 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
997
  max_new_tokens: int = 128,
998
  temperature: float = 0.0,
999
  eos_token_id: Optional[int] = None,
 
 
1000
  ) -> tuple:
1001
  """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
1002
 
@@ -1044,6 +1049,18 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1044
  else:
1045
  next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
1046
 
 
 
 
 
 
 
 
 
 
 
 
 
1047
  generated_tokens.append(next_token)
1048
 
1049
  if eos_token_id is not None and (next_token == eos_token_id).all():
@@ -1080,6 +1097,8 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1080
  temperature: float = 0.0,
1081
  mask_token_id: Optional[int] = None,
1082
  eos_token_id: Optional[int] = None,
 
 
1083
  ):
1084
  self.config.use_sbd_objective = True
1085
  self.config.dlm_paradigm = "sbd"
@@ -1176,6 +1195,23 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1176
 
1177
  x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1178
  past_key_values.crop(block_start + accept_cnt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1179
  total_accept_token += accept_cnt
1180
 
1181
  if total_accept_token >= max_new_tokens:
@@ -1184,4 +1220,625 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1184
  return x[:, : -(block_length * 2)], nfe
1185
 
1186
 
1187
- __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
34
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
35
 
36
  @dataclass
37
  class MinistralDiffOutputWithPast(ModelOutput):
 
872
  )
873
 
874
 
875
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None, max_thinking_tokens=None, end_think_token_id=None):
876
  if eos_token_id is None:
877
  eos_token_id = getattr(self.config, 'eos_token_id', None)
878
 
 
890
  neg_entropy=False,
891
  causal_context=causal_context,
892
  eos_token_id=eos_token_id,
893
+ max_thinking_tokens=max_thinking_tokens,
894
+ end_think_token_id=end_think_token_id,
895
  )
896
 
897
  return out_ids, nfe
 
1000
  max_new_tokens: int = 128,
1001
  temperature: float = 0.0,
1002
  eos_token_id: Optional[int] = None,
1003
+ max_thinking_tokens: Optional[int] = None,
1004
+ end_think_token_id: Optional[int] = None,
1005
  ) -> tuple:
1006
  """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
1007
 
 
1049
  else:
1050
  next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
1051
 
1052
+ # ---- thinking budget enforcement ----
1053
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1054
+ if step >= max_thinking_tokens:
1055
+ if generated_tokens:
1056
+ gen_tensor = torch.cat(generated_tokens, dim=1)
1057
+ has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
1058
+ else:
1059
+ has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
1060
+ for b in range(batch_size):
1061
+ if not has_end_think[b]:
1062
+ next_token[b] = end_think_token_id
1063
+
1064
  generated_tokens.append(next_token)
1065
 
1066
  if eos_token_id is not None and (next_token == eos_token_id).all():
 
1097
  temperature: float = 0.0,
1098
  mask_token_id: Optional[int] = None,
1099
  eos_token_id: Optional[int] = None,
1100
+ max_thinking_tokens: Optional[int] = None,
1101
+ end_think_token_id: Optional[int] = None,
1102
  ):
1103
  self.config.use_sbd_objective = True
1104
  self.config.dlm_paradigm = "sbd"
 
1195
 
1196
  x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1197
  past_key_values.crop(block_start + accept_cnt)
1198
+
1199
+ # ---- thinking budget enforcement ----
1200
+ # Insert end_think as the first token of the next draft block,
1201
+ # shifting all subsequent tokens right by 1 (discarding the last).
1202
+ # The first draft token is always accepted unconditionally, so
1203
+ # end_think is guaranteed to be finalized in the next iteration
1204
+ # without needing to re-encode or touch the KV cache.
1205
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1206
+ tokens_so_far = total_accept_token + accept_cnt
1207
+ if tokens_so_far > max_thinking_tokens:
1208
+ gen_so_far = x[0, prompt_len : prompt_len + tokens_so_far]
1209
+ has_end_think = (gen_so_far == end_think_token_id).any()
1210
+ if not has_end_think:
1211
+ insert_pos = block_start + accept_cnt
1212
+ x[0, insert_pos + 1:] = x[0, insert_pos:-1].clone()
1213
+ x[0, insert_pos] = end_think_token_id
1214
+
1215
  total_accept_token += accept_cnt
1216
 
1217
  if total_accept_token >= max_new_tokens:
 
1220
  return x[:, : -(block_length * 2)], nfe
1221
 
1222
 
1223
+ @torch.no_grad()
1224
+ def linear_spec_generate(
1225
+ self,
1226
+ prompt_ids: torch.Tensor,
1227
+ max_new_tokens: int = 128,
1228
+ block_length: int = 32,
1229
+ temperature: float = 0.0,
1230
+ mask_token_id: Optional[int] = None,
1231
+ eos_token_id: Optional[int] = None,
1232
+ max_thinking_tokens: Optional[int] = None,
1233
+ end_think_token_id: Optional[int] = None,
1234
+ threshold: float = 0.0,
1235
+ ):
1236
+ """Linear speculative decoding: diffusion draft + AR verification.
1237
+
1238
+ Each step:
1239
+ 1. Draft: forward [last_accepted, mask, ...] with bidirectional attention
1240
+ (diffusion_lm=True, use_cache=False). Shift AR logits to get
1241
+ per-position predictions; apply confidence filtering.
1242
+ 2. Verify: forward the drafted block with causal attention
1243
+ (diffusion_lm=False, use_cache=True, use_causal_mask=True).
1244
+ Accept consecutive AR-matching tokens plus one bonus token.
1245
+
1246
+ Args:
1247
+ prompt_ids: Input token IDs of shape (1, prompt_len).
1248
+ max_new_tokens: Maximum number of tokens to generate.
1249
+ block_length: Number of tokens per draft/verify block.
1250
+ temperature: Sampling temperature (0 = greedy).
1251
+ mask_token_id: Override for config.mask_token_id.
1252
+ eos_token_id: Override for config.eos_token_id.
1253
+ max_thinking_tokens: Budget for thinking tokens before forcing end_think.
1254
+ end_think_token_id: Token ID inserted when thinking budget is exceeded.
1255
+ threshold: Confidence threshold for accepting draft predictions.
1256
+
1257
+ Returns:
1258
+ (output_ids, nfe): output_ids includes the prompt; nfe is the number
1259
+ of forward evaluations (matching self_spec_generate interface).
1260
+ """
1261
+ if prompt_ids.shape[0] != 1:
1262
+ raise ValueError("Linear speculative decoding requires batch_size == 1")
1263
+
1264
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1265
+ if eos_token_id is None:
1266
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1267
+
1268
+ device = prompt_ids.device
1269
+ prompt_len = prompt_ids.shape[1]
1270
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1271
+
1272
+ def _set_diffusion_lm(val: bool):
1273
+ for layer in self.encoder.layers:
1274
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1275
+ layer.self_attn.diffusion_lm = val
1276
+
1277
+ # ===== Prefill (causal) =====
1278
+ _set_diffusion_lm(False)
1279
+
1280
+ enc_out = self.encoder(
1281
+ input_ids=prompt_ids,
1282
+ past_key_values=DynamicCache(),
1283
+ use_cache=True,
1284
+ use_causal_mask=True,
1285
+ )
1286
+ past_key_values = enc_out.past_key_values
1287
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1288
+ nfe = 1
1289
+
1290
+ if temperature > 0:
1291
+ probs = torch.softmax(last_logit / temperature, dim=-1)
1292
+ next_token = torch.multinomial(probs, num_samples=1)
1293
+ else:
1294
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1295
+
1296
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1297
+ output_ids = torch.cat([prompt_ids, next_token], dim=1)
1298
+ return output_ids, nfe
1299
+
1300
+ generated = [next_token]
1301
+ total_gen = 1
1302
+
1303
+ # ===== Main loop =====
1304
+ while total_gen < max_new_tokens:
1305
+ cache_len = past_key_values.get_seq_length()
1306
+
1307
+ block = torch.full(
1308
+ (1, block_length), token_mask_id, dtype=torch.long, device=device
1309
+ )
1310
+ block[0, 0] = next_token.item()
1311
+
1312
+ # -------- Draft (bidirectional, don't update cache) --------
1313
+ _set_diffusion_lm(True)
1314
+ while True:
1315
+ is_mask = block == token_mask_id
1316
+ if not is_mask.any():
1317
+ break
1318
+
1319
+ enc_out = self.encoder(
1320
+ input_ids=block,
1321
+ past_key_values=past_key_values,
1322
+ use_cache=False,
1323
+ )
1324
+ nfe += 1
1325
+
1326
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1327
+ if dream_style:
1328
+ # DREAM: logit[i] predicts position i+1 → shift to self-prediction
1329
+ draft_logits = torch.cat(
1330
+ [draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1
1331
+ )
1332
+ # LLaDA: logit[i] already predicts position i → no shift needed
1333
+
1334
+ if temperature > 0:
1335
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1336
+ draft_tokens = torch.multinomial(
1337
+ draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
1338
+ ).view(1, block_length)
1339
+ else:
1340
+ draft_tokens = draft_logits.argmax(dim=-1)
1341
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1342
+
1343
+ if threshold > 0:
1344
+ draft_conf = torch.gather(
1345
+ draft_probs, -1, draft_tokens.unsqueeze(-1)
1346
+ ).squeeze(-1)
1347
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1348
+ unmask = draft_conf >= threshold
1349
+
1350
+ # Ensure each iteration makes progress even when every masked
1351
+ # position falls below the confidence threshold.
1352
+ if not unmask.any():
1353
+ best_idx = draft_conf.view(-1).argmax()
1354
+ unmask = torch.zeros_like(is_mask, dtype=torch.bool)
1355
+ unmask.view(-1)[best_idx] = True
1356
+
1357
+ block[unmask] = draft_tokens[unmask]
1358
+ else:
1359
+ block[is_mask] = draft_tokens[is_mask]
1360
+ break
1361
+
1362
+ # -------- Verify (causal, update cache) --------
1363
+ _set_diffusion_lm(False)
1364
+ enc_out = self.encoder(
1365
+ input_ids=block,
1366
+ past_key_values=past_key_values,
1367
+ use_cache=True,
1368
+ use_causal_mask=True,
1369
+ )
1370
+ past_key_values = enc_out.past_key_values
1371
+ nfe += 1
1372
+
1373
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1374
+ if temperature > 0:
1375
+ verify_probs = torch.softmax(verify_logits / temperature, dim=-1)
1376
+ ar_tokens = torch.multinomial(
1377
+ verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1
1378
+ ).view(1, block_length)
1379
+ else:
1380
+ ar_tokens = verify_logits.argmax(dim=-1)
1381
+
1382
+ accepted = 0
1383
+ for i in range(block_length - 1):
1384
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1385
+ accepted += 1
1386
+ else:
1387
+ break
1388
+ accepted += 1 # bonus token from AR verification
1389
+
1390
+ accepted_toks = ar_tokens[:, :accepted]
1391
+ generated.append(accepted_toks)
1392
+ total_gen += accepted
1393
+
1394
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1395
+
1396
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1397
+
1398
+ # -------- EOS check --------
1399
+ if eos_token_id is not None:
1400
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1401
+ if len(eos_pos) > 0:
1402
+ first_eos = eos_pos[0].item()
1403
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1404
+ total_gen = total_gen - accepted + first_eos + 1
1405
+ break
1406
+
1407
+ # -------- Thinking budget enforcement --------
1408
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1409
+ if total_gen > max_thinking_tokens:
1410
+ all_gen = torch.cat(generated, dim=1)
1411
+ if not (all_gen == end_think_token_id).any():
1412
+ next_token = torch.tensor(
1413
+ [[end_think_token_id]], device=device
1414
+ )
1415
+
1416
+ if total_gen >= max_new_tokens:
1417
+ break
1418
+
1419
+ all_generated = torch.cat(generated, dim=1)
1420
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1421
+
1422
+ return output_ids, nfe
1423
+
1424
+
1425
+ @torch.no_grad()
1426
+ def linear_spec_generate_mp(
1427
+ self,
1428
+ prompt_ids: torch.Tensor,
1429
+ max_new_tokens: int = 512,
1430
+ block_length: int = 32,
1431
+ temperature: float = 0.0,
1432
+ mask_token_id: Optional[int] = None,
1433
+ eos_token_id: Optional[int] = None,
1434
+ max_paths: int = 16,
1435
+ uncertain_threshold: float = 0.7,
1436
+ top_k_candidates: int = 2,
1437
+ threshold: float = 0.0,
1438
+ max_thinking_tokens: Optional[int] = None,
1439
+ end_think_token_id: Optional[int] = None,
1440
+ ):
1441
+ """Linear speculative decoding with multi-path tree verification.
1442
+
1443
+ Self-contained method — no external file dependencies beyond the model itself.
1444
+
1445
+ Each iteration costs 2 NFE (1 draft + 1 verify):
1446
+ 1. Draft: single-step bidirectional diffusion fills a block of masks.
1447
+ 2. Verify: tree-structured AR verification with multiple candidate paths.
1448
+
1449
+ Multi-path verification identifies low-confidence draft positions and
1450
+ explores top-k alternative tokens. All candidate paths share a trie
1451
+ prefix and are verified in one forward pass via a 4D tree-ancestry
1452
+ attention mask (~40 tokens), picking the path with the longest
1453
+ accepted prefix.
1454
+
1455
+ Benchmark results (NeMo Skills prompt, enable_thinking=False):
1456
+ GSM8K bl=32: +17.1% UW-TPF vs vanilla (acc 93.9%)
1457
+ MBPP bl=64: +17.8% UW-TPF vs vanilla (pass@1 78.2%)
1458
+
1459
+ Args:
1460
+ prompt_ids: (1, prompt_len) input token IDs.
1461
+ max_new_tokens: Maximum tokens to generate.
1462
+ block_length: Draft block size. Use 32 for math, 64 for code.
1463
+ temperature: Sampling temperature (0.0 = greedy).
1464
+ eos_token_id: Stop token ID.
1465
+ max_paths: Tree verification budget. 16 = up to 4 uncertain
1466
+ positions x 2 candidates each.
1467
+ uncertain_threshold: Confidence below which a position is
1468
+ considered uncertain and expanded with alternatives.
1469
+ top_k_candidates: Number of alternative tokens to try at each
1470
+ uncertain position.
1471
+
1472
+ Returns:
1473
+ output_ids: (1, prompt_len + generated_len) full sequence.
1474
+ nfe: Total number of forward evaluations.
1475
+ """
1476
+ from itertools import product as _product
1477
+
1478
+ if prompt_ids.shape[0] != 1:
1479
+ raise ValueError("Requires batch_size == 1")
1480
+
1481
+ device = prompt_ids.device
1482
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1483
+ if eos_token_id is None:
1484
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1485
+
1486
+ def _set_dlm(val: bool):
1487
+ for layer in self.encoder.layers:
1488
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1489
+ layer.self_attn.diffusion_lm = val
1490
+
1491
+ def _crop_cache(kv, length):
1492
+ for li in range(len(kv)):
1493
+ kv.key_cache[li] = kv.key_cache[li][:, :, :length]
1494
+ kv.value_cache[li] = kv.value_cache[li][:, :, :length]
1495
+ kv._seen_tokens = length
1496
+
1497
+ # ----- tree verify helpers (inlined) -----
1498
+
1499
+ def _mp_verify(block, draft_probs, draft_conf, past_kv, cache_len):
1500
+ """Multi-path verify via batch-stacking (flash-attention compatible).
1501
+
1502
+ Unlike tree attention (4D mask), batch-stacking expands the KV cache
1503
+ batch dimension and runs all candidate paths as separate batch entries.
1504
+ This keeps flash attention + GQA enabled, avoiding OOM from the 4D
1505
+ mask path which disables both.
1506
+
1507
+ Returns (accepted_toks, n_accepted, past_kv, next_tok) or None.
1508
+ """
1509
+ bl = block.shape[1]
1510
+
1511
+ # Identify uncertain positions
1512
+ is_filled = block[0] != token_mask_id
1513
+ pos_conf = torch.zeros(bl, device=device)
1514
+ pos_conf[0] = float('inf')
1515
+ for p in range(1, bl):
1516
+ if is_filled[p]:
1517
+ c = draft_conf[0, p].item()
1518
+ pos_conf[p] = c if c != float('-inf') else float('inf')
1519
+ else:
1520
+ pos_conf[p] = float('-inf')
1521
+
1522
+ unc_mask = (pos_conf < uncertain_threshold) & (pos_conf > float('-inf'))
1523
+ unc_pos = unc_mask.nonzero(as_tuple=True)[0].tolist()
1524
+ if not unc_pos:
1525
+ return None
1526
+
1527
+ import math as _math
1528
+ max_unc = min(len(unc_pos), max(1, int(_math.log2(max_paths))))
1529
+ unc_pos = sorted(unc_pos)[:max_unc]
1530
+
1531
+ # Build candidate blocks
1532
+ topk_at = {}
1533
+ for p in unc_pos:
1534
+ _, ids = draft_probs[0, p].topk(top_k_candidates)
1535
+ topk_at[p] = ids.tolist()
1536
+
1537
+ combos = list(_product(*(topk_at[p] for p in sorted(topk_at))))[:max_paths]
1538
+ num_paths = len(combos)
1539
+ if num_paths <= 1:
1540
+ return None
1541
+
1542
+ candidate_blocks = block.expand(num_paths, -1).clone()
1543
+ pos_list = sorted(topk_at.keys())
1544
+ for pi, combo in enumerate(combos):
1545
+ for ci, p in enumerate(pos_list):
1546
+ candidate_blocks[pi, p] = combo[ci]
1547
+
1548
+ # Expand KV cache batch dimension (shared, no copy)
1549
+ for li in range(len(past_kv.key_cache)):
1550
+ past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
1551
+ past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
1552
+
1553
+ # Batched causal verify — uses flash attention + GQA
1554
+ _set_dlm(False)
1555
+ enc_out = self.encoder(
1556
+ input_ids=candidate_blocks,
1557
+ past_key_values=past_kv,
1558
+ use_cache=True,
1559
+ use_causal_mask=True,
1560
+ )
1561
+ past_kv = enc_out.past_key_values
1562
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1563
+
1564
+ if temperature > 0:
1565
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1566
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(num_paths, bl)
1567
+ else:
1568
+ ar_tokens = vlogits.argmax(dim=-1)
1569
+
1570
+ # Find best path (longest accepted prefix)
1571
+ best_acc, best_pidx = 0, 0
1572
+ for pi in range(num_paths):
1573
+ acc = 0
1574
+ for i in range(bl - 1):
1575
+ if ar_tokens[pi, i].item() == candidate_blocks[pi, i + 1].item():
1576
+ acc += 1
1577
+ else:
1578
+ break
1579
+ acc += 1
1580
+ if acc > best_acc:
1581
+ best_acc, best_pidx = acc, pi
1582
+
1583
+ accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
1584
+
1585
+ # Extract winning path's KV cache slice
1586
+ for li in range(len(past_kv.key_cache)):
1587
+ past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
1588
+ past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
1589
+ _crop_cache(past_kv, cache_len + best_acc)
1590
+
1591
+ return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]
1592
+
1593
+ # ── Prefill (causal) ──
1594
+ _set_dlm(False)
1595
+ enc_out = self.encoder(
1596
+ input_ids=prompt_ids, past_key_values=DynamicCache(),
1597
+ use_cache=True, use_causal_mask=True,
1598
+ )
1599
+ past_key_values = enc_out.past_key_values
1600
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1601
+ nfe = 1
1602
+
1603
+ if temperature > 0:
1604
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), 1)
1605
+ else:
1606
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1607
+
1608
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1609
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1610
+
1611
+ generated = [next_token]
1612
+ total_gen = 1
1613
+
1614
+ # ── Main draft-verify loop ──
1615
+ while total_gen < max_new_tokens:
1616
+ cache_len = past_key_values.get_seq_length()
1617
+
1618
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1619
+ block[0, 0] = next_token.item()
1620
+
1621
+ # Draft: single-step bidirectional diffusion (1 NFE)
1622
+ _set_dlm(True)
1623
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1624
+ nfe += 1
1625
+
1626
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1627
+ if temperature > 0:
1628
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1629
+ draft_tokens = torch.multinomial(
1630
+ draft_probs.view(-1, draft_probs.shape[-1]), 1
1631
+ ).view(1, block_length)
1632
+ else:
1633
+ draft_tokens = draft_logits.argmax(dim=-1)
1634
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1635
+
1636
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1637
+ is_mask = block == token_mask_id
1638
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1639
+ block[is_mask] = draft_tokens[is_mask]
1640
+
1641
+ # Verify: multi-path batch-stacking (1 NFE, flash-attention compatible)
1642
+ result = _mp_verify(block, draft_probs, draft_conf, past_key_values, cache_len)
1643
+
1644
+ if result is not None:
1645
+ accepted_toks, accepted, past_key_values, next_token = result
1646
+ nfe += 1
1647
+ else:
1648
+ # No uncertain positions — single-path causal verify
1649
+ _set_dlm(False)
1650
+ enc_out = self.encoder(
1651
+ input_ids=block, past_key_values=past_key_values,
1652
+ use_cache=True, use_causal_mask=True,
1653
+ )
1654
+ past_key_values = enc_out.past_key_values
1655
+ nfe += 1
1656
+
1657
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1658
+ if temperature > 0:
1659
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1660
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(1, block_length)
1661
+ else:
1662
+ ar_tokens = vlogits.argmax(dim=-1)
1663
+
1664
+ accepted = 0
1665
+ for i in range(block_length - 1):
1666
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1667
+ accepted += 1
1668
+ else:
1669
+ break
1670
+ accepted += 1
1671
+
1672
+ accepted_toks = ar_tokens[:, :accepted]
1673
+ _crop_cache(past_key_values, cache_len + accepted)
1674
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1675
+
1676
+ generated.append(accepted_toks)
1677
+ total_gen += accepted
1678
+
1679
+ if eos_token_id is not None:
1680
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1681
+ if len(eos_pos) > 0:
1682
+ first_eos = eos_pos[0].item()
1683
+ generated[-1] = accepted_toks[:, :first_eos + 1]
1684
+ total_gen = total_gen - accepted + first_eos + 1
1685
+ break
1686
+
1687
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1688
+ if total_gen > max_thinking_tokens:
1689
+ all_gen = torch.cat(generated, dim=1)
1690
+ if not (all_gen == end_think_token_id).any():
1691
+ next_token = torch.tensor(
1692
+ [[end_think_token_id]], device=device
1693
+ )
1694
+
1695
+ if total_gen >= max_new_tokens:
1696
+ break
1697
+
1698
+ all_generated = torch.cat(generated, dim=1)
1699
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1700
+ return output_ids, nfe
1701
+
1702
+
1703
+ @torch.no_grad()
1704
+ def linear_spec_generate_lora(
1705
+ self,
1706
+ prompt_ids: torch.Tensor,
1707
+ max_new_tokens: int = 128,
1708
+ block_length: int = 32,
1709
+ temperature: float = 0.0,
1710
+ mask_token_id: Optional[int] = None,
1711
+ eos_token_id: Optional[int] = None,
1712
+ threshold: float = 0.0,
1713
+ rebuild_kv: str = 'none',
1714
+ max_thinking_tokens: Optional[int] = None,
1715
+ end_think_token_id: Optional[int] = None,
1716
+ ):
1717
+ """Linear speculative decoding: diffusion draft + AR verify.
1718
+ LoRA adapter toggling: ON for draft (bidirectional), OFF for verify (causal).
1719
+ Returns (output_ids, nfe).
1720
+ """
1721
+ if prompt_ids.shape[0] != 1:
1722
+ raise ValueError("linear_spec_generate requires batch_size == 1")
1723
+
1724
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1725
+ if eos_token_id is None:
1726
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1727
+
1728
+ device = prompt_ids.device
1729
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1730
+
1731
+ def _set_diffusion_lm(val: bool):
1732
+ for layer in self.encoder.layers:
1733
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1734
+ layer.self_attn.diffusion_lm = val
1735
+
1736
+ def _toggle_adapters(model, enable: bool):
1737
+ for module in model.modules():
1738
+ if hasattr(module, '_disable_adapters'):
1739
+ module._disable_adapters = not enable
1740
+
1741
+ # Prefill (causal, LoRA OFF)
1742
+ _set_diffusion_lm(False)
1743
+ _toggle_adapters(self, False)
1744
+ enc_out = self.encoder(
1745
+ input_ids=prompt_ids,
1746
+ past_key_values=DynamicCache(),
1747
+ use_cache=True,
1748
+ use_causal_mask=True,
1749
+ )
1750
+ past_key_values = enc_out.past_key_values
1751
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1752
+ nfe = 1
1753
+
1754
+ if temperature > 0:
1755
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
1756
+ else:
1757
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1758
+
1759
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1760
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1761
+
1762
+ generated = [next_token]
1763
+ total_gen = 1
1764
+
1765
+ while total_gen < max_new_tokens:
1766
+ cache_len = past_key_values.get_seq_length()
1767
+
1768
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1769
+ block[0, 0] = next_token.item()
1770
+
1771
+ # Draft (bidirectional, LoRA ON)
1772
+ _set_diffusion_lm(True)
1773
+ _toggle_adapters(self, True)
1774
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1775
+ nfe += 1
1776
+
1777
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1778
+ if dream_style:
1779
+ draft_logits = torch.cat([draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1)
1780
+
1781
+ if temperature > 0:
1782
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1783
+ draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1).view(1, block_length)
1784
+ else:
1785
+ draft_tokens = draft_logits.argmax(dim=-1)
1786
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1787
+
1788
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1789
+ is_mask = block == token_mask_id
1790
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1791
+ unmask = draft_conf > threshold
1792
+ if unmask.sum() > 0:
1793
+ block[unmask] = draft_tokens[unmask]
1794
+
1795
+ # Verify (causal, LoRA OFF)
1796
+ _set_diffusion_lm(False)
1797
+ _toggle_adapters(self, False)
1798
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=True, use_causal_mask=True)
1799
+ past_key_values = enc_out.past_key_values
1800
+ nfe += 1
1801
+
1802
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1803
+ if temperature > 0:
1804
+ ar_tokens = torch.multinomial(torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]), num_samples=1).view(1, block_length)
1805
+ else:
1806
+ ar_tokens = verify_logits.argmax(dim=-1)
1807
+
1808
+ accepted = 0
1809
+ for i in range(block_length - 1):
1810
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1811
+ accepted += 1
1812
+ else:
1813
+ break
1814
+ accepted += 1 # bonus token
1815
+
1816
+ accepted_toks = ar_tokens[:, :accepted]
1817
+ generated.append(accepted_toks)
1818
+ total_gen += accepted
1819
+
1820
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1821
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1822
+
1823
+ # EOS check
1824
+ if eos_token_id is not None:
1825
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1826
+ if len(eos_pos) > 0:
1827
+ first_eos = eos_pos[0].item()
1828
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1829
+ total_gen = total_gen - accepted + first_eos + 1
1830
+ break
1831
+
1832
+ # Thinking budget enforcement
1833
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1834
+ if total_gen > max_thinking_tokens:
1835
+ all_gen = torch.cat(generated, dim=1)
1836
+ if not (all_gen == end_think_token_id).any():
1837
+ next_token = torch.tensor([[end_think_token_id]], device=device)
1838
+
1839
+ if total_gen >= max_new_tokens:
1840
+ break
1841
+
1842
+ all_generated = torch.cat(generated, dim=1)
1843
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1844
+ return output_ids, nfe