YongganFu commited on
Commit
3ef6080
·
verified ·
1 Parent(s): 86af60c

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
chat_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ '''
8
+ The Gumbel max is a method for sampling categorical distributions.
9
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
+ Thus, we use float64.
11
+ '''
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
+ x0 = torch.argmax(logits_with_noise, dim=-1)
23
+
24
+ if remasking == 'low_confidence':
25
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
26
+ p = F.softmax(logits, dim=-1)
27
+ x0_p = torch.squeeze(
28
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
+ elif remasking == 'top_p_margin':
30
+ # Compute probabilities
31
+ p = F.softmax(logits, dim=-1) # (B, L, V)
32
+ # Top-2 per position
33
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
35
+
36
+ # Normalize margin to [0,1] over MASKED positions per row
37
+ plus_inf = torch.full_like(margin, float('inf'))
38
+ minus_inf = torch.full_like(margin, float('-inf'))
39
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
40
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
41
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
42
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
43
+ denom = (row_max - row_min)
44
+
45
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
46
+ normalized = torch.zeros_like(margin)
47
+ nonzero = denom > 0
48
+ normalized = torch.where(
49
+ mask_index & nonzero,
50
+ (margin - row_min) / (denom + 1e-12),
51
+ normalized
52
+ )
53
+ normalized = torch.where(
54
+ mask_index & (~nonzero),
55
+ torch.ones_like(normalized),
56
+ normalized
57
+ )
58
+ x0_p = normalized # ∈ [0,1] on masked positions
59
+ elif remasking == 'random':
60
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
61
+ else:
62
+ raise NotImplementedError(remasking)
63
+
64
+ # Calculate negative entropy if requested
65
+ if neg_entropy:
66
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
67
+ p = F.softmax(logits, dim=-1)
68
+ epsilon = 1e-10
69
+ log_probs = torch.log(p + epsilon)
70
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
71
+ else:
72
+ confidence_scores = x0_p
73
+
74
+ x0 = torch.where(mask_index, x0, x)
75
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
76
+
77
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
78
+ if threshold is not None:
79
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
80
+ # print(f'confidence: {confidence}')
81
+ for j in range(confidence.shape[0]):
82
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
83
+ transfer_index[j, select_index] = True
84
+ if threshold is not None:
85
+ for k in range(1, num_transfer_tokens[j]):
86
+ if confidence[j, select_index[k]] < threshold:
87
+ transfer_index[j, select_index[k]] = False
88
+ return x0, transfer_index
89
+
90
+
91
+ def get_num_transfer_tokens(mask_index, steps: int):
92
+ mask_num = mask_index.sum(dim=1, keepdim=True)
93
+ base = mask_num // steps
94
+ remainder = mask_num % steps
95
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
96
+ for i in range(mask_num.size(0)):
97
+ num_transfer_tokens[i, : int(remainder[i])] += 1
98
+ return num_transfer_tokens
99
+
100
+
101
+ @torch.no_grad()
102
+ def generate_with_prefix_cache_block_diff(
103
+ model,
104
+ prompt,
105
+ steps=128,
106
+ gen_length=128,
107
+ block_length=128,
108
+ temperature=0.,
109
+ remasking='low_confidence',
110
+ mask_id=126336,
111
+ threshold=None,
112
+ factor=None,
113
+ shift_logits=False,
114
+ neg_entropy=False,
115
+ causal_context=False,
116
+ eos_token_id=None,
117
+ max_thinking_tokens=None,
118
+ end_think_token_id=None,
119
+ ):
120
+ dream_style=shift_logits
121
+ x_accum = prompt.clone()
122
+ B = prompt.shape[0]
123
+
124
+ assert gen_length % block_length == 0
125
+ num_blocks = gen_length // block_length
126
+
127
+ assert steps % num_blocks == 0
128
+ steps_per_block = steps // num_blocks
129
+
130
+ nfe = 0
131
+
132
+ if causal_context:
133
+ model_module = model.module if hasattr(model, "module") else model
134
+ for layer in model_module.encoder.layers:
135
+ if hasattr(layer.self_attn, 'diffusion_lm'):
136
+ layer.self_attn.diffusion_lm=False
137
+
138
+ # Compute KV cache for the prompt initially
139
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context)
140
+ past_key_values = output.past_key_values
141
+
142
+ if causal_context:
143
+ for layer in model_module.encoder.layers:
144
+ if hasattr(layer.self_attn, 'diffusion_lm'):
145
+ layer.self_attn.diffusion_lm=True
146
+
147
+ # Causal prefill: next token from last position (same as linear_spec_generate).
148
+ next_token = None
149
+ if causal_context:
150
+ last_logit = output.logits[:, -1, :]
151
+ if temperature > 0:
152
+ probs = torch.softmax(last_logit / temperature, dim=-1)
153
+ next_token = torch.multinomial(probs, num_samples=1)
154
+ else:
155
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
156
+
157
+ # For dream_style: store the "next token logit" of the context
158
+ next_logits_context = None
159
+ if dream_style:
160
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
161
+
162
+ for num_block in range(num_blocks):
163
+ # Create a new block with mask tokens; under causal context, seed position 0
164
+ # with the next-token prediction from the previous causal forward (prefill or
165
+ # post-block encode), matching linear_spec_generate.
166
+ mask_block = torch.ones(
167
+ (prompt.shape[0], block_length),
168
+ dtype=prompt.dtype,
169
+ device=prompt.device
170
+ ) * mask_id
171
+ if causal_context:
172
+ mask_block[:, 0] = next_token[:, 0]
173
+
174
+ # Append the block of masks
175
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
176
+ current_block_start = prompt.size(1) + num_block * block_length
177
+ block_slice = slice(current_block_start, current_block_start + block_length)
178
+
179
+ # ---- thinking budget enforcement ----
180
+ # If we've generated >= max_thinking_tokens without a </think>, inject one.
181
+ if end_think_token_id is not None and max_thinking_tokens is not None:
182
+ tokens_before_block = num_block * block_length
183
+ tokens_after_block = tokens_before_block + block_length
184
+ if tokens_after_block > max_thinking_tokens:
185
+ gen_so_far = x_accum[:, prompt.size(1):current_block_start]
186
+ has_end_think = (
187
+ (gen_so_far == end_think_token_id).any(dim=1)
188
+ if gen_so_far.size(1) > 0
189
+ else torch.zeros(B, dtype=torch.bool, device=prompt.device)
190
+ )
191
+ if not has_end_think.all():
192
+ if tokens_before_block < max_thinking_tokens:
193
+ offset = max_thinking_tokens - tokens_before_block
194
+ else:
195
+ offset = 0
196
+ inject_pos = current_block_start + offset
197
+ for b in range(B):
198
+ if not has_end_think[b]:
199
+ x_accum[b, inject_pos] = end_think_token_id
200
+
201
+ # Build the initial mask for this block
202
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
203
+
204
+ # Precompute the transfer schedule for this block
205
+ if dream_style:
206
+ # masked positions only (position 0 may be causal-seeded, not mask_id)
207
+ schedule_mask = mask_block_idx0
208
+ else:
209
+ schedule_mask = mask_block_idx0
210
+
211
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
212
+
213
+ # Denoise the current block
214
+ for i in range(steps_per_block):
215
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
216
+ if mask_block_idx.sum() == 0:
217
+ break
218
+
219
+ nfe += 1
220
+
221
+ # Forward only the current noisy block using cached context
222
+ logits_block = model(
223
+ x_accum[:, block_slice],
224
+ past_key_values=past_key_values,
225
+ use_cache=False
226
+ ).logits
227
+
228
+ if dream_style:
229
+ # Align logits so that each masked position has a predictor:
230
+ # prepend context-next logit, then use logits_block[:-1]
231
+ if block_length == 1:
232
+ logits_use = next_logits_context # (B, 1, V)
233
+ else:
234
+ logits_use = torch.cat(
235
+ [next_logits_context, logits_block[:, :-1, :]],
236
+ dim=1
237
+ ) # (B, Lb, V)
238
+
239
+ mask_use = mask_block_idx # (B, Lb)
240
+ x_use = x_accum[:, block_slice] # (B, Lb)
241
+
242
+ x0, transfer_idx = get_transfer_index(
243
+ logits_use, temperature, remasking, mask_use, x_use,
244
+ num_transfer_tokens=num_transfer_tokens[:, i],
245
+ threshold=threshold, neg_entropy=neg_entropy
246
+ )
247
+ cur = x_accum[:, block_slice].clone()
248
+ cur[transfer_idx] = x0[transfer_idx]
249
+ x_accum[:, block_slice] = cur
250
+
251
+ else:
252
+ # non-AR (same-position) case
253
+ x0, transfer_idx = get_transfer_index(
254
+ logits_block, temperature, remasking, mask_block_idx,
255
+ x_accum[:, block_slice],
256
+ num_transfer_tokens=num_transfer_tokens[:, i],
257
+ threshold=threshold, neg_entropy=neg_entropy
258
+ )
259
+ cur = x_accum[:, block_slice].clone()
260
+ cur[transfer_idx] = x0[transfer_idx]
261
+ x_accum[:, block_slice] = cur
262
+
263
+ if eos_token_id is not None:
264
+ block_tokens = x_accum[:, block_slice] # (B, Lb)
265
+ eos_mask = (block_tokens == eos_token_id) # (B, Lb)
266
+ any_eos = eos_mask.any(dim=1) # (B,)
267
+ if any_eos.any():
268
+ after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
269
+ mask_before = (block_tokens == mask_id) & ~after_eos
270
+ if (any_eos & ~mask_before.any(dim=1)).any():
271
+ break
272
+
273
+ if causal_context:
274
+ for layer in model_module.encoder.layers:
275
+ if hasattr(layer.self_attn, 'diffusion_lm'):
276
+ layer.self_attn.diffusion_lm=False
277
+
278
+ # after block is fully denoised, update KV cache
279
+ output = model(
280
+ x_accum[:, block_slice],
281
+ past_key_values=past_key_values,
282
+ use_cache=True,
283
+ use_causal_mask=causal_context
284
+ )
285
+ past_key_values = output.past_key_values
286
+ nfe += 1
287
+
288
+ if causal_context:
289
+ for layer in model_module.encoder.layers:
290
+ if hasattr(layer.self_attn, 'diffusion_lm'):
291
+ layer.self_attn.diffusion_lm=True
292
+ # Next block's first position = greedy/sampled next token from this causal encode
293
+ last_logit = output.logits[:, -1, :]
294
+ if temperature > 0:
295
+ probs = torch.softmax(last_logit / temperature, dim=-1)
296
+ next_token = torch.multinomial(probs, num_samples=1)
297
+ else:
298
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
299
+
300
+ if dream_style and num_block < num_blocks - 1:
301
+ # refresh context-next logit for the next block
302
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
303
+
304
+ if eos_token_id is not None:
305
+ gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
306
+ is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
307
+ has_eos = is_eos.any(dim=1) # (B,)
308
+ if has_eos.all():
309
+ first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
310
+ max_eos = first_eos_pos.max().item()
311
+ return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
312
+
313
+ return x_accum, nfe
config.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ada_dlm_loss_ratio": null,
3
+ "ada_perm_ratio_global": null,
4
+ "ada_perm_ratio_per_block": null,
5
+ "adaptive_mask_rate": false,
6
+ "ar_loss_weight": 1.0,
7
+ "architectures": [
8
+ "MinistralDiffEncoderModel"
9
+ ],
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "attn_implementation": "sdpa",
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
15
+ "AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
16
+ },
17
+ "block_size": 32,
18
+ "bos_token_id": 1,
19
+ "diff_loss_weight": 1,
20
+ "dlm_arch": "encoder",
21
+ "dlm_loss_weight": null,
22
+ "dlm_paradigm": "bidirectional",
23
+ "dlm_type": "llada",
24
+ "dp_varying_mask_ratio": false,
25
+ "enable_self_spec": false,
26
+ "enforce_mask": false,
27
+ "eos_token_id": 11,
28
+ "global_loss_avg": false,
29
+ "head_dim": 128,
30
+ "hidden_act": "silu",
31
+ "hidden_size": 5120,
32
+ "initializer_range": 0.02,
33
+ "intermediate_size": 16384,
34
+ "mask_token_id": 100,
35
+ "max_position_embeddings": 262144,
36
+ "mlp_bias": false,
37
+ "model_type": "ministral_dlm",
38
+ "multi_sampling": null,
39
+ "num_ar_layers": 0,
40
+ "num_attention_heads": 32,
41
+ "num_diffusion_layers": 0,
42
+ "num_hidden_layers": 40,
43
+ "num_key_value_heads": 8,
44
+ "num_skip_loss_tokens": 0,
45
+ "prefix_ratio": 0.8,
46
+ "random_length_prob": 0,
47
+ "rms_norm_eps": 1e-05,
48
+ "rope_parameters": {
49
+ "beta_fast": 32.0,
50
+ "beta_slow": 1.0,
51
+ "factor": 16.0,
52
+ "llama_4_scaling_beta": 0.1,
53
+ "mscale": 1.0,
54
+ "mscale_all_dim": 1.0,
55
+ "original_max_position_embeddings": 16384,
56
+ "rope_theta": 1000000000.0,
57
+ "rope_type": "yarn",
58
+ "type": "yarn"
59
+ },
60
+ "rope_scaling": {
61
+ "beta_fast": 32.0,
62
+ "beta_slow": 1.0,
63
+ "factor": 16.0,
64
+ "llama_4_scaling_beta": 0.1,
65
+ "mscale": 1.0,
66
+ "mscale_all_dim": 1.0,
67
+ "original_max_position_embeddings": 16384,
68
+ "rope_theta": 1000000.0,
69
+ "rope_type": "yarn",
70
+ "type": "yarn"
71
+ },
72
+ "rope_theta": 1000000000.0,
73
+ "sliding_window": null,
74
+ "tie_word_embeddings": false,
75
+ "tok_mask_half_life_ratio": null,
76
+ "torch_dtype": "bfloat16",
77
+ "transformers_version": "4.55.4",
78
+ "use_cache": false,
79
+ "vocab_size": 131072
80
+ }
configuration_ministral_dlm.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Ministral DLM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MinistralDLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
+ It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 131072):
35
+ Vocabulary size of the Ministral model.
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 14336):
39
+ Dimension of the MLP representations.
40
+ num_hidden_layers (`int`, *optional*, defaults to 34):
41
+ Number of hidden layers in the Transformer decoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 32):
43
+ Number of attention heads for each attention layer.
44
+ num_key_value_heads (`int`, *optional*, defaults to 8):
45
+ Number of key_value heads for Grouped Query Attention.
46
+ head_dim (`int`, *optional*, defaults to 128):
47
+ The attention head dimension.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function.
50
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
51
+ The maximum sequence length.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions.
58
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
59
+ Whether the model's input and output word embeddings should be tied.
60
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
61
+ The base period of the RoPE embeddings.
62
+ rope_parameters (`Dict`, *optional*):
63
+ Dictionary containing the scaling configuration for the RoPE embeddings.
64
+ Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
65
+ attention_bias (`bool`, defaults to `False`):
66
+ Whether to use a bias in the query, key, value and output projection layers.
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ mlp_bias (`bool`, *optional*, defaults to `False`):
70
+ Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
+ sliding_window (`int`, *optional*, defaults to None):
72
+ Sliding window attention size.
73
+ mask_token_id (`int`, *optional*, defaults to -1):
74
+ Token ID for masking in diffusion.
75
+ dlm_type (`str`, *optional*, defaults to 'llada'):
76
+ Type of diffusion language model ('llada', 'dream').
77
+ random_length_prob (`float`, *optional*):
78
+ Probability of using random lengths during training.
79
+ num_ar_layers (`int`, *optional*, defaults to 0):
80
+ Number of autoregressive layers.
81
+ num_diffusion_layers (`int`, *optional*, defaults to 0):
82
+ Number of diffusion layers.
83
+ diff_loss_weight (`float`, *optional*, defaults to 1):
84
+ Weight for diffusion loss.
85
+ enforce_mask (`bool`, *optional*, defaults to False):
86
+ Whether to enforce masking.
87
+ prefix_ratio (`float`, *optional*, defaults to 0.8):
88
+ Ratio for prefix in prefix_bidirectional mode.
89
+ dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
90
+ Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
91
+ dlm_arch (`str`, *optional*, defaults to 'encoder'):
92
+ Architecture type ('encoder', 'encoder_decoder').
93
+ block_size (`int`, *optional*, defaults to 32):
94
+ Block size for block diffusion paradigms.
95
+ tok_mask_half_life_ratio (`float`, *optional*):
96
+ Half-life ratio for token masking.
97
+ adaptive_mask_rate (`bool`, *optional*, defaults to False):
98
+ Whether to use adaptive mask rate.
99
+ multi_sampling (`int`, *optional*):
100
+ Number of samples for multi-sampling.
101
+ num_skip_loss_tokens (`int`, *optional*, defaults to 0):
102
+ Number of tokens to skip in loss calculation.
103
+ dlm_loss_weight (`float`, *optional*):
104
+ Weight for diffusion LM loss.
105
+ ar_loss_weight (`float`, *optional*, defaults to 1.0):
106
+ Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
107
+ global_loss_avg (`bool`, *optional*, defaults to False):
108
+ Whether to use global loss average.
109
+ dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
110
+ Whether to use varying mask ratio for each DP rank during sampling.
111
+ ada_perm_ratio_per_block (`float`, *optional*):
112
+ Adaptive permutation ratio for each block.
113
+ ada_perm_ratio_global (`float`, *optional*):
114
+ Adaptive permutation ratio for global.
115
+ enable_self_spec (`bool`, *optional*, defaults to `False`):
116
+ Force MinistralFlexAttention for all paradigms (including bidirectional/autoregressive).
117
+ Required for self speculative generation; leave False for standard eval to use faster SDPA kernels.
118
+ """
119
+
120
+ model_type = "ministral_dlm"
121
+ keys_to_ignore_at_inference = ["past_key_values"]
122
+
123
+ # Default tensor parallel plan for base model `Ministral`
124
+ base_model_tp_plan = {
125
+ "layers.*.self_attn.q_proj": "colwise",
126
+ "layers.*.self_attn.k_proj": "colwise",
127
+ "layers.*.self_attn.v_proj": "colwise",
128
+ "layers.*.self_attn.o_proj": "rowwise",
129
+ "layers.*.mlp.gate_proj": "colwise",
130
+ "layers.*.mlp.up_proj": "colwise",
131
+ "layers.*.mlp.down_proj": "rowwise",
132
+ }
133
+ base_model_pp_plan = {
134
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
135
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
136
+ "norm": (["hidden_states"], ["hidden_states"]),
137
+ }
138
+
139
+ def __init__(
140
+ self,
141
+ vocab_size=131072,
142
+ hidden_size=4096,
143
+ intermediate_size=14336,
144
+ num_hidden_layers=34,
145
+ num_attention_heads=32,
146
+ num_key_value_heads=8,
147
+ head_dim=128,
148
+ hidden_act="silu",
149
+ max_position_embeddings=262144,
150
+ initializer_range=0.02,
151
+ rms_norm_eps=1e-05,
152
+ use_cache=True,
153
+ pad_token_id=None,
154
+ bos_token_id=1,
155
+ eos_token_id=2,
156
+ tie_word_embeddings=False,
157
+ rope_theta=1000000.0,
158
+ rope_parameters=None,
159
+ rope_scaling=None,
160
+ attention_bias=False,
161
+ attention_dropout=0.0,
162
+ mlp_bias=False,
163
+ sliding_window=None,
164
+ attn_implementation="sdpa",
165
+ mask_token_id=-1,
166
+ dlm_type='llada',
167
+ random_length_prob=None,
168
+ num_ar_layers=0,
169
+ num_diffusion_layers=0,
170
+ diff_loss_weight=1,
171
+ enforce_mask=False,
172
+ prefix_ratio=0.8,
173
+ dlm_paradigm='bidirectional',
174
+ dlm_arch='encoder',
175
+ block_size=32,
176
+ tok_mask_half_life_ratio=None,
177
+ adaptive_mask_rate=False,
178
+ multi_sampling=None,
179
+ num_skip_loss_tokens=0,
180
+ dlm_loss_weight=None,
181
+ ar_loss_weight=1.0,
182
+ global_loss_avg=False,
183
+ dp_varying_mask_ratio=False,
184
+ ada_perm_ratio_per_block=None,
185
+ ada_perm_ratio_global=None,
186
+ ada_dlm_loss_ratio=None,
187
+ enable_self_spec=False,
188
+ **kwargs,
189
+ ):
190
+ self.vocab_size = vocab_size
191
+ self.max_position_embeddings = max_position_embeddings
192
+ self.hidden_size = hidden_size
193
+ self.intermediate_size = intermediate_size
194
+ self.num_hidden_layers = num_hidden_layers
195
+ self.num_attention_heads = num_attention_heads
196
+
197
+ # for backward compatibility
198
+ if num_key_value_heads is None:
199
+ num_key_value_heads = num_attention_heads
200
+
201
+ self.num_key_value_heads = num_key_value_heads
202
+ self.head_dim = head_dim
203
+ self.hidden_act = hidden_act
204
+ self.initializer_range = initializer_range
205
+ self.rms_norm_eps = rms_norm_eps
206
+ self.use_cache = use_cache
207
+ self.rope_theta = rope_theta
208
+ self.rope_parameters = rope_parameters
209
+ self.rope_scaling = rope_scaling
210
+ self.attention_bias = attention_bias
211
+ self.attention_dropout = attention_dropout
212
+ self.mlp_bias = mlp_bias
213
+ self.sliding_window = sliding_window
214
+
215
+ rope_config_validation(self)
216
+
217
+ self.attn_implementation = attn_implementation
218
+
219
+ self.mask_token_id = mask_token_id
220
+ self.dlm_type = dlm_type
221
+ self.random_length_prob = random_length_prob
222
+ self.num_ar_layers = num_ar_layers
223
+ self.num_diffusion_layers = num_diffusion_layers
224
+ self.diff_loss_weight = diff_loss_weight
225
+ self.enforce_mask = enforce_mask
226
+ self.prefix_ratio = prefix_ratio
227
+ self.dlm_paradigm = dlm_paradigm
228
+ self.dlm_arch = dlm_arch
229
+ self.block_size = block_size
230
+ self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
231
+ self.adaptive_mask_rate = adaptive_mask_rate
232
+ self.multi_sampling = multi_sampling
233
+ self.num_skip_loss_tokens = num_skip_loss_tokens
234
+ self.dlm_loss_weight = dlm_loss_weight
235
+ self.ar_loss_weight = ar_loss_weight
236
+ self.global_loss_avg = global_loss_avg
237
+ self.dp_varying_mask_ratio = dp_varying_mask_ratio
238
+ self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
239
+ self.ada_perm_ratio_global = ada_perm_ratio_global
240
+ self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
241
+ self.enable_self_spec = enable_self_spec
242
+ super().__init__(
243
+ pad_token_id=pad_token_id,
244
+ bos_token_id=bos_token_id,
245
+ eos_token_id=eos_token_id,
246
+ tie_word_embeddings=tie_word_embeddings,
247
+ **kwargs,
248
+ )
249
+
250
+
251
+ __all__ = ["MinistralDLMConfig"]
252
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 11,
5
+ "transformers_version": "4.55.4",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272e78c39809711139deb08024b4fe8e6af83ab1316c8514bdfa35d7c880a320
3
+ size 27012190712
modeling_ministral.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from transformers.utils.generic import check_model_inputs
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_layers import (
17
+ GenericForQuestionAnswering,
18
+ GenericForSequenceClassification,
19
+ GenericForTokenClassification,
20
+ GradientCheckpointingLayer,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
+ # from transformers.utils.generic import maybe_autocast
28
+ from .configuration_ministral_dlm import MinistralDLMConfig
29
+
30
+ #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
+
32
+ def rotate_half(x):
33
+ """Rotates half the hidden dims of the input."""
34
+ x1 = x[..., : x.shape[-1] // 2]
35
+ x2 = x[..., x.shape[-1] // 2 :]
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+ # @use_kernel_func_from_hub("rotary_pos_emb")
39
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
40
+ """Applies Rotary Position Embedding to the query and key tensors.
41
+
42
+ Args:
43
+ q (`torch.Tensor`): The query tensor.
44
+ k (`torch.Tensor`): The key tensor.
45
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
46
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
47
+ position_ids (`torch.Tensor`, *optional*):
48
+ Deprecated and unused.
49
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
50
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
51
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
52
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
53
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
54
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
55
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
56
+ Returns:
57
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
58
+ """
59
+ cos = cos.unsqueeze(unsqueeze_dim)
60
+ sin = sin.unsqueeze(unsqueeze_dim)
61
+ q_embed = (q * cos) + (rotate_half(q) * sin)
62
+ k_embed = (k * cos) + (rotate_half(k) * sin)
63
+ return q_embed, k_embed
64
+
65
+
66
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
67
+ """
68
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
69
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
70
+ """
71
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
72
+ if n_rep == 1:
73
+ return hidden_states
74
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
75
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
76
+
77
+
78
+ def eager_attention_forward(
79
+ module: nn.Module,
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor],
84
+ scaling: float,
85
+ dropout: float = 0.0,
86
+ **kwargs: Unpack[TransformersKwargs],
87
+ ):
88
+ key_states = repeat_kv(key, module.num_key_value_groups)
89
+ value_states = repeat_kv(value, module.num_key_value_groups)
90
+
91
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
92
+ if attention_mask is not None:
93
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
94
+ attn_weights = attn_weights + causal_mask
95
+
96
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
97
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
98
+ attn_output = torch.matmul(attn_weights, value_states)
99
+ attn_output = attn_output.transpose(1, 2).contiguous()
100
+
101
+ return attn_output, attn_weights
102
+
103
+
104
+ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
105
+ scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
106
+ return scaling.unsqueeze(-1)
107
+
108
+
109
+ # @use_kernelized_func(apply_rotary_pos_emb)
110
+ class Ministral3Attention(nn.Module):
111
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
112
+
113
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
114
+ super().__init__()
115
+ self.config = config
116
+ self.layer_idx = layer_idx
117
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
118
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
119
+ self.scaling = self.head_dim**-0.5
120
+ self.attention_dropout = config.attention_dropout
121
+ self.is_causal = True
122
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
123
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
124
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
125
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
126
+
127
+ self.diffusion_lm = config.diffusion_lm
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.Tensor,
132
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
133
+ attention_mask: Optional[torch.Tensor],
134
+ past_key_values: Optional[Cache] = None,
135
+ cache_position: Optional[torch.LongTensor] = None,
136
+ use_cache: Optional[bool] = False,
137
+ **kwargs: Unpack[FlashAttentionKwargs],
138
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
139
+ input_shape = hidden_states.shape[:-1]
140
+ hidden_shape = (*input_shape, -1, self.head_dim)
141
+
142
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
143
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
144
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
145
+
146
+ cos, sin = position_embeddings
147
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
148
+ query_states = query_states * _get_llama_4_attn_scale(
149
+ cache_position,
150
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
151
+ self.config.rope_parameters.get("original_max_position_embeddings"),
152
+ ).to(query_states.dtype)
153
+
154
+ if past_key_values is not None:
155
+ if use_cache:
156
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
157
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
158
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
159
+ else: ## if use_cache == False, do not update cache
160
+ old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
161
+ key_states = torch.cat([old_k, key_states], dim=-2)
162
+ value_states = torch.cat([old_v, value_states], dim=-2)
163
+
164
+ attention_interface: Callable = eager_attention_forward
165
+ if self.config._attn_implementation != "eager":
166
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
167
+
168
+ if self.diffusion_lm:
169
+ attn_output, attn_weights = attention_interface(
170
+ self,
171
+ query_states,
172
+ key_states,
173
+ value_states,
174
+ None,
175
+ dropout=0.0 if not self.training else self.attention_dropout,
176
+ scaling=self.scaling,
177
+ is_causal=False,
178
+ **kwargs,
179
+ )
180
+
181
+ else:
182
+ attn_output, attn_weights = attention_interface(
183
+ self,
184
+ query_states,
185
+ key_states,
186
+ value_states,
187
+ attention_mask,
188
+ dropout=0.0 if not self.training else self.attention_dropout,
189
+ scaling=self.scaling,
190
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
191
+ **kwargs,
192
+ )
193
+
194
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
195
+ attn_output = self.o_proj(attn_output)
196
+ return attn_output, attn_weights
197
+
198
+
199
+ class Ministral3MLP(nn.Module):
200
+ def __init__(self, config):
201
+ super().__init__()
202
+ self.config = config
203
+ self.hidden_size = config.hidden_size
204
+ self.intermediate_size = config.intermediate_size
205
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
206
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
207
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
208
+ self.act_fn = ACT2FN[config.hidden_act]
209
+
210
+ def forward(self, x):
211
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
212
+ return down_proj
213
+
214
+
215
+ @use_kernel_forward_from_hub("RMSNorm")
216
+ class Ministral3RMSNorm(nn.Module):
217
+ def __init__(self, hidden_size, eps=1e-6):
218
+ """
219
+ Ministral3RMSNorm is equivalent to T5LayerNorm
220
+ """
221
+ super().__init__()
222
+ self.weight = nn.Parameter(torch.ones(hidden_size))
223
+ self.variance_epsilon = eps
224
+
225
+ def forward(self, hidden_states):
226
+ input_dtype = hidden_states.dtype
227
+ hidden_states = hidden_states.to(torch.float32)
228
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
229
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
230
+ return self.weight * hidden_states.to(input_dtype)
231
+
232
+ def extra_repr(self):
233
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
234
+
235
+
236
+ class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
238
+ super().__init__()
239
+ self.hidden_size = config.hidden_size
240
+
241
+ if hasattr(config, 'attn_class'):
242
+ attn_class = config.attn_class
243
+ else:
244
+ attn_class = Ministral3Attention
245
+
246
+ self.self_attn = attn_class(config=config, layer_idx=layer_idx)
247
+ self.mlp = Ministral3MLP(config)
248
+ self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
+ self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ position_ids: Optional[torch.LongTensor] = None,
256
+ past_key_values: Optional[Cache] = None,
257
+ use_cache: Optional[bool] = False,
258
+ cache_position: Optional[torch.LongTensor] = None,
259
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
260
+ **kwargs: Unpack[TransformersKwargs],
261
+ ) -> torch.Tensor:
262
+ residual = hidden_states
263
+ hidden_states = self.input_layernorm(hidden_states)
264
+ # Self Attention
265
+ hidden_states, _ = self.self_attn(
266
+ hidden_states=hidden_states,
267
+ attention_mask=attention_mask,
268
+ position_ids=position_ids,
269
+ past_key_values=past_key_values,
270
+ use_cache=use_cache,
271
+ cache_position=cache_position,
272
+ position_embeddings=position_embeddings,
273
+ **kwargs,
274
+ )
275
+ hidden_states = residual + hidden_states
276
+
277
+ # Fully Connected
278
+ residual = hidden_states
279
+ hidden_states = self.post_attention_layernorm(hidden_states)
280
+ hidden_states = self.mlp(hidden_states)
281
+ hidden_states = residual + hidden_states
282
+ return hidden_states
283
+
284
+
285
+ @auto_docstring
286
+ class Ministral3PreTrainedModel(PreTrainedModel):
287
+ config: MinistralDLMConfig
288
+ base_model_prefix = "model"
289
+ supports_gradient_checkpointing = True
290
+ _no_split_modules = ["Ministral3DecoderLayer"]
291
+ _skip_keys_device_placement = ["past_key_values"]
292
+ _supports_flash_attn = True
293
+ _supports_sdpa = True
294
+ _supports_flex_attn = True
295
+
296
+ _can_compile_fullgraph = True
297
+ _supports_attention_backend = True
298
+ _can_record_outputs = {
299
+ "hidden_states": Ministral3DecoderLayer,
300
+ "attentions": Ministral3Attention,
301
+ }
302
+
303
+
304
+ class Ministral3RotaryEmbedding(nn.Module):
305
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
306
+
307
+ def __init__(self, config: MinistralDLMConfig, device=None):
308
+ super().__init__()
309
+ self.max_seq_len_cached = config.max_position_embeddings
310
+ self.original_max_seq_len = config.max_position_embeddings
311
+
312
+ self.config = config
313
+
314
+ self.rope_type = self.config.rope_parameters["rope_type"]
315
+ rope_init_fn: Callable = self.compute_default_rope_parameters
316
+ if self.rope_type != "default":
317
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
318
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
319
+
320
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
321
+ self.original_inv_freq = inv_freq
322
+
323
+
324
+ @staticmethod
325
+ def compute_default_rope_parameters(
326
+ config: Optional[MinistralDLMConfig] = None,
327
+ device: Optional["torch.device"] = None,
328
+ seq_len: Optional[int] = None,
329
+ ) -> tuple["torch.Tensor", float]:
330
+ """
331
+ Computes the inverse frequencies according to the original RoPE implementation
332
+ Args:
333
+ config ([`~transformers.PreTrainedConfig`]):
334
+ The model configuration.
335
+ device (`torch.device`):
336
+ The device to use for initialization of the inverse frequencies.
337
+ seq_len (`int`, *optional*):
338
+ The current sequence length. Unused for this type of RoPE.
339
+ Returns:
340
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
341
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
342
+ """
343
+ base = config.rope_parameters["rope_theta"]
344
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
345
+
346
+ attention_factor = 1.0 # Unused in this type of RoPE
347
+
348
+ # Compute the inverse frequencies
349
+ inv_freq = 1.0 / (
350
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
351
+ )
352
+ return inv_freq, attention_factor
353
+
354
+ @torch.no_grad()
355
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
356
+ def forward(self, x, position_ids):
357
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
358
+ position_ids_expanded = position_ids[:, None, :].float()
359
+
360
+ # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
361
+ # with maybe_autocast(device_type=device_type, enabled=False): # Force float32
362
+
363
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
364
+ emb = torch.cat((freqs, freqs), dim=-1)
365
+ cos = emb.cos() * self.attention_scaling
366
+ sin = emb.sin() * self.attention_scaling
367
+
368
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
369
+
370
+
371
+ @auto_docstring
372
+ class Ministral3Model(Ministral3PreTrainedModel):
373
+ def __init__(self, config: MinistralDLMConfig):
374
+ super().__init__(config)
375
+ self.padding_idx = config.pad_token_id
376
+ self.vocab_size = config.vocab_size
377
+
378
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
379
+ self.layers = nn.ModuleList(
380
+ [Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
381
+ )
382
+ self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
383
+ self.rotary_emb = Ministral3RotaryEmbedding(config=config)
384
+ self.gradient_checkpointing = False
385
+
386
+ # Initialize weights and apply final processing
387
+ self.post_init()
388
+
389
+ @check_model_inputs
390
+ @auto_docstring
391
+ def forward(
392
+ self,
393
+ input_ids: Optional[torch.LongTensor] = None,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ position_ids: Optional[torch.LongTensor] = None,
396
+ past_key_values: Optional[Cache] = None,
397
+ inputs_embeds: Optional[torch.FloatTensor] = None,
398
+ use_cache: Optional[bool] = None,
399
+ cache_position: Optional[torch.LongTensor] = None,
400
+ **kwargs: Unpack[TransformersKwargs],
401
+ ) -> BaseModelOutputWithPast:
402
+ if (input_ids is None) ^ (inputs_embeds is not None):
403
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
404
+
405
+ if inputs_embeds is None:
406
+ inputs_embeds = self.embed_tokens(input_ids)
407
+
408
+ if use_cache and past_key_values is None:
409
+ # past_key_values = DynamicCache(config=self.config)
410
+ past_key_values = DynamicCache()
411
+
412
+ if cache_position is None:
413
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
414
+ cache_position = torch.arange(
415
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
416
+ )
417
+
418
+ if position_ids is None:
419
+ position_ids = cache_position.unsqueeze(0)
420
+
421
+ if kwargs.get("use_causal_mask", False):
422
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
+ causal_mask = mask_function(
424
+ config=self.config,
425
+ input_embeds=inputs_embeds,
426
+ attention_mask=attention_mask,
427
+ cache_position=cache_position,
428
+ past_key_values=past_key_values,
429
+ position_ids=position_ids,
430
+ )
431
+
432
+ else:
433
+ causal_mask = None
434
+
435
+ hidden_states = inputs_embeds
436
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
437
+
438
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
439
+ hidden_states = decoder_layer(
440
+ hidden_states,
441
+ attention_mask=causal_mask,
442
+ position_ids=position_ids,
443
+ past_key_values=past_key_values,
444
+ use_cache=use_cache,
445
+ cache_position=cache_position,
446
+ position_embeddings=position_embeddings,
447
+ **kwargs,
448
+ )
449
+ hidden_states = self.norm(hidden_states)
450
+ return BaseModelOutputWithPast(
451
+ last_hidden_state=hidden_states,
452
+ past_key_values=past_key_values if use_cache else None,
453
+ )
454
+
455
+
456
+ @auto_docstring
457
+ class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
458
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
459
+ _tp_plan = {"lm_head": "colwise_rep"}
460
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
461
+
462
+ def __init__(self, config):
463
+ super().__init__(config)
464
+ self.model = Ministral3Model(config)
465
+ self.vocab_size = config.vocab_size
466
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
467
+
468
+ # Initialize weights and apply final processing
469
+ self.post_init()
470
+
471
+ @can_return_tuple
472
+ @auto_docstring
473
+ def forward(
474
+ self,
475
+ input_ids: Optional[torch.LongTensor] = None,
476
+ attention_mask: Optional[torch.Tensor] = None,
477
+ position_ids: Optional[torch.LongTensor] = None,
478
+ past_key_values: Optional[Cache] = None,
479
+ inputs_embeds: Optional[torch.FloatTensor] = None,
480
+ labels: Optional[torch.LongTensor] = None,
481
+ use_cache: Optional[bool] = None,
482
+ cache_position: Optional[torch.LongTensor] = None,
483
+ logits_to_keep: Union[int, torch.Tensor] = 0,
484
+ **kwargs: Unpack[TransformersKwargs],
485
+ ) -> CausalLMOutputWithPast:
486
+ r"""
487
+ Example:
488
+
489
+ ```python
490
+ >>> from transformers import AutoTokenizer, Ministral3ForCausalLM
491
+
492
+ >>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
493
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
494
+
495
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
496
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
497
+
498
+ >>> # Generate
499
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
500
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
501
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
502
+ ```"""
503
+ outputs: BaseModelOutputWithPast = self.model(
504
+ input_ids=input_ids,
505
+ attention_mask=attention_mask,
506
+ position_ids=position_ids,
507
+ past_key_values=past_key_values,
508
+ inputs_embeds=inputs_embeds,
509
+ use_cache=use_cache,
510
+ cache_position=cache_position,
511
+ **kwargs,
512
+ )
513
+
514
+ hidden_states = outputs.last_hidden_state
515
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
516
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
517
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
518
+
519
+ loss = None
520
+ if labels is not None:
521
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
522
+
523
+ return CausalLMOutputWithPast(
524
+ loss=loss,
525
+ logits=logits,
526
+ past_key_values=outputs.past_key_values,
527
+ hidden_states=outputs.hidden_states,
528
+ attentions=outputs.attentions,
529
+ )
530
+
531
+
532
+ class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
533
+ pass
534
+
535
+
536
+ class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
537
+ pass
538
+
539
+
540
+ class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
541
+ pass
542
+
543
+
544
+ __all__ = [
545
+ "Ministral3ForCausalLM",
546
+ "Ministral3ForQuestionAnswering",
547
+ "Ministral3Model",
548
+ "Ministral3PreTrainedModel",
549
+ "Ministral3ForSequenceClassification",
550
+ "Ministral3ForTokenClassification",
551
+ ]
modeling_ministral_dlm.py ADDED
@@ -0,0 +1,1844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, Tuple, Union
4
+ import random
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
17
+
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.generation import GenerationMixin
27
+
28
+ import math
29
+
30
+ from .chat_utils import generate_with_prefix_cache_block_diff
31
+ from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
+ from .configuration_ministral_dlm import MinistralDLMConfig
33
+
34
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
35
+
36
+ @dataclass
37
+ class MinistralDiffOutputWithPast(ModelOutput):
38
+ loss: torch.FloatTensor | None = None
39
+ logits: torch.FloatTensor | None = None
40
+ causal_logits: torch.FloatTensor | None = None
41
+ past_key_values: Cache | None = None
42
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
43
+ attentions: tuple[torch.FloatTensor, ...] | None = None
44
+
45
+
46
+ # @torch.compile(dynamic=True, mode="reduce-overhead")
47
+ # @torch.compile(mode="default")
48
+ # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
49
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
50
+ def fused_flex_attention(q, k, v, block_mask=None):
51
+ return flex_attention(q, k, v, block_mask=block_mask)
52
+
53
+
54
+ def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
55
+ """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
56
+ if hasattr(past_key_values, 'crop'):
57
+ past_key_values.crop(max_length)
58
+ else:
59
+ for layer_idx in range(len(past_key_values)):
60
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
61
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
62
+ past_key_values._seen_tokens = max_length
63
+
64
+
65
+ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
66
+ """After quadratic decoding, extract only draft tokens (first of each block) from cache."""
67
+ for layer_idx in range(len(past_key_values)):
68
+ if hasattr(past_key_values, 'layers'):
69
+ layer_cache = past_key_values.layers[layer_idx]
70
+ k, v = layer_cache.keys, layer_cache.values
71
+ else:
72
+ k = past_key_values.key_cache[layer_idx]
73
+ v = past_key_values.value_cache[layer_idx]
74
+
75
+ clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
76
+ clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
77
+ new_k = torch.cat([clean_k, draft_k], dim=2)
78
+ new_v = torch.cat([clean_v, draft_v], dim=2)
79
+
80
+ if hasattr(past_key_values, 'layers'):
81
+ layer_cache.keys = new_k
82
+ layer_cache.values = new_v
83
+ else:
84
+ past_key_values.key_cache[layer_idx] = new_k
85
+ past_key_values.value_cache[layer_idx] = new_v
86
+
87
+ past_key_values._seen_tokens = clean_len + block_length
88
+
89
+
90
+ # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
91
+ class MinistralFlexAttention(Ministral3Attention):
92
+ def __init__(self, *args, **kwargs):
93
+ super().__init__(*args, **kwargs)
94
+
95
+ self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
96
+ self.block_size_orig = self.config.block_size
97
+
98
+ if self.config.dlm_paradigm == 'bidirectional':
99
+ self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
100
+ elif self.config.dlm_paradigm == 'autoregressive':
101
+ self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
102
+ elif self.config.dlm_paradigm == 'block_diff':
103
+ self.block_diff_mask = None
104
+ elif self.config.dlm_paradigm == 'sbd_block_diff':
105
+ self.sbd_block_diff_mask = None
106
+ else:
107
+ raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
108
+
109
+ self.block_size = self.block_size_orig
110
+ self.mode = self.config.dlm_paradigm
111
+ self._quadratic_block_mask = {}
112
+
113
+ import torch._dynamo.config as dcfg
114
+ dcfg.cache_size_limit = 512
115
+
116
+
117
+ def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
118
+ if block_length not in self._quadratic_block_mask:
119
+ draft_len = block_length * (block_length + 1)
120
+
121
+ def quadratic(b, h, q_idx, kv_idx):
122
+ first_clean = torch.logical_and(
123
+ kv_idx % (block_length + 1) == 0,
124
+ kv_idx < draft_len,
125
+ )
126
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
127
+ block_q = q_idx // (block_length + 1)
128
+ block_kv = kv_idx // (block_length + 1)
129
+ same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
130
+ same_block_except_first = torch.logical_and(
131
+ same_block,
132
+ q_idx % (block_length + 1) != 0,
133
+ )
134
+ draft_part = torch.logical_or(first_clean, same_block_except_first)
135
+ clean_part = kv_idx >= draft_len
136
+ return torch.logical_or(draft_part, clean_part)
137
+
138
+ block_mask = create_block_mask(
139
+ quadratic,
140
+ B=None,
141
+ H=None,
142
+ Q_LEN=draft_len,
143
+ KV_LEN=draft_len + self.config.max_position_embeddings,
144
+ device="cuda",
145
+ )
146
+
147
+ self._quadratic_block_mask[block_length] = block_mask
148
+
149
+ return self._quadratic_block_mask[block_length]
150
+
151
+
152
+ def set_attention_mode(self, mode, block_size=None):
153
+ self.mode = mode
154
+ self.block_size = block_size
155
+
156
+ def compute_block_mask(self, mode, q_len=None, block_size=None):
157
+
158
+ def bidirectional_mask(b, h, q, kv):
159
+ return (q >= kv) | (q < kv)
160
+
161
+ def autoregressive_mask(b, h, q, kv):
162
+ return (q >= kv)
163
+
164
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
165
+ x0_flag_q = (q_idx >= n)
166
+ x0_flag_kv = (kv_idx >= n)
167
+
168
+ # Compute block indices
169
+ block_q = torch.where(x0_flag_q == 1,
170
+ (q_idx - n) // block_size,
171
+ q_idx // block_size)
172
+ block_kv = torch.where(x0_flag_kv == 1,
173
+ (kv_idx - n) // block_size,
174
+ kv_idx // block_size)
175
+
176
+ # **1. Block Diagonal Mask (M_BD) **
177
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
178
+
179
+ # **2. Offset Block-Causal Mask (M_OBC) **
180
+ offset_block_causal = (
181
+ (block_q > block_kv)
182
+ & (x0_flag_kv == 1)
183
+ & (x0_flag_q == 0)
184
+ )
185
+
186
+ # **3. Block-Causal Mask (M_BC) **
187
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
188
+
189
+ # **4. Combine Masks **
190
+ return block_diagonal | offset_block_causal | block_causal
191
+
192
+
193
+ def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
194
+ x0_flag_q = (q_idx >= n)
195
+ x0_flag_kv = (kv_idx >= n)
196
+
197
+ # Compute block indices
198
+ block_q = torch.where(x0_flag_q == 1,
199
+ (q_idx - n) // block_size,
200
+ q_idx // block_size)
201
+ block_kv = torch.where(x0_flag_kv == 1,
202
+ (kv_idx - n) // block_size,
203
+ kv_idx // block_size)
204
+
205
+ # **1. Block Diagonal Mask (M_BD) **
206
+ block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
207
+
208
+ # **2. Offset Block-Causal Mask (M_OBC) **
209
+ offset_block_causal = (
210
+ (block_q > block_kv)
211
+ & (x0_flag_kv == 1)
212
+ & (x0_flag_q == 0)
213
+ )
214
+
215
+ # **3. Fully Causal Mask (M_BC) **
216
+ fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
217
+
218
+ # **4. Combine Masks **
219
+ return block_diagonal | offset_block_causal | fully_causal
220
+
221
+ if mode == 'bidirectional':
222
+ attn_mask = bidirectional_mask
223
+ elif mode == 'autoregressive':
224
+ attn_mask = autoregressive_mask
225
+ elif mode == 'block_diff':
226
+ assert block_size is not None
227
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
228
+ elif mode == 'sbd_block_diff':
229
+ assert block_size is not None
230
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
231
+ else:
232
+ raise ValueError(f"Unknown attention mode: {mode}")
233
+
234
+ if q_len is not None:
235
+ Q_LEN = q_len
236
+ else:
237
+ if mode in ['block_diff', 'sbd_block_diff']:
238
+ Q_LEN = self.max_seq_length * 2
239
+ else:
240
+ Q_LEN = self.max_seq_length
241
+
242
+ block_mask = create_block_mask(
243
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
244
+ )
245
+
246
+ return block_mask
247
+
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
253
+ attention_mask: Optional[torch.Tensor],
254
+ past_key_values: Optional[Cache] = None,
255
+ cache_position: Optional[torch.LongTensor] = None,
256
+ is_training: bool = True,
257
+ **kwargs: Unpack[FlashAttentionKwargs],
258
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
259
+ bsz, q_len, _ = hidden_states.size()
260
+ input_shape = hidden_states.shape[:-1]
261
+ hidden_shape = (*input_shape, -1, self.head_dim)
262
+
263
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
264
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
265
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
266
+
267
+ cos, sin = position_embeddings
268
+
269
+ if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
270
+ # Split query and key states in half along sequence length dimension
271
+ q1, q2 = query_states.chunk(2, dim=2)
272
+ k1, k2 = key_states.chunk(2, dim=2)
273
+
274
+ # Apply RoPE independently to each half
275
+ q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
276
+ q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
277
+
278
+ # Recombine the halves
279
+ query_states = torch.cat([q1, q2], dim=2)
280
+ key_states = torch.cat([k1, k2], dim=2)
281
+ else:
282
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
283
+
284
+ query_states = query_states * _get_llama_4_attn_scale(
285
+ cache_position,
286
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
287
+ self.config.rope_parameters.get("original_max_position_embeddings"),
288
+ ).to(query_states.dtype)
289
+
290
+ if past_key_values is not None:
291
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
292
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
293
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
294
+
295
+ self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
296
+ if self_spec_inference_mode is not None:
297
+ if self_spec_inference_mode == "quadratic":
298
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
299
+ if block_length is None:
300
+ raise ValueError("SBD quadratic decoding requires block_length in config.")
301
+ if past_key_values is not None:
302
+ seq_len = key_states.shape[2]
303
+ draft_len = block_length * (block_length + 1)
304
+
305
+ clean_keys = key_states[:, :, :-draft_len]
306
+ draft_keys = key_states[:, :, -draft_len:]
307
+ clean_values = value_states[:, :, :-draft_len]
308
+ draft_values = value_states[:, :, -draft_len:]
309
+ key_states = torch.cat([draft_keys, clean_keys], dim=2)
310
+ value_states = torch.cat([draft_values, clean_values], dim=2)
311
+
312
+ block_mask: BlockMask = self._get_sbd_inference_quadratic_decoding_block_mask(
313
+ block_length=block_length
314
+ )
315
+ block_mask.seq_lengths = (draft_len, seq_len)
316
+ else:
317
+ seq_len = query_states.shape[2]
318
+ draft_len = block_length * (block_length + 1)
319
+ clean_len = seq_len - draft_len
320
+
321
+ def _causal_mask(b, h, q_idx, kv_idx):
322
+ return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
323
+
324
+ def _draft2clean_mask(b, h, q_idx, kv_idx):
325
+ full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
326
+ first_clean = torch.logical_and(
327
+ q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
328
+ )
329
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
330
+ return torch.logical_or(full_clean, first_clean)
331
+
332
+ def _draft_mask(b, h, q_idx, kv_idx):
333
+ block_q = (q_idx - clean_len) // (block_length + 1)
334
+ block_kv = (kv_idx - clean_len) // (block_length + 1)
335
+ quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
336
+ same_block = torch.logical_and(block_q == block_kv, quadrant)
337
+ same_block_except_first = torch.logical_and(
338
+ same_block,
339
+ (q_idx - clean_len) % (block_length + 1) != 0,
340
+ )
341
+ return torch.logical_and(block_q == block_kv, same_block_except_first)
342
+
343
+ mask = or_masks(_causal_mask, _draft2clean_mask)
344
+ mask = or_masks(mask, _draft_mask)
345
+
346
+ block_mask = create_block_mask(
347
+ mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
348
+ )
349
+
350
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
351
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
352
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
353
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
354
+ attn_output = self.o_proj(attn_output)
355
+ return attn_output, None
356
+
357
+ elif self_spec_inference_mode == "default":
358
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
359
+ if block_length is None:
360
+ raise ValueError("SBD default decoding requires block_length in config.")
361
+ seq_len = query_states.shape[2]
362
+ prefix_len = seq_len - block_length
363
+
364
+ def _clean_q_mask(b, h, q_idx, kv_idx):
365
+ return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
366
+
367
+ def _noisy_q_mask(b, h, q_idx, kv_idx):
368
+ return q_idx >= prefix_len
369
+
370
+ block_mask = create_block_mask(
371
+ or_masks(_clean_q_mask, _noisy_q_mask),
372
+ B=None,
373
+ H=None,
374
+ Q_LEN=seq_len,
375
+ KV_LEN=seq_len,
376
+ )
377
+
378
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
379
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
380
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
381
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
382
+ attn_output = self.o_proj(attn_output)
383
+ return attn_output, None
384
+
385
+ else:
386
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
387
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
388
+
389
+ if self.mode == 'bidirectional':
390
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
391
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
392
+ else:
393
+ block_mask = self.bidirectional_mask
394
+
395
+ elif self.mode == 'autoregressive':
396
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
397
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
398
+ else:
399
+ block_mask = self.autoregressive_mask
400
+
401
+ elif self.mode == 'block_diff':
402
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
403
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
404
+ else:
405
+ block_mask = self.block_diff_mask
406
+ elif self.mode == 'sbd_block_diff':
407
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
408
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
409
+ else:
410
+ block_mask = self.sbd_block_diff_mask
411
+ else:
412
+ raise ValueError(f"Unknown attention mode: {self.mode}")
413
+
414
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
415
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
416
+
417
+ attn_output = self.o_proj(attn_output)
418
+
419
+ return attn_output, None
420
+
421
+
422
+ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
423
+ """Return a Bool mask of length len(log_w) with exactly k True."""
424
+ g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
425
+ topk = torch.topk(log_w + g, k).indices
426
+ mask = torch.zeros_like(log_w, dtype=torch.bool)
427
+ mask[topk] = True
428
+ return mask
429
+
430
+
431
+ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
432
+ """
433
+ A single model with:
434
+ - a bidirectional encoder + diffusion‐LM head over A
435
+ - a causal decoder + LM head over B, conditioned on F_A
436
+ """
437
+
438
+ def __init__(self, config: MinistralDLMConfig):
439
+ super().__init__(config)
440
+
441
+ self.mask_token_id = config.mask_token_id
442
+
443
+ diffusion_config = copy.deepcopy(config)
444
+ diffusion_config.diffusion_lm = True
445
+
446
+ use_flex = getattr(config, 'enable_self_spec', False)
447
+
448
+ if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
449
+ diffusion_config.attn_class = MinistralFlexAttention
450
+ elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
451
+ diffusion_config.attn_class = MinistralFlexAttention if use_flex else Ministral3Attention
452
+ if config.dlm_paradigm == 'autoregressive':
453
+ diffusion_config.diffusion_lm = False
454
+ else:
455
+ raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
456
+
457
+ self.encoder = Ministral3Model(diffusion_config)
458
+ self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
459
+ self.vocab_size = config.vocab_size
460
+
461
+ self.current_iter_ratio = None
462
+
463
+ self.post_init()
464
+
465
+
466
+ def get_input_embeddings(self):
467
+ return self.encoder.embed_tokens
468
+
469
+ def set_input_embeddings(self, value):
470
+ self.encoder.embed_tokens = value
471
+
472
+ def get_output_embeddings(self):
473
+ return self.diffusion_head
474
+
475
+ def set_output_embeddings(self, new_embeddings):
476
+ self.diffusion_head = new_embeddings
477
+
478
+
479
+ def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
480
+ b, l = input_ids.shape
481
+ device = input_ids.device
482
+
483
+ if self.config.dp_varying_mask_ratio:
484
+ # Enable different random seeds for each DP rank during sampling
485
+ import torch.distributed as dist
486
+ dp_rank = 0
487
+ if dist.is_initialized():
488
+ try:
489
+ dp_rank = dist.get_rank()
490
+ except Exception:
491
+ dp_rank = 0
492
+ # Use a local generator to avoid affecting global RNG state
493
+ generator = torch.Generator(device=device)
494
+ generator.manual_seed(torch.seed() + dp_rank)
495
+ else:
496
+ generator = None
497
+
498
+ if self.config.adaptive_mask_rate:
499
+ assert block_size is not None
500
+
501
+ # --- simple linear window mapping ---
502
+ bs_min = getattr(self.config, "t_bs_min", 16)
503
+ bs_max = getattr(self.config, "t_bs_max", 128)
504
+ w = getattr(self.config, "t_window_width", 0.6) # fixed width
505
+
506
+ # fraction in [0,1] (unclamped first)
507
+ frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
508
+ # upper bound decreases linearly from 1.0 -> 0.5
509
+ u_max = 1.0 - w * frac
510
+ # clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
511
+ u_max = max(0.6, min(1.0, u_max))
512
+ u_min = u_max - w # ensures width = w
513
+
514
+ # sample t ~ Uniform(u_min, u_max)
515
+ t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
516
+ else:
517
+ t = torch.rand(b, device=device, generator=generator)
518
+
519
+ p_mask = (1 - eps) * t + eps # shape: (b,)
520
+ p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
521
+
522
+ masked_indices = torch.rand((b, l), device=device) < p_mask
523
+
524
+ if loss_mask is not None:
525
+ masked_indices[loss_mask == 0] = 0
526
+
527
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
528
+
529
+ return noisy_batch, masked_indices, p_mask
530
+
531
+
532
+ def forward_process_exp(
533
+ self,
534
+ input_ids: torch.Tensor,
535
+ eps: float = 1e-3,
536
+ block_size: int | None = None,
537
+ half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
538
+ loss_mask: Optional[torch.Tensor] = None,
539
+ ):
540
+ """
541
+ Two-stage corruption with optional per-block sampling.
542
+ • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
543
+ • Stage 2: sample exactly k positions with weights
544
+ w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
545
+ uniform when m→1).
546
+ If `block_size` is given, the procedure is run *independently*
547
+ inside each contiguous block of that length (last block may be shorter).
548
+ When block_size is provided, m is sampled per-block and p_mask is per-block.
549
+ Args
550
+ ----
551
+ input_ids : (B, L) LongTensor
552
+ eps : minimum corruption ratio
553
+ block_size: if not None, operate block-wise with per-block m sampling
554
+ half_life_ratio : controls steepness when m→0
555
+ """
556
+ B, L = input_ids.shape
557
+ device = input_ids.device
558
+ dtype = torch.float32
559
+
560
+ masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
561
+ p_mask = torch.zeros((B, L), dtype=dtype, device=device)
562
+
563
+ # ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
564
+ for b in range(B):
565
+ if block_size is None:
566
+ # ---------- Per-batch sampling (original behavior) ----------
567
+ m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
568
+ k_tot = int(round(m * L))
569
+ k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
570
+
571
+ # Fill p_mask for this batch
572
+ p_mask[b, :] = m
573
+
574
+ slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
575
+
576
+ # ------- single pool over the whole sentence -------------
577
+ lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
578
+
579
+ pos = torch.arange(L, device=device, dtype=dtype)
580
+ log_w = (lam_base * slope * pos).clone()
581
+
582
+ masked_indices[b] = gumbel_topk(log_w, k_tot)
583
+
584
+ else:
585
+ # ---------- Per-block sampling ----------
586
+ num_blocks = math.ceil(L / block_size)
587
+ lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
588
+
589
+ for blk in range(num_blocks):
590
+ start = blk * block_size
591
+ end = min((blk + 1) * block_size, L)
592
+ blk_len = end - start
593
+
594
+ # Sample m per block
595
+ m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
596
+
597
+ # Fill p_mask for this block
598
+ p_mask[b, start:end] = m_blk
599
+
600
+ # per-block budget
601
+ k_blk = int(round(m_blk * blk_len))
602
+ k_blk = max(0, min(k_blk, blk_len))
603
+ if k_blk == 0:
604
+ continue
605
+
606
+ slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
607
+
608
+ pos = torch.arange(blk_len, device=device, dtype=dtype)
609
+ log_w = lam_base * slope * pos
610
+
611
+ blk_mask = gumbel_topk(log_w, k_blk)
612
+ masked_indices[b, start:end] = blk_mask
613
+
614
+ if loss_mask is not None:
615
+ masked_indices[loss_mask == 0] = 0
616
+
617
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
618
+ return noisy_batch, masked_indices, p_mask
619
+
620
+
621
+ def forward(
622
+ self,
623
+ input_ids: torch.LongTensor,
624
+ attention_mask: Optional[torch.Tensor] = None,
625
+ position_ids: Optional[torch.LongTensor] = None,
626
+ labels: Optional[torch.LongTensor] = None,
627
+ split_len: Optional[int] = None,
628
+ past_key_values: Optional[Cache] = None,
629
+ block_size: Optional[int] = None,
630
+ block_diff_ppl: bool = False,
631
+ eps: float = 1e-3,
632
+ is_teacher: bool = False,
633
+ masked_indices: Optional[torch.Tensor] = None,
634
+ p_mask: Optional[torch.Tensor] = None,
635
+ teacher_logits: Optional[torch.Tensor] = None,
636
+ masked_indices_teacher: Optional[torch.Tensor] = None,
637
+ loss_mask: Optional[torch.Tensor] = None,
638
+ ce_loss_weight: float = 1.0,
639
+ output_last_hidden_states_only: bool = False,
640
+ skip_loss: bool = False,
641
+ **kwargs,
642
+ ) -> CausalLMOutputWithPast:
643
+
644
+ batch_size, seq_len = input_ids.shape
645
+
646
+ if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
647
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
648
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
649
+ input_ids = input_ids[:, :random_length]
650
+ labels = labels[:, :random_length]
651
+
652
+ if attention_mask is not None:
653
+ attention_mask = attention_mask[:, :random_length]
654
+ if position_ids is not None:
655
+ position_ids = position_ids[:, :random_length]
656
+ if loss_mask is not None:
657
+ loss_mask = loss_mask[:, :random_length]
658
+
659
+ elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
660
+ if labels is not None and block_size is None:
661
+ if torch.rand(1) < self.config.random_length_prob:
662
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
663
+ else:
664
+ block_size = self.config.block_size
665
+
666
+ else:
667
+ raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
668
+
669
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
670
+ if masked_indices is not None:
671
+ # assert p_mask is not None
672
+
673
+ if loss_mask is not None:
674
+ masked_indices[loss_mask == 0] = 0
675
+
676
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
677
+
678
+ else:
679
+ if self.config.tok_mask_half_life_ratio is not None:
680
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
681
+ else:
682
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
683
+
684
+ else:
685
+ noisy_inputs = input_ids
686
+ masked_indices = None
687
+ p_mask = None
688
+
689
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
690
+ for layer in self.encoder.layers:
691
+ if hasattr(layer.self_attn, 'set_attention_mode'):
692
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
693
+
694
+ input_ids_len = noisy_inputs.shape[1]
695
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
696
+ if position_ids is None:
697
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
698
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
699
+
700
+ if block_diff_ppl:
701
+ if position_ids is None:
702
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
703
+
704
+ enc_out = self.encoder(
705
+ past_key_values=past_key_values,
706
+ input_ids=noisy_inputs,
707
+ attention_mask=attention_mask,
708
+ position_ids=position_ids,
709
+ is_training=(labels is not None) or (block_diff_ppl),
710
+ **kwargs,
711
+ )
712
+
713
+ if output_last_hidden_states_only:
714
+ return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
715
+
716
+ logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
717
+ causal_logits = None
718
+
719
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
720
+ if self.config.dlm_paradigm == 'sbd_block_diff':
721
+ causal_logits = logits[:, input_ids_len:]
722
+ else:
723
+ causal_logits = None
724
+
725
+ logits = logits[:, :input_ids_len]
726
+
727
+ loss = None
728
+ if labels is not None and not skip_loss:
729
+ if self.config.dlm_paradigm == 'autoregressive':
730
+ shift_logits = logits[..., :-1, :].contiguous()
731
+ shift_labels = labels[..., 1:].contiguous()
732
+
733
+ if loss_mask is None:
734
+ loss_fct = CrossEntropyLoss()
735
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
736
+ shift_labels = shift_labels.view(-1)
737
+ loss = loss_fct(shift_logits, shift_labels)
738
+
739
+ else:
740
+ loss_mask = loss_mask[..., 1:].contiguous()
741
+
742
+ loss_fct = CrossEntropyLoss(reduction='none')
743
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
744
+ shift_labels = shift_labels.view(-1)
745
+ shift_labels = shift_labels.to(shift_logits.device)
746
+
747
+ token_losses = loss_fct(shift_logits, shift_labels)
748
+
749
+ flat_loss_mask = loss_mask.reshape(-1)
750
+ loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
751
+
752
+ else:
753
+ # Handle DREAM vs LLADA style losses
754
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
755
+ logits = logits[..., :-1, :].contiguous()
756
+ labels = labels[..., 1:].contiguous()
757
+ masked_indices = masked_indices[:, 1:]
758
+ p_mask = p_mask[:, 1:]
759
+
760
+ if self.config.ada_perm_ratio_per_block is not None:
761
+ # Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
762
+ block_size = self.config.block_size
763
+ batch_size, seq_len = masked_indices.shape
764
+ num_blocks = seq_len // block_size
765
+
766
+ # Get the max logit (confidence) for each position
767
+ confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
768
+
769
+ # Create a mask for tokens to include in loss
770
+ selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
771
+
772
+ for blk in range(num_blocks):
773
+ start = blk * block_size
774
+ end = min((blk + 1) * block_size, seq_len)
775
+
776
+ # Get masked indices within this block
777
+ block_masked = masked_indices[:, start:end] # (batch_size, block_len)
778
+ block_confidence = confidence[:, start:end] # (batch_size, block_len)
779
+
780
+ for b in range(batch_size):
781
+ # Get positions that are masked in this block for this batch
782
+ masked_positions = torch.where(block_masked[b])[0]
783
+ num_masked = len(masked_positions)
784
+
785
+ if num_masked > 0:
786
+ # Number of tokens to keep (top by confidence)
787
+ k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
788
+
789
+ # Get confidence values for masked positions
790
+ masked_confidence = block_confidence[b, masked_positions]
791
+
792
+ # Get indices of top-k confident tokens
793
+ _, topk_indices = torch.topk(masked_confidence, k)
794
+ selected_positions = masked_positions[topk_indices]
795
+
796
+ # Mark these positions in the selected mask
797
+ selected_mask[b, start + selected_positions] = True
798
+
799
+ # Calculate loss only for selected positions
800
+ token_loss = torch.nn.functional.cross_entropy(
801
+ logits[selected_mask],
802
+ labels[selected_mask],
803
+ reduction='none'
804
+ ) / p_mask[selected_mask]
805
+
806
+ num_mask_tokens = selected_mask.sum()
807
+
808
+ else:
809
+ # Calculate token-wise cross entropy loss for masked positions in B
810
+ token_loss = torch.nn.functional.cross_entropy(
811
+ logits[masked_indices],
812
+ labels[masked_indices],
813
+ reduction='none'
814
+ ) / p_mask[masked_indices]
815
+
816
+ num_mask_tokens = masked_indices.sum()
817
+
818
+ if self.config.global_loss_avg:
819
+ loss = token_loss.sum()
820
+ else:
821
+ loss = token_loss.sum() / num_mask_tokens
822
+
823
+ if self.config.ada_dlm_loss_ratio is not None:
824
+ assert self.current_iter_ratio is not None
825
+ assert self.config.dlm_loss_weight is not None
826
+
827
+ dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
828
+ loss = dlm_loss_weight * loss
829
+
830
+ elif self.config.dlm_loss_weight is not None:
831
+ loss = self.config.dlm_loss_weight * loss
832
+
833
+ if self.config.dlm_paradigm == 'sbd_block_diff':
834
+ causal_logits = causal_logits[..., :-1, :].contiguous()
835
+ causal_logits = causal_logits.view(-1, causal_logits.size(-1))
836
+
837
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
838
+ causal_labels = labels.view(-1)
839
+ else:
840
+ causal_labels = labels[..., 1:].contiguous().view(-1)
841
+
842
+ if self.config.global_loss_avg:
843
+ loss_fct = CrossEntropyLoss(reduction='sum')
844
+ ar_loss = loss_fct(causal_logits, causal_labels)
845
+
846
+ self.loss_diffusion = loss.detach().item() / num_mask_tokens
847
+ self.loss_ar = ar_loss.detach().item() / seq_len
848
+
849
+ loss = loss + self.config.ar_loss_weight * ar_loss
850
+ else:
851
+ loss_fct = CrossEntropyLoss()
852
+ ar_loss = loss_fct(causal_logits, causal_labels)
853
+
854
+ self.loss_diffusion = loss.detach().item()
855
+ self.loss_ar = ar_loss.detach().item()
856
+
857
+ loss = loss + self.config.ar_loss_weight * ar_loss
858
+
859
+ if self.config.global_loss_avg:
860
+ if self.config.dlm_paradigm == 'sbd_block_diff':
861
+ loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
862
+ else:
863
+ loss = (loss, num_mask_tokens)
864
+
865
+ return MinistralDiffOutputWithPast(
866
+ loss=loss if not is_teacher else logits,
867
+ logits=logits,
868
+ causal_logits=causal_logits,
869
+ past_key_values=enc_out.past_key_values,
870
+ hidden_states=None,
871
+ attentions=None,
872
+ )
873
+
874
+
875
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None, max_thinking_tokens=None, end_think_token_id=None):
876
+ if eos_token_id is None:
877
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
878
+
879
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
880
+ model=self,
881
+ prompt=prompt_ids,
882
+ gen_length=max_new_tokens,
883
+ steps=steps,
884
+ block_length=block_length,
885
+ remasking="low_confidence",
886
+ temperature=temperature,
887
+ mask_id=self.mask_token_id,
888
+ threshold=threshold,
889
+ shift_logits=shift_logits,
890
+ neg_entropy=False,
891
+ causal_context=causal_context,
892
+ eos_token_id=eos_token_id,
893
+ max_thinking_tokens=max_thinking_tokens,
894
+ end_think_token_id=end_think_token_id,
895
+ )
896
+
897
+ return out_ids, nfe
898
+
899
+
900
+ @torch.no_grad()
901
+ def sbd_inference_diffusion_quadratic(
902
+ self,
903
+ clean_input_ids: Optional[torch.Tensor],
904
+ draft_input_ids: torch.Tensor,
905
+ block_length: int,
906
+ draft_only: bool = False,
907
+ past_key_values: Optional[Cache] = None,
908
+ use_cache: bool = False,
909
+ ):
910
+ enc_config = self.encoder.config
911
+ enc_config.use_sbd_objective = True
912
+ enc_config.block_length = block_length
913
+
914
+ if draft_only:
915
+ assert clean_input_ids is not None
916
+
917
+ if use_cache and past_key_values is None:
918
+ past_key_values = DynamicCache()
919
+
920
+ enc_config.self_spec_inference_mode = "default"
921
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
922
+ outputs = self.encoder(
923
+ input_ids=input_ids,
924
+ position_ids=None,
925
+ past_key_values=past_key_values,
926
+ use_cache=use_cache,
927
+ is_training=False,
928
+ )
929
+
930
+ hidden_states = outputs.last_hidden_state
931
+ logits = self.diffusion_head(hidden_states)
932
+
933
+ past_key_values = getattr(outputs, "past_key_values", None)
934
+ if use_cache and past_key_values is not None:
935
+ _crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
936
+
937
+ return logits, past_key_values
938
+ else:
939
+ enc_config.self_spec_inference_mode = "quadratic"
940
+
941
+ draft_len = block_length * (block_length + 1)
942
+ draft_input_ids = torch.cat(
943
+ [
944
+ draft_input_ids.view(-1, block_length, 1),
945
+ torch.full(
946
+ (draft_input_ids.shape[0], block_length, block_length),
947
+ fill_value=self.config.mask_token_id,
948
+ device=draft_input_ids.device,
949
+ ),
950
+ ],
951
+ dim=-1,
952
+ ).view(-1, draft_len)
953
+
954
+ if use_cache:
955
+ assert past_key_values is not None, (
956
+ "Past key values should be provided when using cache, e.g. run draft_only=True first."
957
+ )
958
+ assert clean_input_ids is None, (
959
+ "Clean input ids should already be in cache, thus none should be provided."
960
+ )
961
+ clean_len = past_key_values.get_seq_length()
962
+ input_ids = draft_input_ids
963
+ else:
964
+ clean_len = clean_input_ids.shape[1]
965
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
966
+
967
+ per_block_position_ids = torch.arange(
968
+ clean_len, clean_len + block_length + 1, device=draft_input_ids.device
969
+ )[None,].repeat(block_length, 1)
970
+ per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
971
+
972
+ if use_cache:
973
+ position_ids = per_block_position_ids.view(-1)[None,]
974
+ else:
975
+ clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
976
+ position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
977
+
978
+ outputs = self.encoder(
979
+ input_ids=input_ids,
980
+ position_ids=position_ids,
981
+ past_key_values=past_key_values,
982
+ use_cache=use_cache,
983
+ is_training=False,
984
+ )
985
+
986
+ hidden_states = outputs.last_hidden_state
987
+ logits = self.diffusion_head(hidden_states)
988
+ past_key_values = getattr(outputs, "past_key_values", None)
989
+
990
+ if use_cache and past_key_values is not None:
991
+ _extract_draft_kv_cache(past_key_values, clean_len, block_length)
992
+
993
+ return logits, past_key_values
994
+
995
+
996
+ @torch.no_grad()
997
+ def ar_generate(
998
+ self,
999
+ prompt_ids: torch.Tensor,
1000
+ max_new_tokens: int = 128,
1001
+ temperature: float = 0.0,
1002
+ eos_token_id: Optional[int] = None,
1003
+ max_thinking_tokens: Optional[int] = None,
1004
+ end_think_token_id: Optional[int] = None,
1005
+ ) -> tuple:
1006
+ """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
1007
+
1008
+ Bypasses MinistralDiffEncoderModel.forward() to avoid diffusion-specific
1009
+ code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
1010
+ position_ids, and use_cache so the KV cache and causal masking behave
1011
+ identically to MistralForCausalLM / vLLM.
1012
+
1013
+ Returns:
1014
+ (output_ids, nfe) where output_ids includes the prompt.
1015
+ """
1016
+ for layer in self.encoder.layers:
1017
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1018
+ layer.self_attn.diffusion_lm = False
1019
+
1020
+ if eos_token_id is None:
1021
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
1022
+
1023
+ device = prompt_ids.device
1024
+ batch_size, prompt_len = prompt_ids.shape
1025
+
1026
+ past_key_values = DynamicCache()
1027
+ cache_position = torch.arange(prompt_len, device=device)
1028
+ position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
1029
+
1030
+ enc_out = self.encoder(
1031
+ input_ids=prompt_ids,
1032
+ position_ids=position_ids,
1033
+ past_key_values=past_key_values,
1034
+ use_cache=True,
1035
+ cache_position=cache_position,
1036
+ )
1037
+ past_key_values = enc_out.past_key_values
1038
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1039
+
1040
+ generated_tokens = []
1041
+ nfe = 0
1042
+
1043
+ for step in range(max_new_tokens):
1044
+ nfe += 1
1045
+
1046
+ if temperature > 0:
1047
+ probs = torch.softmax(next_logit / temperature, dim=-1)
1048
+ next_token = torch.multinomial(probs, num_samples=1)
1049
+ else:
1050
+ next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
1051
+
1052
+ # ---- thinking budget enforcement ----
1053
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1054
+ if step >= max_thinking_tokens:
1055
+ if generated_tokens:
1056
+ gen_tensor = torch.cat(generated_tokens, dim=1)
1057
+ has_end_think = (gen_tensor == end_think_token_id).any(dim=1)
1058
+ else:
1059
+ has_end_think = torch.zeros(batch_size, dtype=torch.bool, device=device)
1060
+ for b in range(batch_size):
1061
+ if not has_end_think[b]:
1062
+ next_token[b] = end_think_token_id
1063
+
1064
+ generated_tokens.append(next_token)
1065
+
1066
+ if eos_token_id is not None and (next_token == eos_token_id).all():
1067
+ break
1068
+
1069
+ if step < max_new_tokens - 1:
1070
+ cur_pos = prompt_len + step
1071
+ step_cache_pos = torch.tensor([cur_pos], device=device)
1072
+ step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
1073
+
1074
+ enc_out = self.encoder(
1075
+ input_ids=next_token,
1076
+ position_ids=step_pos_ids,
1077
+ past_key_values=past_key_values,
1078
+ use_cache=True,
1079
+ cache_position=step_cache_pos,
1080
+ )
1081
+ past_key_values = enc_out.past_key_values
1082
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1083
+
1084
+ all_generated = torch.cat(generated_tokens, dim=1)
1085
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1086
+ return output_ids, nfe
1087
+
1088
+
1089
+ @torch.no_grad()
1090
+ def self_spec_generate(
1091
+ self,
1092
+ prompt_ids: torch.Tensor,
1093
+ max_new_tokens: int = 128,
1094
+ steps: int = 128,
1095
+ block_length: int = 16,
1096
+ ar_mix_weight: Optional[float] = None,
1097
+ temperature: float = 0.0,
1098
+ mask_token_id: Optional[int] = None,
1099
+ eos_token_id: Optional[int] = None,
1100
+ max_thinking_tokens: Optional[int] = None,
1101
+ end_think_token_id: Optional[int] = None,
1102
+ ):
1103
+ self.config.use_sbd_objective = True
1104
+ self.config.dlm_paradigm = "sbd"
1105
+
1106
+ if prompt_ids.shape[0] != 1:
1107
+ raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
1108
+
1109
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1110
+ if eos_token_id is None:
1111
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1112
+
1113
+ x = torch.full(
1114
+ (1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
1115
+ token_mask_id,
1116
+ dtype=torch.long,
1117
+ device=prompt_ids.device,
1118
+ )
1119
+ x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
1120
+
1121
+ if max_new_tokens % block_length != 0:
1122
+ raise ValueError("max_new_tokens must be divisible by block_length")
1123
+ num_blocks = max_new_tokens // block_length
1124
+ if steps % num_blocks != 0:
1125
+ raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
1126
+
1127
+ prompt_len = prompt_ids.shape[1]
1128
+ nfe = 0
1129
+ nfe += 1
1130
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1131
+ clean_input_ids=x[:, :prompt_len],
1132
+ draft_input_ids=x[:, prompt_len : prompt_len + block_length],
1133
+ block_length=block_length,
1134
+ draft_only=True,
1135
+ use_cache=True,
1136
+ )
1137
+
1138
+ logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
1139
+ logits_proposal[:, 1] = logits_proposal[:, 0]
1140
+ logits_proposal = logits_proposal[:, 1:]
1141
+ x0_proposal = torch.argmax(logits_proposal, dim=-1)
1142
+ x[:, prompt_len : prompt_len + block_length] = x0_proposal
1143
+
1144
+ total_accept_token = 0
1145
+ while True:
1146
+ nfe += 1
1147
+ block_start = prompt_len + total_accept_token
1148
+ block_end = block_start + block_length
1149
+ draft_input_ids = x[:, block_start:block_end]
1150
+
1151
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1152
+ clean_input_ids=None,
1153
+ draft_input_ids=draft_input_ids,
1154
+ block_length=block_length,
1155
+ draft_only=False,
1156
+ past_key_values=past_key_values,
1157
+ use_cache=True,
1158
+ )
1159
+
1160
+ useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1161
+ if ar_mix_weight is None:
1162
+ useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1163
+ else:
1164
+ if not (0.0 <= ar_mix_weight <= 1.0):
1165
+ raise ValueError("ar_mix_weight must be between 0 and 1")
1166
+ mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
1167
+ useful_token_logits[:, :, 0] = mix_logits
1168
+ useful_token_logits[:, :, 1] = mix_logits
1169
+
1170
+ if temperature > 0:
1171
+ useful_token_logits = useful_token_logits / temperature
1172
+
1173
+ useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
1174
+ new_draft_input_ids = useful_token_pred[:, 0, 1:]
1175
+ accept_cnt = 1
1176
+
1177
+ while accept_cnt < block_length:
1178
+ if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
1179
+ break
1180
+ new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
1181
+ accept_cnt += 1
1182
+
1183
+ x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
1184
+
1185
+ # EoS early stopping: all accepted tokens are finalized left-to-right,
1186
+ # so if any is EoS we can truncate and return immediately.
1187
+ if eos_token_id is not None:
1188
+ accepted = x[0, block_start : block_start + accept_cnt]
1189
+ eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
1190
+ if len(eos_positions) > 0:
1191
+ first_eos_rel = eos_positions[0].item()
1192
+ total_accept_token += first_eos_rel + 1
1193
+ output_end = prompt_len + total_accept_token
1194
+ return x[:, :output_end], nfe
1195
+
1196
+ x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1197
+ past_key_values.crop(block_start + accept_cnt)
1198
+
1199
+ # ---- thinking budget enforcement ----
1200
+ # Insert end_think as the first token of the next draft block,
1201
+ # shifting all subsequent tokens right by 1 (discarding the last).
1202
+ # The first draft token is always accepted unconditionally, so
1203
+ # end_think is guaranteed to be finalized in the next iteration
1204
+ # without needing to re-encode or touch the KV cache.
1205
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1206
+ tokens_so_far = total_accept_token + accept_cnt
1207
+ if tokens_so_far > max_thinking_tokens:
1208
+ gen_so_far = x[0, prompt_len : prompt_len + tokens_so_far]
1209
+ has_end_think = (gen_so_far == end_think_token_id).any()
1210
+ if not has_end_think:
1211
+ insert_pos = block_start + accept_cnt
1212
+ x[0, insert_pos + 1:] = x[0, insert_pos:-1].clone()
1213
+ x[0, insert_pos] = end_think_token_id
1214
+
1215
+ total_accept_token += accept_cnt
1216
+
1217
+ if total_accept_token >= max_new_tokens:
1218
+ break
1219
+
1220
+ return x[:, : -(block_length * 2)], nfe
1221
+
1222
+
1223
+ @torch.no_grad()
1224
+ def linear_spec_generate(
1225
+ self,
1226
+ prompt_ids: torch.Tensor,
1227
+ max_new_tokens: int = 128,
1228
+ block_length: int = 32,
1229
+ temperature: float = 0.0,
1230
+ mask_token_id: Optional[int] = None,
1231
+ eos_token_id: Optional[int] = None,
1232
+ max_thinking_tokens: Optional[int] = None,
1233
+ end_think_token_id: Optional[int] = None,
1234
+ threshold: float = 0.0,
1235
+ ):
1236
+ """Linear speculative decoding: diffusion draft + AR verification.
1237
+
1238
+ Each step:
1239
+ 1. Draft: forward [last_accepted, mask, ...] with bidirectional attention
1240
+ (diffusion_lm=True, use_cache=False). Shift AR logits to get
1241
+ per-position predictions; apply confidence filtering.
1242
+ 2. Verify: forward the drafted block with causal attention
1243
+ (diffusion_lm=False, use_cache=True, use_causal_mask=True).
1244
+ Accept consecutive AR-matching tokens plus one bonus token.
1245
+
1246
+ Args:
1247
+ prompt_ids: Input token IDs of shape (1, prompt_len).
1248
+ max_new_tokens: Maximum number of tokens to generate.
1249
+ block_length: Number of tokens per draft/verify block.
1250
+ temperature: Sampling temperature (0 = greedy).
1251
+ mask_token_id: Override for config.mask_token_id.
1252
+ eos_token_id: Override for config.eos_token_id.
1253
+ max_thinking_tokens: Budget for thinking tokens before forcing end_think.
1254
+ end_think_token_id: Token ID inserted when thinking budget is exceeded.
1255
+ threshold: Confidence threshold for accepting draft predictions.
1256
+
1257
+ Returns:
1258
+ (output_ids, nfe): output_ids includes the prompt; nfe is the number
1259
+ of forward evaluations (matching self_spec_generate interface).
1260
+ """
1261
+ if prompt_ids.shape[0] != 1:
1262
+ raise ValueError("Linear speculative decoding requires batch_size == 1")
1263
+
1264
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1265
+ if eos_token_id is None:
1266
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1267
+
1268
+ device = prompt_ids.device
1269
+ prompt_len = prompt_ids.shape[1]
1270
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1271
+
1272
+ def _set_diffusion_lm(val: bool):
1273
+ for layer in self.encoder.layers:
1274
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1275
+ layer.self_attn.diffusion_lm = val
1276
+
1277
+ # ===== Prefill (causal) =====
1278
+ _set_diffusion_lm(False)
1279
+
1280
+ enc_out = self.encoder(
1281
+ input_ids=prompt_ids,
1282
+ past_key_values=DynamicCache(),
1283
+ use_cache=True,
1284
+ use_causal_mask=True,
1285
+ )
1286
+ past_key_values = enc_out.past_key_values
1287
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1288
+ nfe = 1
1289
+
1290
+ if temperature > 0:
1291
+ probs = torch.softmax(last_logit / temperature, dim=-1)
1292
+ next_token = torch.multinomial(probs, num_samples=1)
1293
+ else:
1294
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1295
+
1296
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1297
+ output_ids = torch.cat([prompt_ids, next_token], dim=1)
1298
+ return output_ids, nfe
1299
+
1300
+ generated = [next_token]
1301
+ total_gen = 1
1302
+
1303
+ # ===== Main loop =====
1304
+ while total_gen < max_new_tokens:
1305
+ cache_len = past_key_values.get_seq_length()
1306
+
1307
+ block = torch.full(
1308
+ (1, block_length), token_mask_id, dtype=torch.long, device=device
1309
+ )
1310
+ block[0, 0] = next_token.item()
1311
+
1312
+ # -------- Draft (bidirectional, don't update cache) --------
1313
+ _set_diffusion_lm(True)
1314
+ while True:
1315
+ is_mask = block == token_mask_id
1316
+ if not is_mask.any():
1317
+ break
1318
+
1319
+ enc_out = self.encoder(
1320
+ input_ids=block,
1321
+ past_key_values=past_key_values,
1322
+ use_cache=False,
1323
+ )
1324
+ nfe += 1
1325
+
1326
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1327
+ if dream_style:
1328
+ # DREAM: logit[i] predicts position i+1 → shift to self-prediction
1329
+ draft_logits = torch.cat(
1330
+ [draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1
1331
+ )
1332
+ # LLaDA: logit[i] already predicts position i → no shift needed
1333
+
1334
+ if temperature > 0:
1335
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1336
+ draft_tokens = torch.multinomial(
1337
+ draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
1338
+ ).view(1, block_length)
1339
+ else:
1340
+ draft_tokens = draft_logits.argmax(dim=-1)
1341
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1342
+
1343
+ if threshold > 0:
1344
+ draft_conf = torch.gather(
1345
+ draft_probs, -1, draft_tokens.unsqueeze(-1)
1346
+ ).squeeze(-1)
1347
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1348
+ unmask = draft_conf >= threshold
1349
+
1350
+ # Ensure each iteration makes progress even when every masked
1351
+ # position falls below the confidence threshold.
1352
+ if not unmask.any():
1353
+ best_idx = draft_conf.view(-1).argmax()
1354
+ unmask = torch.zeros_like(is_mask, dtype=torch.bool)
1355
+ unmask.view(-1)[best_idx] = True
1356
+
1357
+ block[unmask] = draft_tokens[unmask]
1358
+ else:
1359
+ block[is_mask] = draft_tokens[is_mask]
1360
+ break
1361
+
1362
+ # -------- Verify (causal, update cache) --------
1363
+ _set_diffusion_lm(False)
1364
+ enc_out = self.encoder(
1365
+ input_ids=block,
1366
+ past_key_values=past_key_values,
1367
+ use_cache=True,
1368
+ use_causal_mask=True,
1369
+ )
1370
+ past_key_values = enc_out.past_key_values
1371
+ nfe += 1
1372
+
1373
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1374
+ if temperature > 0:
1375
+ verify_probs = torch.softmax(verify_logits / temperature, dim=-1)
1376
+ ar_tokens = torch.multinomial(
1377
+ verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1
1378
+ ).view(1, block_length)
1379
+ else:
1380
+ ar_tokens = verify_logits.argmax(dim=-1)
1381
+
1382
+ accepted = 0
1383
+ for i in range(block_length - 1):
1384
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1385
+ accepted += 1
1386
+ else:
1387
+ break
1388
+ accepted += 1 # bonus token from AR verification
1389
+
1390
+ accepted_toks = ar_tokens[:, :accepted]
1391
+ generated.append(accepted_toks)
1392
+ total_gen += accepted
1393
+
1394
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1395
+
1396
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1397
+
1398
+ # -------- EOS check --------
1399
+ if eos_token_id is not None:
1400
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1401
+ if len(eos_pos) > 0:
1402
+ first_eos = eos_pos[0].item()
1403
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1404
+ total_gen = total_gen - accepted + first_eos + 1
1405
+ break
1406
+
1407
+ # -------- Thinking budget enforcement --------
1408
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1409
+ if total_gen > max_thinking_tokens:
1410
+ all_gen = torch.cat(generated, dim=1)
1411
+ if not (all_gen == end_think_token_id).any():
1412
+ next_token = torch.tensor(
1413
+ [[end_think_token_id]], device=device
1414
+ )
1415
+
1416
+ if total_gen >= max_new_tokens:
1417
+ break
1418
+
1419
+ all_generated = torch.cat(generated, dim=1)
1420
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1421
+
1422
+ return output_ids, nfe
1423
+
1424
+
1425
+ @torch.no_grad()
1426
+ def linear_spec_generate_mp(
1427
+ self,
1428
+ prompt_ids: torch.Tensor,
1429
+ max_new_tokens: int = 512,
1430
+ block_length: int = 32,
1431
+ temperature: float = 0.0,
1432
+ mask_token_id: Optional[int] = None,
1433
+ eos_token_id: Optional[int] = None,
1434
+ max_paths: int = 16,
1435
+ uncertain_threshold: float = 0.7,
1436
+ top_k_candidates: int = 2,
1437
+ threshold: float = 0.0,
1438
+ max_thinking_tokens: Optional[int] = None,
1439
+ end_think_token_id: Optional[int] = None,
1440
+ ):
1441
+ """Linear speculative decoding with multi-path tree verification.
1442
+
1443
+ Self-contained method — no external file dependencies beyond the model itself.
1444
+
1445
+ Each iteration costs 2 NFE (1 draft + 1 verify):
1446
+ 1. Draft: single-step bidirectional diffusion fills a block of masks.
1447
+ 2. Verify: tree-structured AR verification with multiple candidate paths.
1448
+
1449
+ Multi-path verification identifies low-confidence draft positions and
1450
+ explores top-k alternative tokens. All candidate paths share a trie
1451
+ prefix and are verified in one forward pass via a 4D tree-ancestry
1452
+ attention mask (~40 tokens), picking the path with the longest
1453
+ accepted prefix.
1454
+
1455
+ Benchmark results (NeMo Skills prompt, enable_thinking=False):
1456
+ GSM8K bl=32: +17.1% UW-TPF vs vanilla (acc 93.9%)
1457
+ MBPP bl=64: +17.8% UW-TPF vs vanilla (pass@1 78.2%)
1458
+
1459
+ Args:
1460
+ prompt_ids: (1, prompt_len) input token IDs.
1461
+ max_new_tokens: Maximum tokens to generate.
1462
+ block_length: Draft block size. Use 32 for math, 64 for code.
1463
+ temperature: Sampling temperature (0.0 = greedy).
1464
+ eos_token_id: Stop token ID.
1465
+ max_paths: Tree verification budget. 16 = up to 4 uncertain
1466
+ positions x 2 candidates each.
1467
+ uncertain_threshold: Confidence below which a position is
1468
+ considered uncertain and expanded with alternatives.
1469
+ top_k_candidates: Number of alternative tokens to try at each
1470
+ uncertain position.
1471
+
1472
+ Returns:
1473
+ output_ids: (1, prompt_len + generated_len) full sequence.
1474
+ nfe: Total number of forward evaluations.
1475
+ """
1476
+ from itertools import product as _product
1477
+
1478
+ if prompt_ids.shape[0] != 1:
1479
+ raise ValueError("Requires batch_size == 1")
1480
+
1481
+ device = prompt_ids.device
1482
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1483
+ if eos_token_id is None:
1484
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1485
+
1486
+ def _set_dlm(val: bool):
1487
+ for layer in self.encoder.layers:
1488
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1489
+ layer.self_attn.diffusion_lm = val
1490
+
1491
+ def _crop_cache(kv, length):
1492
+ for li in range(len(kv)):
1493
+ kv.key_cache[li] = kv.key_cache[li][:, :, :length]
1494
+ kv.value_cache[li] = kv.value_cache[li][:, :, :length]
1495
+ kv._seen_tokens = length
1496
+
1497
+ # ----- tree verify helpers (inlined) -----
1498
+
1499
+ def _mp_verify(block, draft_probs, draft_conf, past_kv, cache_len):
1500
+ """Multi-path verify via batch-stacking (flash-attention compatible).
1501
+
1502
+ Unlike tree attention (4D mask), batch-stacking expands the KV cache
1503
+ batch dimension and runs all candidate paths as separate batch entries.
1504
+ This keeps flash attention + GQA enabled, avoiding OOM from the 4D
1505
+ mask path which disables both.
1506
+
1507
+ Returns (accepted_toks, n_accepted, past_kv, next_tok) or None.
1508
+ """
1509
+ bl = block.shape[1]
1510
+
1511
+ # Identify uncertain positions
1512
+ is_filled = block[0] != token_mask_id
1513
+ pos_conf = torch.zeros(bl, device=device)
1514
+ pos_conf[0] = float('inf')
1515
+ for p in range(1, bl):
1516
+ if is_filled[p]:
1517
+ c = draft_conf[0, p].item()
1518
+ pos_conf[p] = c if c != float('-inf') else float('inf')
1519
+ else:
1520
+ pos_conf[p] = float('-inf')
1521
+
1522
+ unc_mask = (pos_conf < uncertain_threshold) & (pos_conf > float('-inf'))
1523
+ unc_pos = unc_mask.nonzero(as_tuple=True)[0].tolist()
1524
+ if not unc_pos:
1525
+ return None
1526
+
1527
+ import math as _math
1528
+ max_unc = min(len(unc_pos), max(1, int(_math.log2(max_paths))))
1529
+ unc_pos = sorted(unc_pos)[:max_unc]
1530
+
1531
+ # Build candidate blocks
1532
+ topk_at = {}
1533
+ for p in unc_pos:
1534
+ _, ids = draft_probs[0, p].topk(top_k_candidates)
1535
+ topk_at[p] = ids.tolist()
1536
+
1537
+ combos = list(_product(*(topk_at[p] for p in sorted(topk_at))))[:max_paths]
1538
+ num_paths = len(combos)
1539
+ if num_paths <= 1:
1540
+ return None
1541
+
1542
+ candidate_blocks = block.expand(num_paths, -1).clone()
1543
+ pos_list = sorted(topk_at.keys())
1544
+ for pi, combo in enumerate(combos):
1545
+ for ci, p in enumerate(pos_list):
1546
+ candidate_blocks[pi, p] = combo[ci]
1547
+
1548
+ # Expand KV cache batch dimension (shared, no copy)
1549
+ for li in range(len(past_kv.key_cache)):
1550
+ past_kv.key_cache[li] = past_kv.key_cache[li].expand(num_paths, -1, -1, -1)
1551
+ past_kv.value_cache[li] = past_kv.value_cache[li].expand(num_paths, -1, -1, -1)
1552
+
1553
+ # Batched causal verify — uses flash attention + GQA
1554
+ _set_dlm(False)
1555
+ enc_out = self.encoder(
1556
+ input_ids=candidate_blocks,
1557
+ past_key_values=past_kv,
1558
+ use_cache=True,
1559
+ use_causal_mask=True,
1560
+ )
1561
+ past_kv = enc_out.past_key_values
1562
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1563
+
1564
+ if temperature > 0:
1565
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1566
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(num_paths, bl)
1567
+ else:
1568
+ ar_tokens = vlogits.argmax(dim=-1)
1569
+
1570
+ # Find best path (longest accepted prefix)
1571
+ best_acc, best_pidx = 0, 0
1572
+ for pi in range(num_paths):
1573
+ acc = 0
1574
+ for i in range(bl - 1):
1575
+ if ar_tokens[pi, i].item() == candidate_blocks[pi, i + 1].item():
1576
+ acc += 1
1577
+ else:
1578
+ break
1579
+ acc += 1
1580
+ if acc > best_acc:
1581
+ best_acc, best_pidx = acc, pi
1582
+
1583
+ accepted_toks = ar_tokens[best_pidx:best_pidx+1, :best_acc]
1584
+
1585
+ # Extract winning path's KV cache slice
1586
+ for li in range(len(past_kv.key_cache)):
1587
+ past_kv.key_cache[li] = past_kv.key_cache[li][best_pidx:best_pidx+1].contiguous()
1588
+ past_kv.value_cache[li] = past_kv.value_cache[li][best_pidx:best_pidx+1].contiguous()
1589
+ _crop_cache(past_kv, cache_len + best_acc)
1590
+
1591
+ return accepted_toks, best_acc, past_kv, accepted_toks[:, -1:]
1592
+
1593
+ # ── Prefill (causal) ──
1594
+ _set_dlm(False)
1595
+ enc_out = self.encoder(
1596
+ input_ids=prompt_ids, past_key_values=DynamicCache(),
1597
+ use_cache=True, use_causal_mask=True,
1598
+ )
1599
+ past_key_values = enc_out.past_key_values
1600
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1601
+ nfe = 1
1602
+
1603
+ if temperature > 0:
1604
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), 1)
1605
+ else:
1606
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1607
+
1608
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1609
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1610
+
1611
+ generated = [next_token]
1612
+ total_gen = 1
1613
+
1614
+ # ── Main draft-verify loop ──
1615
+ while total_gen < max_new_tokens:
1616
+ cache_len = past_key_values.get_seq_length()
1617
+
1618
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1619
+ block[0, 0] = next_token.item()
1620
+
1621
+ # Draft: single-step bidirectional diffusion (1 NFE)
1622
+ _set_dlm(True)
1623
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1624
+ nfe += 1
1625
+
1626
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1627
+ if temperature > 0:
1628
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1629
+ draft_tokens = torch.multinomial(
1630
+ draft_probs.view(-1, draft_probs.shape[-1]), 1
1631
+ ).view(1, block_length)
1632
+ else:
1633
+ draft_tokens = draft_logits.argmax(dim=-1)
1634
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1635
+
1636
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1637
+ is_mask = block == token_mask_id
1638
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1639
+ block[is_mask] = draft_tokens[is_mask]
1640
+
1641
+ # Verify: multi-path batch-stacking (1 NFE, flash-attention compatible)
1642
+ result = _mp_verify(block, draft_probs, draft_conf, past_key_values, cache_len)
1643
+
1644
+ if result is not None:
1645
+ accepted_toks, accepted, past_key_values, next_token = result
1646
+ nfe += 1
1647
+ else:
1648
+ # No uncertain positions — single-path causal verify
1649
+ _set_dlm(False)
1650
+ enc_out = self.encoder(
1651
+ input_ids=block, past_key_values=past_key_values,
1652
+ use_cache=True, use_causal_mask=True,
1653
+ )
1654
+ past_key_values = enc_out.past_key_values
1655
+ nfe += 1
1656
+
1657
+ vlogits = self.diffusion_head(enc_out.last_hidden_state)
1658
+ if temperature > 0:
1659
+ vp = torch.softmax(vlogits / temperature, dim=-1)
1660
+ ar_tokens = torch.multinomial(vp.view(-1, vp.shape[-1]), 1).view(1, block_length)
1661
+ else:
1662
+ ar_tokens = vlogits.argmax(dim=-1)
1663
+
1664
+ accepted = 0
1665
+ for i in range(block_length - 1):
1666
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1667
+ accepted += 1
1668
+ else:
1669
+ break
1670
+ accepted += 1
1671
+
1672
+ accepted_toks = ar_tokens[:, :accepted]
1673
+ _crop_cache(past_key_values, cache_len + accepted)
1674
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1675
+
1676
+ generated.append(accepted_toks)
1677
+ total_gen += accepted
1678
+
1679
+ if eos_token_id is not None:
1680
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1681
+ if len(eos_pos) > 0:
1682
+ first_eos = eos_pos[0].item()
1683
+ generated[-1] = accepted_toks[:, :first_eos + 1]
1684
+ total_gen = total_gen - accepted + first_eos + 1
1685
+ break
1686
+
1687
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1688
+ if total_gen > max_thinking_tokens:
1689
+ all_gen = torch.cat(generated, dim=1)
1690
+ if not (all_gen == end_think_token_id).any():
1691
+ next_token = torch.tensor(
1692
+ [[end_think_token_id]], device=device
1693
+ )
1694
+
1695
+ if total_gen >= max_new_tokens:
1696
+ break
1697
+
1698
+ all_generated = torch.cat(generated, dim=1)
1699
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1700
+ return output_ids, nfe
1701
+
1702
+
1703
+ @torch.no_grad()
1704
+ def linear_spec_generate_lora(
1705
+ self,
1706
+ prompt_ids: torch.Tensor,
1707
+ max_new_tokens: int = 128,
1708
+ block_length: int = 32,
1709
+ temperature: float = 0.0,
1710
+ mask_token_id: Optional[int] = None,
1711
+ eos_token_id: Optional[int] = None,
1712
+ threshold: float = 0.0,
1713
+ rebuild_kv: str = 'none',
1714
+ max_thinking_tokens: Optional[int] = None,
1715
+ end_think_token_id: Optional[int] = None,
1716
+ ):
1717
+ """Linear speculative decoding: diffusion draft + AR verify.
1718
+ LoRA adapter toggling: ON for draft (bidirectional), OFF for verify (causal).
1719
+ Returns (output_ids, nfe).
1720
+ """
1721
+ if prompt_ids.shape[0] != 1:
1722
+ raise ValueError("linear_spec_generate requires batch_size == 1")
1723
+
1724
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1725
+ if eos_token_id is None:
1726
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1727
+
1728
+ device = prompt_ids.device
1729
+ dream_style = getattr(self.config, 'dlm_type', 'llada') == 'dream'
1730
+
1731
+ def _set_diffusion_lm(val: bool):
1732
+ for layer in self.encoder.layers:
1733
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1734
+ layer.self_attn.diffusion_lm = val
1735
+
1736
+ def _toggle_adapters(model, enable: bool):
1737
+ for module in model.modules():
1738
+ if hasattr(module, '_disable_adapters'):
1739
+ module._disable_adapters = not enable
1740
+
1741
+ # Prefill (causal, LoRA OFF)
1742
+ _set_diffusion_lm(False)
1743
+ _toggle_adapters(self, False)
1744
+ enc_out = self.encoder(
1745
+ input_ids=prompt_ids,
1746
+ past_key_values=DynamicCache(),
1747
+ use_cache=True,
1748
+ use_causal_mask=True,
1749
+ )
1750
+ past_key_values = enc_out.past_key_values
1751
+ last_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1752
+ nfe = 1
1753
+
1754
+ if temperature > 0:
1755
+ next_token = torch.multinomial(torch.softmax(last_logit / temperature, dim=-1), num_samples=1)
1756
+ else:
1757
+ next_token = torch.argmax(last_logit, dim=-1, keepdim=True)
1758
+
1759
+ if eos_token_id is not None and next_token.item() == eos_token_id:
1760
+ return torch.cat([prompt_ids, next_token], dim=1), nfe
1761
+
1762
+ generated = [next_token]
1763
+ total_gen = 1
1764
+
1765
+ while total_gen < max_new_tokens:
1766
+ cache_len = past_key_values.get_seq_length()
1767
+
1768
+ block = torch.full((1, block_length), token_mask_id, dtype=torch.long, device=device)
1769
+ block[0, 0] = next_token.item()
1770
+
1771
+ # Draft (bidirectional, LoRA ON)
1772
+ _set_diffusion_lm(True)
1773
+ _toggle_adapters(self, True)
1774
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=False)
1775
+ nfe += 1
1776
+
1777
+ draft_logits = self.diffusion_head(enc_out.last_hidden_state)
1778
+ if dream_style:
1779
+ draft_logits = torch.cat([draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1)
1780
+
1781
+ if temperature > 0:
1782
+ draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
1783
+ draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1).view(1, block_length)
1784
+ else:
1785
+ draft_tokens = draft_logits.argmax(dim=-1)
1786
+ draft_probs = torch.softmax(draft_logits, dim=-1)
1787
+
1788
+ draft_conf = torch.gather(draft_probs, -1, draft_tokens.unsqueeze(-1)).squeeze(-1)
1789
+ is_mask = block == token_mask_id
1790
+ draft_conf = torch.where(is_mask, draft_conf, -torch.inf)
1791
+ unmask = draft_conf > threshold
1792
+ if unmask.sum() > 0:
1793
+ block[unmask] = draft_tokens[unmask]
1794
+
1795
+ # Verify (causal, LoRA OFF)
1796
+ _set_diffusion_lm(False)
1797
+ _toggle_adapters(self, False)
1798
+ enc_out = self.encoder(input_ids=block, past_key_values=past_key_values, use_cache=True, use_causal_mask=True)
1799
+ past_key_values = enc_out.past_key_values
1800
+ nfe += 1
1801
+
1802
+ verify_logits = self.diffusion_head(enc_out.last_hidden_state)
1803
+ if temperature > 0:
1804
+ ar_tokens = torch.multinomial(torch.softmax(verify_logits / temperature, dim=-1).view(-1, verify_logits.shape[-1]), num_samples=1).view(1, block_length)
1805
+ else:
1806
+ ar_tokens = verify_logits.argmax(dim=-1)
1807
+
1808
+ accepted = 0
1809
+ for i in range(block_length - 1):
1810
+ if ar_tokens[0, i].item() == block[0, i + 1].item():
1811
+ accepted += 1
1812
+ else:
1813
+ break
1814
+ accepted += 1 # bonus token
1815
+
1816
+ accepted_toks = ar_tokens[:, :accepted]
1817
+ generated.append(accepted_toks)
1818
+ total_gen += accepted
1819
+
1820
+ _crop_dynamic_cache(past_key_values, cache_len + accepted)
1821
+ next_token = ar_tokens[:, accepted - 1 : accepted]
1822
+
1823
+ # EOS check
1824
+ if eos_token_id is not None:
1825
+ eos_pos = (accepted_toks[0] == eos_token_id).nonzero(as_tuple=True)[0]
1826
+ if len(eos_pos) > 0:
1827
+ first_eos = eos_pos[0].item()
1828
+ generated[-1] = accepted_toks[:, : first_eos + 1]
1829
+ total_gen = total_gen - accepted + first_eos + 1
1830
+ break
1831
+
1832
+ # Thinking budget enforcement
1833
+ if end_think_token_id is not None and max_thinking_tokens is not None:
1834
+ if total_gen > max_thinking_tokens:
1835
+ all_gen = torch.cat(generated, dim=1)
1836
+ if not (all_gen == end_think_token_id).any():
1837
+ next_token = torch.tensor([[end_think_token_id]], device=device)
1838
+
1839
+ if total_gen >= max_new_tokens:
1840
+ break
1841
+
1842
+ all_generated = torch.cat(generated, dim=1)
1843
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1844
+ return output_ids, nfe