YongganFu commited on
Commit
01c01ff
·
verified ·
1 Parent(s): f04b980

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
chat_utils.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ '''
8
+ The Gumbel max is a method for sampling categorical distributions.
9
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
+ Thus, we use float64.
11
+ '''
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
+ x0 = torch.argmax(logits_with_noise, dim=-1)
23
+
24
+ if remasking == 'low_confidence':
25
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
26
+ p = F.softmax(logits, dim=-1)
27
+ x0_p = torch.squeeze(
28
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
+ elif remasking == 'top_p_margin':
30
+ # Compute probabilities
31
+ p = F.softmax(logits, dim=-1) # (B, L, V)
32
+ # Top-2 per position
33
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
35
+
36
+ # Normalize margin to [0,1] over MASKED positions per row
37
+ plus_inf = torch.full_like(margin, float('inf'))
38
+ minus_inf = torch.full_like(margin, float('-inf'))
39
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
40
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
41
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
42
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
43
+ denom = (row_max - row_min)
44
+
45
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
46
+ normalized = torch.zeros_like(margin)
47
+ nonzero = denom > 0
48
+ normalized = torch.where(
49
+ mask_index & nonzero,
50
+ (margin - row_min) / (denom + 1e-12),
51
+ normalized
52
+ )
53
+ normalized = torch.where(
54
+ mask_index & (~nonzero),
55
+ torch.ones_like(normalized),
56
+ normalized
57
+ )
58
+ x0_p = normalized # ∈ [0,1] on masked positions
59
+ elif remasking == 'random':
60
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
61
+ else:
62
+ raise NotImplementedError(remasking)
63
+
64
+ # Calculate negative entropy if requested
65
+ if neg_entropy:
66
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
67
+ p = F.softmax(logits, dim=-1)
68
+ epsilon = 1e-10
69
+ log_probs = torch.log(p + epsilon)
70
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
71
+ else:
72
+ confidence_scores = x0_p
73
+
74
+ x0 = torch.where(mask_index, x0, x)
75
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
76
+
77
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
78
+ if threshold is not None:
79
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
80
+ # print(f'confidence: {confidence}')
81
+ for j in range(confidence.shape[0]):
82
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
83
+ transfer_index[j, select_index] = True
84
+ if threshold is not None:
85
+ for k in range(1, num_transfer_tokens[j]):
86
+ if confidence[j, select_index[k]] < threshold:
87
+ transfer_index[j, select_index[k]] = False
88
+ return x0, transfer_index
89
+
90
+
91
+ def get_num_transfer_tokens(mask_index, steps: int):
92
+ mask_num = mask_index.sum(dim=1, keepdim=True)
93
+ base = mask_num // steps
94
+ remainder = mask_num % steps
95
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
96
+ for i in range(mask_num.size(0)):
97
+ num_transfer_tokens[i, : int(remainder[i])] += 1
98
+ return num_transfer_tokens
99
+
100
+
101
+ @torch.no_grad()
102
+ def generate_with_prefix_cache_block_diff(
103
+ model,
104
+ prompt,
105
+ steps=128,
106
+ gen_length=128,
107
+ block_length=128,
108
+ temperature=0.,
109
+ remasking='low_confidence',
110
+ mask_id=126336,
111
+ threshold=None,
112
+ factor=None,
113
+ shift_logits=False,
114
+ neg_entropy=False,
115
+ causal_context=False,
116
+ eos_token_id=None,
117
+ ):
118
+ dream_style=shift_logits
119
+ # Initialize the accumulator
120
+ x_accum = prompt.clone()
121
+
122
+ assert gen_length % block_length == 0
123
+ num_blocks = gen_length // block_length
124
+
125
+ assert steps % num_blocks == 0
126
+ steps_per_block = steps // num_blocks
127
+
128
+ nfe = 0
129
+
130
+ if causal_context:
131
+ model_module = model.module if hasattr(model, "module") else model
132
+ for layer in model_module.encoder.layers:
133
+ if hasattr(layer.self_attn, 'diffusion_lm'):
134
+ layer.self_attn.diffusion_lm=False
135
+
136
+ # Compute KV cache for the prompt initially
137
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context)
138
+ past_key_values = output.past_key_values
139
+
140
+ if causal_context:
141
+ for layer in model_module.encoder.layers:
142
+ if hasattr(layer.self_attn, 'diffusion_lm'):
143
+ layer.self_attn.diffusion_lm=True
144
+
145
+ # For dream_style: store the "next token logit" of the context
146
+ next_logits_context = None
147
+ if dream_style:
148
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
149
+
150
+ for num_block in range(num_blocks):
151
+ # Create a new block with mask tokens (no seeding)
152
+ mask_block = torch.ones(
153
+ (prompt.shape[0], block_length),
154
+ dtype=prompt.dtype,
155
+ device=prompt.device
156
+ ) * mask_id
157
+
158
+ # Append the block of masks
159
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
160
+ current_block_start = prompt.size(1) + num_block * block_length
161
+ block_slice = slice(current_block_start, current_block_start + block_length)
162
+
163
+ # Build the initial mask for this block
164
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
165
+
166
+ # Precompute the transfer schedule for this block
167
+ if dream_style:
168
+ # still denoise *all* positions (0..Lb-1), since none are seeded
169
+ schedule_mask = mask_block_idx0
170
+ else:
171
+ schedule_mask = mask_block_idx0
172
+
173
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
174
+
175
+ # Denoise the current block
176
+ for i in range(steps_per_block):
177
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
178
+ if mask_block_idx.sum() == 0:
179
+ break
180
+
181
+ nfe += 1
182
+
183
+ # Forward only the current noisy block using cached context
184
+ logits_block = model(
185
+ x_accum[:, block_slice],
186
+ past_key_values=past_key_values,
187
+ use_cache=False
188
+ ).logits
189
+
190
+ if dream_style:
191
+ # Align logits so that each masked position has a predictor:
192
+ # prepend context-next logit, then use logits_block[:-1]
193
+ if block_length == 1:
194
+ logits_use = next_logits_context # (B, 1, V)
195
+ else:
196
+ logits_use = torch.cat(
197
+ [next_logits_context, logits_block[:, :-1, :]],
198
+ dim=1
199
+ ) # (B, Lb, V)
200
+
201
+ mask_use = mask_block_idx # (B, Lb)
202
+ x_use = x_accum[:, block_slice] # (B, Lb)
203
+
204
+ x0, transfer_idx = get_transfer_index(
205
+ logits_use, temperature, remasking, mask_use, x_use,
206
+ num_transfer_tokens=num_transfer_tokens[:, i],
207
+ threshold=threshold, neg_entropy=neg_entropy
208
+ )
209
+ cur = x_accum[:, block_slice].clone()
210
+ cur[transfer_idx] = x0[transfer_idx]
211
+ x_accum[:, block_slice] = cur
212
+
213
+ else:
214
+ # non-AR (same-position) case
215
+ x0, transfer_idx = get_transfer_index(
216
+ logits_block, temperature, remasking, mask_block_idx,
217
+ x_accum[:, block_slice],
218
+ num_transfer_tokens=num_transfer_tokens[:, i],
219
+ threshold=threshold, neg_entropy=neg_entropy
220
+ )
221
+ cur = x_accum[:, block_slice].clone()
222
+ cur[transfer_idx] = x0[transfer_idx]
223
+ x_accum[:, block_slice] = cur
224
+
225
+ if eos_token_id is not None:
226
+ block_tokens = x_accum[:, block_slice] # (B, Lb)
227
+ eos_mask = (block_tokens == eos_token_id) # (B, Lb)
228
+ any_eos = eos_mask.any(dim=1) # (B,)
229
+ if any_eos.any():
230
+ after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
231
+ mask_before = (block_tokens == mask_id) & ~after_eos
232
+ if (any_eos & ~mask_before.any(dim=1)).any():
233
+ break
234
+
235
+ if causal_context:
236
+ for layer in model_module.encoder.layers:
237
+ if hasattr(layer.self_attn, 'diffusion_lm'):
238
+ layer.self_attn.diffusion_lm=False
239
+
240
+ # after block is fully denoised, update KV cache
241
+ output = model(
242
+ x_accum[:, block_slice],
243
+ past_key_values=past_key_values,
244
+ use_cache=True,
245
+ use_causal_mask=causal_context
246
+ )
247
+ past_key_values = output.past_key_values
248
+
249
+ if causal_context:
250
+ for layer in model_module.encoder.layers:
251
+ if hasattr(layer.self_attn, 'diffusion_lm'):
252
+ layer.self_attn.diffusion_lm=True
253
+
254
+ if dream_style and num_block < num_blocks - 1:
255
+ # refresh context-next logit for the next block
256
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
257
+
258
+ if eos_token_id is not None:
259
+ gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
260
+ is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
261
+ has_eos = is_eos.any(dim=1) # (B,)
262
+ if has_eos.all():
263
+ first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
264
+ max_eos = first_eos_pos.max().item()
265
+ return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
266
+
267
+ return x_accum, nfe
config.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ada_dlm_loss_ratio": null,
3
+ "ada_perm_ratio_global": null,
4
+ "ada_perm_ratio_per_block": null,
5
+ "adaptive_mask_rate": false,
6
+ "ar_loss_weight": 1.0,
7
+ "architectures": [
8
+ "MinistralDiffEncoderModel"
9
+ ],
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "attn_implementation": "sdpa",
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
15
+ "AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
16
+ },
17
+ "block_size": 32,
18
+ "bos_token_id": 1,
19
+ "diff_loss_weight": 1,
20
+ "dlm_arch": "encoder",
21
+ "dlm_loss_weight": null,
22
+ "dlm_paradigm": "bidirectional",
23
+ "dlm_type": "llada",
24
+ "dp_varying_mask_ratio": false,
25
+ "enforce_mask": false,
26
+ "eos_token_id": 2,
27
+ "global_loss_avg": false,
28
+ "head_dim": 128,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 3072,
31
+ "initializer_range": 0.02,
32
+ "intermediate_size": 9216,
33
+ "mask_token_id": 100,
34
+ "max_position_embeddings": 262144,
35
+ "mlp_bias": false,
36
+ "model_type": "ministral_dlm",
37
+ "multi_sampling": null,
38
+ "num_ar_layers": 0,
39
+ "num_attention_heads": 32,
40
+ "num_diffusion_layers": 0,
41
+ "num_hidden_layers": 26,
42
+ "num_key_value_heads": 8,
43
+ "num_skip_loss_tokens": 0,
44
+ "prefix_ratio": 0.8,
45
+ "random_length_prob": 0,
46
+ "rms_norm_eps": 1e-05,
47
+ "rope_parameters": {
48
+ "beta_fast": 32.0,
49
+ "beta_slow": 1.0,
50
+ "factor": 16.0,
51
+ "llama_4_scaling_beta": 0.1,
52
+ "mscale": 1.0,
53
+ "mscale_all_dim": 1.0,
54
+ "original_max_position_embeddings": 16384,
55
+ "rope_theta": 1000000.0,
56
+ "rope_type": "yarn",
57
+ "type": "yarn"
58
+ },
59
+ "rope_scaling": {
60
+ "beta_fast": 32.0,
61
+ "beta_slow": 1.0,
62
+ "factor": 16.0,
63
+ "llama_4_scaling_beta": 0.1,
64
+ "mscale": 1.0,
65
+ "mscale_all_dim": 1.0,
66
+ "original_max_position_embeddings": 16384,
67
+ "rope_theta": 1000000.0,
68
+ "rope_type": "yarn",
69
+ "type": "yarn"
70
+ },
71
+ "rope_theta": 1000000.0,
72
+ "sliding_window": null,
73
+ "tie_word_embeddings": false,
74
+ "tok_mask_half_life_ratio": null,
75
+ "torch_dtype": "bfloat16",
76
+ "transformers_version": "4.55.4",
77
+ "use_cache": false,
78
+ "vocab_size": 131072
79
+ }
configuration_ministral_dlm.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Ministral DLM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MinistralDLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
+ It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 131072):
35
+ Vocabulary size of the Ministral model.
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 14336):
39
+ Dimension of the MLP representations.
40
+ num_hidden_layers (`int`, *optional*, defaults to 34):
41
+ Number of hidden layers in the Transformer decoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 32):
43
+ Number of attention heads for each attention layer.
44
+ num_key_value_heads (`int`, *optional*, defaults to 8):
45
+ Number of key_value heads for Grouped Query Attention.
46
+ head_dim (`int`, *optional*, defaults to 128):
47
+ The attention head dimension.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function.
50
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
51
+ The maximum sequence length.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions.
58
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
59
+ Whether the model's input and output word embeddings should be tied.
60
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
61
+ The base period of the RoPE embeddings.
62
+ rope_parameters (`Dict`, *optional*):
63
+ Dictionary containing the scaling configuration for the RoPE embeddings.
64
+ Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
65
+ attention_bias (`bool`, defaults to `False`):
66
+ Whether to use a bias in the query, key, value and output projection layers.
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ mlp_bias (`bool`, *optional*, defaults to `False`):
70
+ Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
+ sliding_window (`int`, *optional*, defaults to None):
72
+ Sliding window attention size.
73
+ mask_token_id (`int`, *optional*, defaults to -1):
74
+ Token ID for masking in diffusion.
75
+ dlm_type (`str`, *optional*, defaults to 'llada'):
76
+ Type of diffusion language model ('llada', 'dream').
77
+ random_length_prob (`float`, *optional*):
78
+ Probability of using random lengths during training.
79
+ num_ar_layers (`int`, *optional*, defaults to 0):
80
+ Number of autoregressive layers.
81
+ num_diffusion_layers (`int`, *optional*, defaults to 0):
82
+ Number of diffusion layers.
83
+ diff_loss_weight (`float`, *optional*, defaults to 1):
84
+ Weight for diffusion loss.
85
+ enforce_mask (`bool`, *optional*, defaults to False):
86
+ Whether to enforce masking.
87
+ prefix_ratio (`float`, *optional*, defaults to 0.8):
88
+ Ratio for prefix in prefix_bidirectional mode.
89
+ dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
90
+ Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
91
+ dlm_arch (`str`, *optional*, defaults to 'encoder'):
92
+ Architecture type ('encoder', 'encoder_decoder').
93
+ block_size (`int`, *optional*, defaults to 32):
94
+ Block size for block diffusion paradigms.
95
+ tok_mask_half_life_ratio (`float`, *optional*):
96
+ Half-life ratio for token masking.
97
+ adaptive_mask_rate (`bool`, *optional*, defaults to False):
98
+ Whether to use adaptive mask rate.
99
+ multi_sampling (`int`, *optional*):
100
+ Number of samples for multi-sampling.
101
+ num_skip_loss_tokens (`int`, *optional*, defaults to 0):
102
+ Number of tokens to skip in loss calculation.
103
+ dlm_loss_weight (`float`, *optional*):
104
+ Weight for diffusion LM loss.
105
+ ar_loss_weight (`float`, *optional*, defaults to 1.0):
106
+ Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
107
+ global_loss_avg (`bool`, *optional*, defaults to False):
108
+ Whether to use global loss average.
109
+ dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
110
+ Whether to use varying mask ratio for each DP rank during sampling.
111
+ ada_perm_ratio_per_block (`float`, *optional*):
112
+ Adaptive permutation ratio for each block.
113
+ ada_perm_ratio_global (`float`, *optional*):
114
+ Adaptive permutation ratio for global.
115
+ """
116
+
117
+ model_type = "ministral_dlm"
118
+ keys_to_ignore_at_inference = ["past_key_values"]
119
+
120
+ # Default tensor parallel plan for base model `Ministral`
121
+ base_model_tp_plan = {
122
+ "layers.*.self_attn.q_proj": "colwise",
123
+ "layers.*.self_attn.k_proj": "colwise",
124
+ "layers.*.self_attn.v_proj": "colwise",
125
+ "layers.*.self_attn.o_proj": "rowwise",
126
+ "layers.*.mlp.gate_proj": "colwise",
127
+ "layers.*.mlp.up_proj": "colwise",
128
+ "layers.*.mlp.down_proj": "rowwise",
129
+ }
130
+ base_model_pp_plan = {
131
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
132
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
133
+ "norm": (["hidden_states"], ["hidden_states"]),
134
+ }
135
+
136
+ def __init__(
137
+ self,
138
+ vocab_size=131072,
139
+ hidden_size=4096,
140
+ intermediate_size=14336,
141
+ num_hidden_layers=34,
142
+ num_attention_heads=32,
143
+ num_key_value_heads=8,
144
+ head_dim=128,
145
+ hidden_act="silu",
146
+ max_position_embeddings=262144,
147
+ initializer_range=0.02,
148
+ rms_norm_eps=1e-05,
149
+ use_cache=True,
150
+ pad_token_id=None,
151
+ bos_token_id=1,
152
+ eos_token_id=2,
153
+ tie_word_embeddings=False,
154
+ rope_theta=1000000.0,
155
+ rope_parameters=None,
156
+ rope_scaling=None,
157
+ attention_bias=False,
158
+ attention_dropout=0.0,
159
+ mlp_bias=False,
160
+ sliding_window=None,
161
+ attn_implementation="sdpa",
162
+ mask_token_id=-1,
163
+ dlm_type='llada',
164
+ random_length_prob=None,
165
+ num_ar_layers=0,
166
+ num_diffusion_layers=0,
167
+ diff_loss_weight=1,
168
+ enforce_mask=False,
169
+ prefix_ratio=0.8,
170
+ dlm_paradigm='bidirectional',
171
+ dlm_arch='encoder',
172
+ block_size=32,
173
+ tok_mask_half_life_ratio=None,
174
+ adaptive_mask_rate=False,
175
+ multi_sampling=None,
176
+ num_skip_loss_tokens=0,
177
+ dlm_loss_weight=None,
178
+ ar_loss_weight=1.0,
179
+ global_loss_avg=False,
180
+ dp_varying_mask_ratio=False,
181
+ ada_perm_ratio_per_block=None,
182
+ ada_perm_ratio_global=None,
183
+ ada_dlm_loss_ratio=None,
184
+ **kwargs,
185
+ ):
186
+ self.vocab_size = vocab_size
187
+ self.max_position_embeddings = max_position_embeddings
188
+ self.hidden_size = hidden_size
189
+ self.intermediate_size = intermediate_size
190
+ self.num_hidden_layers = num_hidden_layers
191
+ self.num_attention_heads = num_attention_heads
192
+
193
+ # for backward compatibility
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_attention_heads
196
+
197
+ self.num_key_value_heads = num_key_value_heads
198
+ self.head_dim = head_dim
199
+ self.hidden_act = hidden_act
200
+ self.initializer_range = initializer_range
201
+ self.rms_norm_eps = rms_norm_eps
202
+ self.use_cache = use_cache
203
+ self.rope_theta = rope_theta
204
+ self.rope_parameters = rope_parameters
205
+ self.rope_scaling = rope_scaling
206
+ self.attention_bias = attention_bias
207
+ self.attention_dropout = attention_dropout
208
+ self.mlp_bias = mlp_bias
209
+ self.sliding_window = sliding_window
210
+
211
+ rope_config_validation(self)
212
+
213
+ self.attn_implementation = attn_implementation
214
+
215
+ self.mask_token_id = mask_token_id
216
+ self.dlm_type = dlm_type
217
+ self.random_length_prob = random_length_prob
218
+ self.num_ar_layers = num_ar_layers
219
+ self.num_diffusion_layers = num_diffusion_layers
220
+ self.diff_loss_weight = diff_loss_weight
221
+ self.enforce_mask = enforce_mask
222
+ self.prefix_ratio = prefix_ratio
223
+ self.dlm_paradigm = dlm_paradigm
224
+ self.dlm_arch = dlm_arch
225
+ self.block_size = block_size
226
+ self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
227
+ self.adaptive_mask_rate = adaptive_mask_rate
228
+ self.multi_sampling = multi_sampling
229
+ self.num_skip_loss_tokens = num_skip_loss_tokens
230
+ self.dlm_loss_weight = dlm_loss_weight
231
+ self.ar_loss_weight = ar_loss_weight
232
+ self.global_loss_avg = global_loss_avg
233
+ self.dp_varying_mask_ratio = dp_varying_mask_ratio
234
+ self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
235
+ self.ada_perm_ratio_global = ada_perm_ratio_global
236
+ self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
237
+ super().__init__(
238
+ pad_token_id=pad_token_id,
239
+ bos_token_id=bos_token_id,
240
+ eos_token_id=eos_token_id,
241
+ tie_word_embeddings=tie_word_embeddings,
242
+ **kwargs,
243
+ )
244
+
245
+
246
+ __all__ = ["MinistralDLMConfig"]
247
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.55.4",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6510fd1ca747463dc8077e1b144f9bc3a1d64c45331d9fff945a00d914cf7cfc
3
+ size 7663347000
modeling_ministral.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from transformers.utils.generic import check_model_inputs
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_layers import (
17
+ GenericForQuestionAnswering,
18
+ GenericForSequenceClassification,
19
+ GenericForTokenClassification,
20
+ GradientCheckpointingLayer,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
+ # from transformers.utils.generic import maybe_autocast
28
+ from .configuration_ministral_dlm import MinistralDLMConfig
29
+
30
+ #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
+
32
+ def rotate_half(x):
33
+ """Rotates half the hidden dims of the input."""
34
+ x1 = x[..., : x.shape[-1] // 2]
35
+ x2 = x[..., x.shape[-1] // 2 :]
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+ # @use_kernel_func_from_hub("rotary_pos_emb")
39
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
40
+ """Applies Rotary Position Embedding to the query and key tensors.
41
+
42
+ Args:
43
+ q (`torch.Tensor`): The query tensor.
44
+ k (`torch.Tensor`): The key tensor.
45
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
46
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
47
+ position_ids (`torch.Tensor`, *optional*):
48
+ Deprecated and unused.
49
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
50
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
51
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
52
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
53
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
54
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
55
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
56
+ Returns:
57
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
58
+ """
59
+ cos = cos.unsqueeze(unsqueeze_dim)
60
+ sin = sin.unsqueeze(unsqueeze_dim)
61
+ q_embed = (q * cos) + (rotate_half(q) * sin)
62
+ k_embed = (k * cos) + (rotate_half(k) * sin)
63
+ return q_embed, k_embed
64
+
65
+
66
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
67
+ """
68
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
69
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
70
+ """
71
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
72
+ if n_rep == 1:
73
+ return hidden_states
74
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
75
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
76
+
77
+
78
+ def eager_attention_forward(
79
+ module: nn.Module,
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor],
84
+ scaling: float,
85
+ dropout: float = 0.0,
86
+ **kwargs: Unpack[TransformersKwargs],
87
+ ):
88
+ key_states = repeat_kv(key, module.num_key_value_groups)
89
+ value_states = repeat_kv(value, module.num_key_value_groups)
90
+
91
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
92
+ if attention_mask is not None:
93
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
94
+ attn_weights = attn_weights + causal_mask
95
+
96
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
97
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
98
+ attn_output = torch.matmul(attn_weights, value_states)
99
+ attn_output = attn_output.transpose(1, 2).contiguous()
100
+
101
+ return attn_output, attn_weights
102
+
103
+
104
+ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
105
+ scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
106
+ return scaling.unsqueeze(-1)
107
+
108
+
109
+ # @use_kernelized_func(apply_rotary_pos_emb)
110
+ class Ministral3Attention(nn.Module):
111
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
112
+
113
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
114
+ super().__init__()
115
+ self.config = config
116
+ self.layer_idx = layer_idx
117
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
118
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
119
+ self.scaling = self.head_dim**-0.5
120
+ self.attention_dropout = config.attention_dropout
121
+ self.is_causal = True
122
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
123
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
124
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
125
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
126
+
127
+ self.diffusion_lm = config.diffusion_lm
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.Tensor,
132
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
133
+ attention_mask: Optional[torch.Tensor],
134
+ past_key_values: Optional[Cache] = None,
135
+ cache_position: Optional[torch.LongTensor] = None,
136
+ use_cache: Optional[bool] = False,
137
+ **kwargs: Unpack[FlashAttentionKwargs],
138
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
139
+ input_shape = hidden_states.shape[:-1]
140
+ hidden_shape = (*input_shape, -1, self.head_dim)
141
+
142
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
143
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
144
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
145
+
146
+ cos, sin = position_embeddings
147
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
148
+ query_states = query_states * _get_llama_4_attn_scale(
149
+ cache_position,
150
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
151
+ self.config.rope_parameters.get("original_max_position_embeddings"),
152
+ ).to(query_states.dtype)
153
+
154
+ if past_key_values is not None:
155
+ if use_cache:
156
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
157
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
158
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
159
+ else: ## if use_cache == False, do not update cache
160
+ old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
161
+ key_states = torch.cat([old_k, key_states], dim=-2)
162
+ value_states = torch.cat([old_v, value_states], dim=-2)
163
+
164
+ attention_interface: Callable = eager_attention_forward
165
+ if self.config._attn_implementation != "eager":
166
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
167
+
168
+ if self.diffusion_lm:
169
+ attn_output, attn_weights = attention_interface(
170
+ self,
171
+ query_states,
172
+ key_states,
173
+ value_states,
174
+ None,
175
+ dropout=0.0 if not self.training else self.attention_dropout,
176
+ scaling=self.scaling,
177
+ is_causal=False,
178
+ **kwargs,
179
+ )
180
+
181
+ else:
182
+ attn_output, attn_weights = attention_interface(
183
+ self,
184
+ query_states,
185
+ key_states,
186
+ value_states,
187
+ attention_mask,
188
+ dropout=0.0 if not self.training else self.attention_dropout,
189
+ scaling=self.scaling,
190
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
191
+ **kwargs,
192
+ )
193
+
194
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
195
+ attn_output = self.o_proj(attn_output)
196
+ return attn_output, attn_weights
197
+
198
+
199
+ class Ministral3MLP(nn.Module):
200
+ def __init__(self, config):
201
+ super().__init__()
202
+ self.config = config
203
+ self.hidden_size = config.hidden_size
204
+ self.intermediate_size = config.intermediate_size
205
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
206
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
207
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
208
+ self.act_fn = ACT2FN[config.hidden_act]
209
+
210
+ def forward(self, x):
211
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
212
+ return down_proj
213
+
214
+
215
+ @use_kernel_forward_from_hub("RMSNorm")
216
+ class Ministral3RMSNorm(nn.Module):
217
+ def __init__(self, hidden_size, eps=1e-6):
218
+ """
219
+ Ministral3RMSNorm is equivalent to T5LayerNorm
220
+ """
221
+ super().__init__()
222
+ self.weight = nn.Parameter(torch.ones(hidden_size))
223
+ self.variance_epsilon = eps
224
+
225
+ def forward(self, hidden_states):
226
+ input_dtype = hidden_states.dtype
227
+ hidden_states = hidden_states.to(torch.float32)
228
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
229
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
230
+ return self.weight * hidden_states.to(input_dtype)
231
+
232
+ def extra_repr(self):
233
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
234
+
235
+
236
+ class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
238
+ super().__init__()
239
+ self.hidden_size = config.hidden_size
240
+
241
+ if hasattr(config, 'attn_class'):
242
+ attn_class = config.attn_class
243
+ else:
244
+ attn_class = Ministral3Attention
245
+
246
+ self.self_attn = attn_class(config=config, layer_idx=layer_idx)
247
+ self.mlp = Ministral3MLP(config)
248
+ self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
+ self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ position_ids: Optional[torch.LongTensor] = None,
256
+ past_key_values: Optional[Cache] = None,
257
+ use_cache: Optional[bool] = False,
258
+ cache_position: Optional[torch.LongTensor] = None,
259
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
260
+ **kwargs: Unpack[TransformersKwargs],
261
+ ) -> torch.Tensor:
262
+ residual = hidden_states
263
+ hidden_states = self.input_layernorm(hidden_states)
264
+ # Self Attention
265
+ hidden_states, _ = self.self_attn(
266
+ hidden_states=hidden_states,
267
+ attention_mask=attention_mask,
268
+ position_ids=position_ids,
269
+ past_key_values=past_key_values,
270
+ use_cache=use_cache,
271
+ cache_position=cache_position,
272
+ position_embeddings=position_embeddings,
273
+ **kwargs,
274
+ )
275
+ hidden_states = residual + hidden_states
276
+
277
+ # Fully Connected
278
+ residual = hidden_states
279
+ hidden_states = self.post_attention_layernorm(hidden_states)
280
+ hidden_states = self.mlp(hidden_states)
281
+ hidden_states = residual + hidden_states
282
+ return hidden_states
283
+
284
+
285
+ @auto_docstring
286
+ class Ministral3PreTrainedModel(PreTrainedModel):
287
+ config: MinistralDLMConfig
288
+ base_model_prefix = "model"
289
+ supports_gradient_checkpointing = True
290
+ _no_split_modules = ["Ministral3DecoderLayer"]
291
+ _skip_keys_device_placement = ["past_key_values"]
292
+ _supports_flash_attn = True
293
+ _supports_sdpa = True
294
+ _supports_flex_attn = True
295
+
296
+ _can_compile_fullgraph = True
297
+ _supports_attention_backend = True
298
+ _can_record_outputs = {
299
+ "hidden_states": Ministral3DecoderLayer,
300
+ "attentions": Ministral3Attention,
301
+ }
302
+
303
+
304
+ class Ministral3RotaryEmbedding(nn.Module):
305
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
306
+
307
+ def __init__(self, config: MinistralDLMConfig, device=None):
308
+ super().__init__()
309
+ self.max_seq_len_cached = config.max_position_embeddings
310
+ self.original_max_seq_len = config.max_position_embeddings
311
+
312
+ self.config = config
313
+
314
+ self.rope_type = self.config.rope_parameters["rope_type"]
315
+ rope_init_fn: Callable = self.compute_default_rope_parameters
316
+ if self.rope_type != "default":
317
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
318
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
319
+
320
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
321
+ self.original_inv_freq = inv_freq
322
+
323
+
324
+ @staticmethod
325
+ def compute_default_rope_parameters(
326
+ config: Optional[MinistralDLMConfig] = None,
327
+ device: Optional["torch.device"] = None,
328
+ seq_len: Optional[int] = None,
329
+ ) -> tuple["torch.Tensor", float]:
330
+ """
331
+ Computes the inverse frequencies according to the original RoPE implementation
332
+ Args:
333
+ config ([`~transformers.PreTrainedConfig`]):
334
+ The model configuration.
335
+ device (`torch.device`):
336
+ The device to use for initialization of the inverse frequencies.
337
+ seq_len (`int`, *optional*):
338
+ The current sequence length. Unused for this type of RoPE.
339
+ Returns:
340
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
341
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
342
+ """
343
+ base = config.rope_parameters["rope_theta"]
344
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
345
+
346
+ attention_factor = 1.0 # Unused in this type of RoPE
347
+
348
+ # Compute the inverse frequencies
349
+ inv_freq = 1.0 / (
350
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
351
+ )
352
+ return inv_freq, attention_factor
353
+
354
+ @torch.no_grad()
355
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
356
+ def forward(self, x, position_ids):
357
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
358
+ position_ids_expanded = position_ids[:, None, :].float()
359
+
360
+ # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
361
+ # with maybe_autocast(device_type=device_type, enabled=False): # Force float32
362
+
363
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
364
+ emb = torch.cat((freqs, freqs), dim=-1)
365
+ cos = emb.cos() * self.attention_scaling
366
+ sin = emb.sin() * self.attention_scaling
367
+
368
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
369
+
370
+
371
+ @auto_docstring
372
+ class Ministral3Model(Ministral3PreTrainedModel):
373
+ def __init__(self, config: MinistralDLMConfig):
374
+ super().__init__(config)
375
+ self.padding_idx = config.pad_token_id
376
+ self.vocab_size = config.vocab_size
377
+
378
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
379
+ self.layers = nn.ModuleList(
380
+ [Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
381
+ )
382
+ self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
383
+ self.rotary_emb = Ministral3RotaryEmbedding(config=config)
384
+ self.gradient_checkpointing = False
385
+
386
+ # Initialize weights and apply final processing
387
+ self.post_init()
388
+
389
+ @check_model_inputs
390
+ @auto_docstring
391
+ def forward(
392
+ self,
393
+ input_ids: Optional[torch.LongTensor] = None,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ position_ids: Optional[torch.LongTensor] = None,
396
+ past_key_values: Optional[Cache] = None,
397
+ inputs_embeds: Optional[torch.FloatTensor] = None,
398
+ use_cache: Optional[bool] = None,
399
+ cache_position: Optional[torch.LongTensor] = None,
400
+ **kwargs: Unpack[TransformersKwargs],
401
+ ) -> BaseModelOutputWithPast:
402
+ if (input_ids is None) ^ (inputs_embeds is not None):
403
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
404
+
405
+ if inputs_embeds is None:
406
+ inputs_embeds = self.embed_tokens(input_ids)
407
+
408
+ if use_cache and past_key_values is None:
409
+ # past_key_values = DynamicCache(config=self.config)
410
+ past_key_values = DynamicCache()
411
+
412
+ if cache_position is None:
413
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
414
+ cache_position = torch.arange(
415
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
416
+ )
417
+
418
+ if position_ids is None:
419
+ position_ids = cache_position.unsqueeze(0)
420
+
421
+ if kwargs.get("use_causal_mask", False):
422
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
+ causal_mask = mask_function(
424
+ config=self.config,
425
+ input_embeds=inputs_embeds,
426
+ attention_mask=attention_mask,
427
+ cache_position=cache_position,
428
+ past_key_values=past_key_values,
429
+ position_ids=position_ids,
430
+ )
431
+
432
+ else:
433
+ causal_mask = None
434
+
435
+ hidden_states = inputs_embeds
436
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
437
+
438
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
439
+ hidden_states = decoder_layer(
440
+ hidden_states,
441
+ attention_mask=causal_mask,
442
+ position_ids=position_ids,
443
+ past_key_values=past_key_values,
444
+ use_cache=use_cache,
445
+ cache_position=cache_position,
446
+ position_embeddings=position_embeddings,
447
+ **kwargs,
448
+ )
449
+ hidden_states = self.norm(hidden_states)
450
+ return BaseModelOutputWithPast(
451
+ last_hidden_state=hidden_states,
452
+ past_key_values=past_key_values if use_cache else None,
453
+ )
454
+
455
+
456
+ @auto_docstring
457
+ class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
458
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
459
+ _tp_plan = {"lm_head": "colwise_rep"}
460
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
461
+
462
+ def __init__(self, config):
463
+ super().__init__(config)
464
+ self.model = Ministral3Model(config)
465
+ self.vocab_size = config.vocab_size
466
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
467
+
468
+ # Initialize weights and apply final processing
469
+ self.post_init()
470
+
471
+ @can_return_tuple
472
+ @auto_docstring
473
+ def forward(
474
+ self,
475
+ input_ids: Optional[torch.LongTensor] = None,
476
+ attention_mask: Optional[torch.Tensor] = None,
477
+ position_ids: Optional[torch.LongTensor] = None,
478
+ past_key_values: Optional[Cache] = None,
479
+ inputs_embeds: Optional[torch.FloatTensor] = None,
480
+ labels: Optional[torch.LongTensor] = None,
481
+ use_cache: Optional[bool] = None,
482
+ cache_position: Optional[torch.LongTensor] = None,
483
+ logits_to_keep: Union[int, torch.Tensor] = 0,
484
+ **kwargs: Unpack[TransformersKwargs],
485
+ ) -> CausalLMOutputWithPast:
486
+ r"""
487
+ Example:
488
+
489
+ ```python
490
+ >>> from transformers import AutoTokenizer, Ministral3ForCausalLM
491
+
492
+ >>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
493
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
494
+
495
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
496
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
497
+
498
+ >>> # Generate
499
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
500
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
501
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
502
+ ```"""
503
+ outputs: BaseModelOutputWithPast = self.model(
504
+ input_ids=input_ids,
505
+ attention_mask=attention_mask,
506
+ position_ids=position_ids,
507
+ past_key_values=past_key_values,
508
+ inputs_embeds=inputs_embeds,
509
+ use_cache=use_cache,
510
+ cache_position=cache_position,
511
+ **kwargs,
512
+ )
513
+
514
+ hidden_states = outputs.last_hidden_state
515
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
516
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
517
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
518
+
519
+ loss = None
520
+ if labels is not None:
521
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
522
+
523
+ return CausalLMOutputWithPast(
524
+ loss=loss,
525
+ logits=logits,
526
+ past_key_values=outputs.past_key_values,
527
+ hidden_states=outputs.hidden_states,
528
+ attentions=outputs.attentions,
529
+ )
530
+
531
+
532
+ class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
533
+ pass
534
+
535
+
536
+ class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
537
+ pass
538
+
539
+
540
+ class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
541
+ pass
542
+
543
+
544
+ __all__ = [
545
+ "Ministral3ForCausalLM",
546
+ "Ministral3ForQuestionAnswering",
547
+ "Ministral3Model",
548
+ "Ministral3PreTrainedModel",
549
+ "Ministral3ForSequenceClassification",
550
+ "Ministral3ForTokenClassification",
551
+ ]
modeling_ministral_dlm.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, Tuple, Union
4
+ import random
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
17
+
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.generation import GenerationMixin
27
+
28
+ import math
29
+
30
+ from .chat_utils import generate_with_prefix_cache_block_diff
31
+ from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
+ from .configuration_ministral_dlm import MinistralDLMConfig
33
+
34
+
35
+ @dataclass
36
+ class MinistralDiffOutputWithPast(ModelOutput):
37
+ loss: torch.FloatTensor | None = None
38
+ logits: torch.FloatTensor | None = None
39
+ causal_logits: torch.FloatTensor | None = None
40
+ past_key_values: Cache | None = None
41
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
42
+ attentions: tuple[torch.FloatTensor, ...] | None = None
43
+
44
+
45
+ # @torch.compile(dynamic=True, mode="reduce-overhead")
46
+ # @torch.compile(mode="default")
47
+ # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
48
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
49
+ def fused_flex_attention(q, k, v, block_mask=None):
50
+ return flex_attention(q, k, v, block_mask=block_mask)
51
+
52
+
53
+ def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
54
+ """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
55
+ if hasattr(past_key_values, 'crop'):
56
+ past_key_values.crop(max_length)
57
+ else:
58
+ for layer_idx in range(len(past_key_values)):
59
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
60
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
61
+ past_key_values._seen_tokens = max_length
62
+
63
+
64
+ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
65
+ """After quadratic decoding, extract only draft tokens (first of each block) from cache."""
66
+ for layer_idx in range(len(past_key_values)):
67
+ if hasattr(past_key_values, 'layers'):
68
+ layer_cache = past_key_values.layers[layer_idx]
69
+ k, v = layer_cache.keys, layer_cache.values
70
+ else:
71
+ k = past_key_values.key_cache[layer_idx]
72
+ v = past_key_values.value_cache[layer_idx]
73
+
74
+ clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
75
+ clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
76
+ new_k = torch.cat([clean_k, draft_k], dim=2)
77
+ new_v = torch.cat([clean_v, draft_v], dim=2)
78
+
79
+ if hasattr(past_key_values, 'layers'):
80
+ layer_cache.keys = new_k
81
+ layer_cache.values = new_v
82
+ else:
83
+ past_key_values.key_cache[layer_idx] = new_k
84
+ past_key_values.value_cache[layer_idx] = new_v
85
+
86
+ past_key_values._seen_tokens = clean_len + block_length
87
+
88
+
89
+ # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
90
+ class MinistralFlexAttention(Ministral3Attention):
91
+ def __init__(self, *args, **kwargs):
92
+ super().__init__(*args, **kwargs)
93
+
94
+ self.block_size_orig = self.config.block_size
95
+
96
+ if self.config.dlm_paradigm == 'bidirectional':
97
+ self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
98
+ elif self.config.dlm_paradigm == 'autoregressive':
99
+ self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
100
+ elif self.config.dlm_paradigm == 'block_diff':
101
+ self.block_diff_mask = None
102
+ elif self.config.dlm_paradigm == 'sbd_block_diff':
103
+ self.sbd_block_diff_mask = None
104
+ else:
105
+ raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
106
+
107
+ self.block_size = self.block_size_orig
108
+ self.mode = self.config.dlm_paradigm
109
+ self._quadratic_block_mask = {}
110
+
111
+ import torch._dynamo.config as dcfg
112
+ dcfg.cache_size_limit = 512
113
+
114
+
115
+ def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
116
+ if block_length not in self._quadratic_block_mask:
117
+ draft_len = block_length * (block_length + 1)
118
+
119
+ def quadratic(b, h, q_idx, kv_idx):
120
+ first_clean = torch.logical_and(
121
+ kv_idx % (block_length + 1) == 0,
122
+ kv_idx < draft_len,
123
+ )
124
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
125
+ block_q = q_idx // (block_length + 1)
126
+ block_kv = kv_idx // (block_length + 1)
127
+ same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
128
+ same_block_except_first = torch.logical_and(
129
+ same_block,
130
+ q_idx % (block_length + 1) != 0,
131
+ )
132
+ draft_part = torch.logical_or(first_clean, same_block_except_first)
133
+ clean_part = kv_idx >= draft_len
134
+ return torch.logical_or(draft_part, clean_part)
135
+
136
+ block_mask = create_block_mask(
137
+ quadratic,
138
+ B=None,
139
+ H=None,
140
+ Q_LEN=draft_len,
141
+ KV_LEN=draft_len + self.config.max_position_embeddings,
142
+ device="cuda",
143
+ )
144
+
145
+ self._quadratic_block_mask[block_length] = block_mask
146
+
147
+ return self._quadratic_block_mask[block_length]
148
+
149
+
150
+ def set_attention_mode(self, mode, block_size=None):
151
+ self.mode = mode
152
+ self.block_size = block_size
153
+
154
+ def compute_block_mask(self, mode, q_len, block_size=None):
155
+
156
+ def bidirectional_mask(b, h, q, kv):
157
+ return (q >= kv) | (q < kv)
158
+
159
+ def autoregressive_mask(b, h, q, kv):
160
+ return (q >= kv)
161
+
162
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
163
+ """
164
+ Constructs the specialized block diffusion attention mask for training
165
+ composed of three masks:
166
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
167
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
168
+ - **Block Causal Mask (M_BC)**: Attention to update x0
169
+ Args:
170
+ b, h: Batch and head indices (ignored for mask logic).
171
+ q_idx, kv_idx: Query and Key indices.
172
+ seq_len: Total sequence length.
173
+ block_size: Defines the block structure.
174
+ Returns:
175
+ A boolean attention mask.
176
+ """
177
+
178
+ # Indicate whether token belongs to xt or x0
179
+ x0_flag_q = (q_idx >= n)
180
+ x0_flag_kv = (kv_idx >= n)
181
+
182
+ # Compute block indices
183
+ block_q = torch.where(x0_flag_q == 1,
184
+ (q_idx - n) // block_size,
185
+ q_idx // block_size)
186
+ block_kv = torch.where(x0_flag_kv == 1,
187
+ (kv_idx - n) // block_size,
188
+ kv_idx // block_size)
189
+
190
+ # **1. Block Diagonal Mask (M_BD) **
191
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
192
+
193
+ # **2. Offset Block-Causal Mask (M_OBC) **
194
+ offset_block_causal = (
195
+ (block_q > block_kv)
196
+ & (x0_flag_kv == 1)
197
+ & (x0_flag_q == 0)
198
+ )
199
+
200
+ # **3. Block-Causal Mask (M_BC) **
201
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
202
+
203
+ # **4. Combine Masks **
204
+ return block_diagonal | offset_block_causal | block_causal
205
+
206
+
207
+ def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
208
+ x0_flag_q = (q_idx >= n)
209
+ x0_flag_kv = (kv_idx >= n)
210
+
211
+ # Compute block indices
212
+ block_q = torch.where(x0_flag_q == 1,
213
+ (q_idx - n) // block_size,
214
+ q_idx // block_size)
215
+ block_kv = torch.where(x0_flag_kv == 1,
216
+ (kv_idx - n) // block_size,
217
+ kv_idx // block_size)
218
+
219
+ # **1. Block Diagonal Mask (M_BD) **
220
+ block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
221
+
222
+ # **2. Offset Block-Causal Mask (M_OBC) **
223
+ offset_block_causal = (
224
+ (block_q > block_kv)
225
+ & (x0_flag_kv == 1)
226
+ & (x0_flag_q == 0)
227
+ )
228
+
229
+ # **3. Fully Causal Mask (M_BC) **
230
+ fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
231
+
232
+ # **4. Combine Masks **
233
+ return block_diagonal | offset_block_causal | fully_causal
234
+
235
+ if mode == 'bidirectional':
236
+ attn_mask = bidirectional_mask
237
+ elif mode == 'autoregressive':
238
+ attn_mask = autoregressive_mask
239
+ elif mode == 'block_diff':
240
+ assert block_size is not None
241
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
242
+ elif mode == 'sbd_block_diff':
243
+ assert block_size is not None
244
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, q_len//2)
245
+ else:
246
+ raise ValueError(f"Unknown attention mode: {mode}")
247
+
248
+ block_mask = create_block_mask(
249
+ attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
250
+ )
251
+
252
+ return block_mask
253
+
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.Tensor,
258
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
259
+ attention_mask: Optional[torch.Tensor],
260
+ past_key_values: Optional[Cache] = None,
261
+ cache_position: Optional[torch.LongTensor] = None,
262
+ is_training: bool = True,
263
+ **kwargs: Unpack[FlashAttentionKwargs],
264
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
265
+ bsz, q_len, _ = hidden_states.size()
266
+ input_shape = hidden_states.shape[:-1]
267
+ hidden_shape = (*input_shape, -1, self.head_dim)
268
+
269
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
270
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
271
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
272
+
273
+ cos, sin = position_embeddings
274
+
275
+ if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
276
+ # Split query and key states in half along sequence length dimension
277
+ q1, q2 = query_states.chunk(2, dim=2)
278
+ k1, k2 = key_states.chunk(2, dim=2)
279
+
280
+ # Apply RoPE independently to each half
281
+ q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
282
+ q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
283
+
284
+ # Recombine the halves
285
+ query_states = torch.cat([q1, q2], dim=2)
286
+ key_states = torch.cat([k1, k2], dim=2)
287
+ else:
288
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
289
+
290
+ query_states = query_states * _get_llama_4_attn_scale(
291
+ cache_position,
292
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
293
+ self.config.rope_parameters.get("original_max_position_embeddings"),
294
+ ).to(query_states.dtype)
295
+
296
+ if past_key_values is not None:
297
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
298
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
299
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
300
+
301
+ tidar_inference_mode = getattr(self.config, "tidar_inference_mode", None)
302
+ if tidar_inference_mode is not None:
303
+ if tidar_inference_mode == "quadratic":
304
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
305
+ if block_length is None:
306
+ raise ValueError("SBD quadratic decoding requires block_length in config.")
307
+ if past_key_values is not None:
308
+ seq_len = key_states.shape[2]
309
+ draft_len = block_length * (block_length + 1)
310
+
311
+ clean_keys = key_states[:, :, :-draft_len]
312
+ draft_keys = key_states[:, :, -draft_len:]
313
+ clean_values = value_states[:, :, :-draft_len]
314
+ draft_values = value_states[:, :, -draft_len:]
315
+ key_states = torch.cat([draft_keys, clean_keys], dim=2)
316
+ value_states = torch.cat([draft_values, clean_values], dim=2)
317
+
318
+ block_mask: BlockMask = self._get_sbd_inference_quadratic_decoding_block_mask(
319
+ block_length=block_length
320
+ )
321
+ block_mask.seq_lengths = (draft_len, seq_len)
322
+ else:
323
+ seq_len = query_states.shape[2]
324
+ draft_len = block_length * (block_length + 1)
325
+ clean_len = seq_len - draft_len
326
+
327
+ def _causal_mask(b, h, q_idx, kv_idx):
328
+ return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
329
+
330
+ def _draft2clean_mask(b, h, q_idx, kv_idx):
331
+ full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
332
+ first_clean = torch.logical_and(
333
+ q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
334
+ )
335
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
336
+ return torch.logical_or(full_clean, first_clean)
337
+
338
+ def _draft_mask(b, h, q_idx, kv_idx):
339
+ block_q = (q_idx - clean_len) // (block_length + 1)
340
+ block_kv = (kv_idx - clean_len) // (block_length + 1)
341
+ quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
342
+ same_block = torch.logical_and(block_q == block_kv, quadrant)
343
+ same_block_except_first = torch.logical_and(
344
+ same_block,
345
+ (q_idx - clean_len) % (block_length + 1) != 0,
346
+ )
347
+ return torch.logical_and(block_q == block_kv, same_block_except_first)
348
+
349
+ mask = or_masks(_causal_mask, _draft2clean_mask)
350
+ mask = or_masks(mask, _draft_mask)
351
+
352
+ block_mask = create_block_mask(
353
+ mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
354
+ )
355
+
356
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
357
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
358
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
359
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
360
+ attn_output = self.o_proj(attn_output)
361
+ return attn_output, None
362
+
363
+ elif tidar_inference_mode == "default":
364
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
365
+ if block_length is None:
366
+ raise ValueError("SBD default decoding requires block_length in config.")
367
+ seq_len = query_states.shape[2]
368
+ prefix_len = seq_len - block_length
369
+
370
+ def _clean_q_mask(b, h, q_idx, kv_idx):
371
+ return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
372
+
373
+ def _noisy_q_mask(b, h, q_idx, kv_idx):
374
+ return q_idx >= prefix_len
375
+
376
+ block_mask = create_block_mask(
377
+ or_masks(_clean_q_mask, _noisy_q_mask),
378
+ B=None,
379
+ H=None,
380
+ Q_LEN=seq_len,
381
+ KV_LEN=seq_len,
382
+ )
383
+
384
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
385
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
386
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
387
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
388
+ attn_output = self.o_proj(attn_output)
389
+ return attn_output, None
390
+
391
+ else:
392
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
393
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
394
+
395
+ if self.mode == 'bidirectional':
396
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
397
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
398
+ else:
399
+ block_mask = self.bidirectional_mask
400
+
401
+ elif self.mode == 'autoregressive':
402
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
403
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
404
+ else:
405
+ block_mask = self.autoregressive_mask
406
+
407
+ elif self.mode == 'block_diff':
408
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
409
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
410
+ else:
411
+ block_mask = self.block_diff_mask
412
+ elif self.mode == 'sbd_block_diff':
413
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
414
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
415
+ else:
416
+ block_mask = self.sbd_block_diff_mask
417
+ else:
418
+ raise ValueError(f"Unknown attention mode: {self.mode}")
419
+
420
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
421
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
422
+
423
+ attn_output = self.o_proj(attn_output)
424
+
425
+ return attn_output, None
426
+
427
+
428
+ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
429
+ """Return a Bool mask of length len(log_w) with exactly k True."""
430
+ g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
431
+ topk = torch.topk(log_w + g, k).indices
432
+ mask = torch.zeros_like(log_w, dtype=torch.bool)
433
+ mask[topk] = True
434
+ return mask
435
+
436
+
437
+ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
438
+ """
439
+ A single model with:
440
+ - a bidirectional encoder + diffusion‐LM head over A
441
+ - a causal decoder + LM head over B, conditioned on F_A
442
+ """
443
+
444
+ def __init__(self, config: MinistralDLMConfig):
445
+ super().__init__(config)
446
+
447
+ self.mask_token_id = config.mask_token_id
448
+
449
+ diffusion_config = copy.deepcopy(config)
450
+ diffusion_config.diffusion_lm = True
451
+
452
+ if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
453
+ diffusion_config.attn_class = MinistralFlexAttention
454
+ elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
455
+ diffusion_config.attn_class = Ministral3Attention
456
+
457
+ if config.dlm_paradigm == 'autoregressive':
458
+ diffusion_config.diffusion_lm = False
459
+ else:
460
+ raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
461
+
462
+ self.encoder = Ministral3Model(diffusion_config)
463
+ self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
464
+ self.vocab_size = config.vocab_size
465
+
466
+ self.current_iter_ratio = None
467
+
468
+ self.post_init()
469
+
470
+
471
+ def get_input_embeddings(self):
472
+ return self.encoder.embed_tokens
473
+
474
+ def set_input_embeddings(self, value):
475
+ self.encoder.embed_tokens = value
476
+
477
+ def get_output_embeddings(self):
478
+ return self.diffusion_head
479
+
480
+ def set_output_embeddings(self, new_embeddings):
481
+ self.diffusion_head = new_embeddings
482
+
483
+
484
+ def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
485
+ b, l = input_ids.shape
486
+ device = input_ids.device
487
+
488
+ if self.config.dp_varying_mask_ratio:
489
+ # Enable different random seeds for each DP rank during sampling
490
+ import torch.distributed as dist
491
+ dp_rank = 0
492
+ if dist.is_initialized():
493
+ try:
494
+ dp_rank = dist.get_rank()
495
+ except Exception:
496
+ dp_rank = 0
497
+ # Use a local generator to avoid affecting global RNG state
498
+ generator = torch.Generator(device=device)
499
+ generator.manual_seed(torch.seed() + dp_rank)
500
+ else:
501
+ generator = None
502
+
503
+ if self.config.adaptive_mask_rate:
504
+ assert block_size is not None
505
+
506
+ # --- simple linear window mapping ---
507
+ bs_min = getattr(self.config, "t_bs_min", 16)
508
+ bs_max = getattr(self.config, "t_bs_max", 128)
509
+ w = getattr(self.config, "t_window_width", 0.6) # fixed width
510
+
511
+ # fraction in [0,1] (unclamped first)
512
+ frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
513
+ # upper bound decreases linearly from 1.0 -> 0.5
514
+ u_max = 1.0 - w * frac
515
+ # clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
516
+ u_max = max(0.6, min(1.0, u_max))
517
+ u_min = u_max - w # ensures width = w
518
+
519
+ # sample t ~ Uniform(u_min, u_max)
520
+ t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
521
+ else:
522
+ t = torch.rand(b, device=device, generator=generator)
523
+
524
+ p_mask = (1 - eps) * t + eps # shape: (b,)
525
+ p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
526
+
527
+ masked_indices = torch.rand((b, l), device=device) < p_mask
528
+
529
+ if loss_mask is not None:
530
+ masked_indices[loss_mask == 0] = 0
531
+
532
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
533
+
534
+ return noisy_batch, masked_indices, p_mask
535
+
536
+
537
+ def forward_process_exp(
538
+ self,
539
+ input_ids: torch.Tensor,
540
+ eps: float = 1e-3,
541
+ block_size: int | None = None,
542
+ half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
543
+ loss_mask: Optional[torch.Tensor] = None,
544
+ ):
545
+ """
546
+ Two-stage corruption with optional per-block sampling.
547
+ • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
548
+ • Stage 2: sample exactly k positions with weights
549
+ w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
550
+ uniform when m→1).
551
+ If `block_size` is given, the procedure is run *independently*
552
+ inside each contiguous block of that length (last block may be shorter).
553
+ When block_size is provided, m is sampled per-block and p_mask is per-block.
554
+ Args
555
+ ----
556
+ input_ids : (B, L) LongTensor
557
+ eps : minimum corruption ratio
558
+ block_size: if not None, operate block-wise with per-block m sampling
559
+ half_life_ratio : controls steepness when m→0
560
+ """
561
+ B, L = input_ids.shape
562
+ device = input_ids.device
563
+ dtype = torch.float32
564
+
565
+ masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
566
+ p_mask = torch.zeros((B, L), dtype=dtype, device=device)
567
+
568
+ # ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
569
+ for b in range(B):
570
+ if block_size is None:
571
+ # ---------- Per-batch sampling (original behavior) ----------
572
+ m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
573
+ k_tot = int(round(m * L))
574
+ k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
575
+
576
+ # Fill p_mask for this batch
577
+ p_mask[b, :] = m
578
+
579
+ slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
580
+
581
+ # ------- single pool over the whole sentence -------------
582
+ lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
583
+
584
+ pos = torch.arange(L, device=device, dtype=dtype)
585
+ log_w = (lam_base * slope * pos).clone()
586
+
587
+ masked_indices[b] = gumbel_topk(log_w, k_tot)
588
+
589
+ else:
590
+ # ---------- Per-block sampling ----------
591
+ num_blocks = math.ceil(L / block_size)
592
+ lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
593
+
594
+ for blk in range(num_blocks):
595
+ start = blk * block_size
596
+ end = min((blk + 1) * block_size, L)
597
+ blk_len = end - start
598
+
599
+ # Sample m per block
600
+ m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
601
+
602
+ # Fill p_mask for this block
603
+ p_mask[b, start:end] = m_blk
604
+
605
+ # per-block budget
606
+ k_blk = int(round(m_blk * blk_len))
607
+ k_blk = max(0, min(k_blk, blk_len))
608
+ if k_blk == 0:
609
+ continue
610
+
611
+ slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
612
+
613
+ pos = torch.arange(blk_len, device=device, dtype=dtype)
614
+ log_w = lam_base * slope * pos
615
+
616
+ blk_mask = gumbel_topk(log_w, k_blk)
617
+ masked_indices[b, start:end] = blk_mask
618
+
619
+ if loss_mask is not None:
620
+ masked_indices[loss_mask == 0] = 0
621
+
622
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
623
+ return noisy_batch, masked_indices, p_mask
624
+
625
+
626
+ def forward(
627
+ self,
628
+ input_ids: torch.LongTensor,
629
+ attention_mask: Optional[torch.Tensor] = None,
630
+ position_ids: Optional[torch.LongTensor] = None,
631
+ labels: Optional[torch.LongTensor] = None,
632
+ split_len: Optional[int] = None,
633
+ past_key_values: Optional[Cache] = None,
634
+ block_size: Optional[int] = None,
635
+ block_diff_ppl: bool = False,
636
+ eps: float = 1e-3,
637
+ is_teacher: bool = False,
638
+ masked_indices: Optional[torch.Tensor] = None,
639
+ p_mask: Optional[torch.Tensor] = None,
640
+ teacher_logits: Optional[torch.Tensor] = None,
641
+ masked_indices_teacher: Optional[torch.Tensor] = None,
642
+ loss_mask: Optional[torch.Tensor] = None,
643
+ ce_loss_weight: float = 1.0,
644
+ output_last_hidden_states_only: bool = False,
645
+ skip_loss: bool = False,
646
+ **kwargs,
647
+ ) -> CausalLMOutputWithPast:
648
+
649
+ batch_size, seq_len = input_ids.shape
650
+
651
+ if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
652
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
653
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
654
+ input_ids = input_ids[:, :random_length]
655
+ labels = labels[:, :random_length]
656
+
657
+ if attention_mask is not None:
658
+ attention_mask = attention_mask[:, :random_length]
659
+ if position_ids is not None:
660
+ position_ids = position_ids[:, :random_length]
661
+ if loss_mask is not None:
662
+ loss_mask = loss_mask[:, :random_length]
663
+
664
+ elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
665
+ if labels is not None and block_size is None:
666
+ if torch.rand(1) < self.config.random_length_prob:
667
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
668
+ else:
669
+ block_size = self.config.block_size
670
+
671
+ else:
672
+ raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
673
+
674
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
675
+ if masked_indices is not None:
676
+ # assert p_mask is not None
677
+
678
+ if loss_mask is not None:
679
+ masked_indices[loss_mask == 0] = 0
680
+
681
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
682
+
683
+ else:
684
+ if self.config.tok_mask_half_life_ratio is not None:
685
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
686
+ else:
687
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
688
+
689
+ else:
690
+ noisy_inputs = input_ids
691
+ masked_indices = None
692
+ p_mask = None
693
+
694
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
695
+ for layer in self.encoder.layers:
696
+ if hasattr(layer.self_attn, 'set_attention_mode'):
697
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
698
+
699
+ input_ids_len = noisy_inputs.shape[1]
700
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
701
+ if position_ids is None:
702
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
703
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
704
+
705
+ if block_diff_ppl:
706
+ if position_ids is None:
707
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
708
+
709
+ enc_out = self.encoder(
710
+ past_key_values=past_key_values,
711
+ input_ids=noisy_inputs,
712
+ attention_mask=attention_mask,
713
+ position_ids=position_ids,
714
+ is_training=(labels is not None) or (block_diff_ppl),
715
+ **kwargs,
716
+ )
717
+
718
+ if output_last_hidden_states_only:
719
+ return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
720
+
721
+ logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
722
+ causal_logits = None
723
+
724
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
725
+ if self.config.dlm_paradigm == 'sbd_block_diff':
726
+ causal_logits = logits[:, input_ids_len:]
727
+ else:
728
+ causal_logits = None
729
+
730
+ logits = logits[:, :input_ids_len]
731
+
732
+ loss = None
733
+ if labels is not None and not skip_loss:
734
+ if self.config.dlm_paradigm == 'autoregressive':
735
+ shift_logits = logits[..., :-1, :].contiguous()
736
+ shift_labels = labels[..., 1:].contiguous()
737
+
738
+ if loss_mask is None:
739
+ loss_fct = CrossEntropyLoss()
740
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
741
+ shift_labels = shift_labels.view(-1)
742
+ loss = loss_fct(shift_logits, shift_labels)
743
+
744
+ else:
745
+ loss_mask = loss_mask[..., 1:].contiguous()
746
+
747
+ loss_fct = CrossEntropyLoss(reduction='none')
748
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
749
+ shift_labels = shift_labels.view(-1)
750
+ shift_labels = shift_labels.to(shift_logits.device)
751
+
752
+ token_losses = loss_fct(shift_logits, shift_labels)
753
+
754
+ flat_loss_mask = loss_mask.reshape(-1)
755
+ loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
756
+
757
+ else:
758
+ # Handle DREAM vs LLADA style losses
759
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
760
+ logits = logits[..., :-1, :].contiguous()
761
+ labels = labels[..., 1:].contiguous()
762
+ masked_indices = masked_indices[:, 1:]
763
+ p_mask = p_mask[:, 1:]
764
+
765
+ if self.config.ada_perm_ratio_per_block is not None:
766
+ # Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
767
+ block_size = self.config.block_size
768
+ batch_size, seq_len = masked_indices.shape
769
+ num_blocks = seq_len // block_size
770
+
771
+ # Get the max logit (confidence) for each position
772
+ confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
773
+
774
+ # Create a mask for tokens to include in loss
775
+ selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
776
+
777
+ for blk in range(num_blocks):
778
+ start = blk * block_size
779
+ end = min((blk + 1) * block_size, seq_len)
780
+
781
+ # Get masked indices within this block
782
+ block_masked = masked_indices[:, start:end] # (batch_size, block_len)
783
+ block_confidence = confidence[:, start:end] # (batch_size, block_len)
784
+
785
+ for b in range(batch_size):
786
+ # Get positions that are masked in this block for this batch
787
+ masked_positions = torch.where(block_masked[b])[0]
788
+ num_masked = len(masked_positions)
789
+
790
+ if num_masked > 0:
791
+ # Number of tokens to keep (top by confidence)
792
+ k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
793
+
794
+ # Get confidence values for masked positions
795
+ masked_confidence = block_confidence[b, masked_positions]
796
+
797
+ # Get indices of top-k confident tokens
798
+ _, topk_indices = torch.topk(masked_confidence, k)
799
+ selected_positions = masked_positions[topk_indices]
800
+
801
+ # Mark these positions in the selected mask
802
+ selected_mask[b, start + selected_positions] = True
803
+
804
+ # Calculate loss only for selected positions
805
+ token_loss = torch.nn.functional.cross_entropy(
806
+ logits[selected_mask],
807
+ labels[selected_mask],
808
+ reduction='none'
809
+ ) / p_mask[selected_mask]
810
+
811
+ num_mask_tokens = selected_mask.sum()
812
+
813
+ else:
814
+ # Calculate token-wise cross entropy loss for masked positions in B
815
+ token_loss = torch.nn.functional.cross_entropy(
816
+ logits[masked_indices],
817
+ labels[masked_indices],
818
+ reduction='none'
819
+ ) / p_mask[masked_indices]
820
+
821
+ num_mask_tokens = masked_indices.sum()
822
+
823
+ if self.config.global_loss_avg:
824
+ loss = token_loss.sum()
825
+ else:
826
+ loss = token_loss.sum() / num_mask_tokens
827
+
828
+ if self.config.ada_dlm_loss_ratio is not None:
829
+ assert self.current_iter_ratio is not None
830
+ assert self.config.dlm_loss_weight is not None
831
+
832
+ dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
833
+ loss = dlm_loss_weight * loss
834
+
835
+ elif self.config.dlm_loss_weight is not None:
836
+ loss = self.config.dlm_loss_weight * loss
837
+
838
+ if self.config.dlm_paradigm == 'sbd_block_diff':
839
+ causal_logits = causal_logits[..., :-1, :].contiguous()
840
+ causal_logits = causal_logits.view(-1, causal_logits.size(-1))
841
+
842
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
843
+ causal_labels = labels.view(-1)
844
+ else:
845
+ causal_labels = labels[..., 1:].contiguous().view(-1)
846
+
847
+ if self.config.global_loss_avg:
848
+ loss_fct = CrossEntropyLoss(reduction='sum')
849
+ ar_loss = loss_fct(causal_logits, causal_labels)
850
+
851
+ self.loss_diffusion = loss.detach().item() / num_mask_tokens
852
+ self.loss_ar = ar_loss.detach().item() / seq_len
853
+
854
+ loss = loss + self.config.ar_loss_weight * ar_loss
855
+ else:
856
+ loss_fct = CrossEntropyLoss()
857
+ ar_loss = loss_fct(causal_logits, causal_labels)
858
+
859
+ self.loss_diffusion = loss.detach().item()
860
+ self.loss_ar = ar_loss.detach().item()
861
+
862
+ loss = loss + self.config.ar_loss_weight * ar_loss
863
+
864
+ if self.config.global_loss_avg:
865
+ if self.config.dlm_paradigm == 'sbd_block_diff':
866
+ loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
867
+ else:
868
+ loss = (loss, num_mask_tokens)
869
+
870
+ return MinistralDiffOutputWithPast(
871
+ loss=loss if not is_teacher else logits,
872
+ logits=logits,
873
+ causal_logits=causal_logits,
874
+ past_key_values=enc_out.past_key_values,
875
+ hidden_states=None,
876
+ attentions=None,
877
+ )
878
+
879
+
880
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None):
881
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
882
+ model=self,
883
+ prompt=prompt_ids,
884
+ gen_length=max_new_tokens,
885
+ steps=steps,
886
+ block_length=block_length,
887
+ remasking="low_confidence",
888
+ temperature=temperature,
889
+ mask_id=self.mask_token_id,
890
+ threshold=threshold,
891
+ shift_logits=shift_logits,
892
+ neg_entropy=False,
893
+ causal_context=causal_context,
894
+ eos_token_id=eos_token_id,
895
+ )
896
+
897
+ return out_ids, nfe
898
+
899
+
900
+ @torch.no_grad()
901
+ def sbd_inference_diffusion_quadratic(
902
+ self,
903
+ clean_input_ids: Optional[torch.Tensor],
904
+ draft_input_ids: torch.Tensor,
905
+ block_length: int,
906
+ draft_only: bool = False,
907
+ past_key_values: Optional[Cache] = None,
908
+ use_cache: bool = False,
909
+ ):
910
+ """SBD quadratic inference (injected by build_hf_tidar_repo)."""
911
+ enc_config = self.encoder.config
912
+ enc_config.use_sbd_objective = True
913
+ enc_config.block_length = block_length
914
+
915
+ if draft_only:
916
+ assert clean_input_ids is not None
917
+
918
+ if use_cache and past_key_values is None:
919
+ past_key_values = DynamicCache()
920
+
921
+ enc_config.tidar_inference_mode = "default"
922
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
923
+ outputs = self.encoder(
924
+ input_ids=input_ids,
925
+ position_ids=None,
926
+ past_key_values=past_key_values,
927
+ use_cache=use_cache,
928
+ is_training=False,
929
+ )
930
+
931
+ hidden_states = outputs.last_hidden_state
932
+ logits = self.diffusion_head(hidden_states)
933
+
934
+ past_key_values = getattr(outputs, "past_key_values", None)
935
+ if use_cache and past_key_values is not None:
936
+ _crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
937
+
938
+ return logits, past_key_values
939
+ else:
940
+ enc_config.tidar_inference_mode = "quadratic"
941
+
942
+ draft_len = block_length * (block_length + 1)
943
+ draft_input_ids = torch.cat(
944
+ [
945
+ draft_input_ids.view(-1, block_length, 1),
946
+ torch.full(
947
+ (draft_input_ids.shape[0], block_length, block_length),
948
+ fill_value=self.config.mask_token_id,
949
+ device=draft_input_ids.device,
950
+ ),
951
+ ],
952
+ dim=-1,
953
+ ).view(-1, draft_len)
954
+
955
+ if use_cache:
956
+ assert past_key_values is not None, (
957
+ "Past key values should be provided when using cache, e.g. run draft_only=True first."
958
+ )
959
+ assert clean_input_ids is None, (
960
+ "Clean input ids should already be in cache, thus none should be provided."
961
+ )
962
+ clean_len = past_key_values.get_seq_length()
963
+ input_ids = draft_input_ids
964
+ else:
965
+ clean_len = clean_input_ids.shape[1]
966
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
967
+
968
+ per_block_position_ids = torch.arange(
969
+ clean_len, clean_len + block_length + 1, device=draft_input_ids.device
970
+ )[None,].repeat(block_length, 1)
971
+ per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
972
+
973
+ if use_cache:
974
+ position_ids = per_block_position_ids.view(-1)[None,]
975
+ else:
976
+ clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
977
+ position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
978
+
979
+ outputs = self.encoder(
980
+ input_ids=input_ids,
981
+ position_ids=position_ids,
982
+ past_key_values=past_key_values,
983
+ use_cache=use_cache,
984
+ is_training=False,
985
+ )
986
+
987
+ hidden_states = outputs.last_hidden_state
988
+ logits = self.diffusion_head(hidden_states)
989
+ past_key_values = getattr(outputs, "past_key_values", None)
990
+
991
+ if use_cache and past_key_values is not None:
992
+ _extract_draft_kv_cache(past_key_values, clean_len, block_length)
993
+
994
+ return logits, past_key_values
995
+
996
+ @torch.no_grad()
997
+ def tidar_generate(
998
+ self,
999
+ prompt_ids: torch.Tensor,
1000
+ max_new_tokens: int = 128,
1001
+ steps: int = 128,
1002
+ block_length: int = 16,
1003
+ threshold: Optional[float] = None,
1004
+ temperature: float = 0.0,
1005
+ mask_token_id: Optional[int] = None,
1006
+ eos_token_id: Optional[int] = None,
1007
+ ):
1008
+ """TiDAR quadratic speculative decoding (injected by build_hf_tidar_repo)."""
1009
+ self.config.use_sbd_objective = True
1010
+ self.config.dlm_paradigm = "sbd"
1011
+
1012
+ if prompt_ids.shape[0] != 1:
1013
+ raise ValueError("TiDAR quadratic decoding currently requires batch_size == 1")
1014
+
1015
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1016
+ if eos_token_id is None:
1017
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1018
+
1019
+ x = torch.full(
1020
+ (1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
1021
+ token_mask_id,
1022
+ dtype=torch.long,
1023
+ device=prompt_ids.device,
1024
+ )
1025
+ x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
1026
+
1027
+ if max_new_tokens % block_length != 0:
1028
+ raise ValueError("max_new_tokens must be divisible by block_length")
1029
+ num_blocks = max_new_tokens // block_length
1030
+ if steps % num_blocks != 0:
1031
+ raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
1032
+
1033
+ prompt_len = prompt_ids.shape[1]
1034
+ nfe = 0
1035
+ nfe += 1
1036
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1037
+ clean_input_ids=x[:, :prompt_len],
1038
+ draft_input_ids=x[:, prompt_len : prompt_len + block_length],
1039
+ block_length=block_length,
1040
+ draft_only=True,
1041
+ use_cache=True,
1042
+ )
1043
+
1044
+ logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
1045
+ logits_proposal[:, 1] = logits_proposal[:, 0]
1046
+ logits_proposal = logits_proposal[:, 1:]
1047
+ x0_proposal = torch.argmax(logits_proposal, dim=-1)
1048
+ x[:, prompt_len : prompt_len + block_length] = x0_proposal
1049
+
1050
+ total_accept_token = 0
1051
+ while True:
1052
+ nfe += 1
1053
+ block_start = prompt_len + total_accept_token
1054
+ block_end = block_start + block_length
1055
+ draft_input_ids = x[:, block_start:block_end]
1056
+
1057
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1058
+ clean_input_ids=None,
1059
+ draft_input_ids=draft_input_ids,
1060
+ block_length=block_length,
1061
+ draft_only=False,
1062
+ past_key_values=past_key_values,
1063
+ use_cache=True,
1064
+ )
1065
+
1066
+ useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1067
+ if threshold is None:
1068
+ useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1069
+ else:
1070
+ if not (0.0 <= threshold <= 1.0):
1071
+ raise ValueError("threshold must be between 0 and 1")
1072
+ mix_logits = useful_token_logits[:, :, 0] * threshold + useful_token_logits[:, :, 1] * (1 - threshold)
1073
+ useful_token_logits[:, :, 0] = mix_logits
1074
+ useful_token_logits[:, :, 1] = mix_logits
1075
+
1076
+ if temperature > 0:
1077
+ useful_token_logits = useful_token_logits / temperature
1078
+
1079
+ useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
1080
+ new_draft_input_ids = useful_token_pred[:, 0, 1:]
1081
+ accept_cnt = 1
1082
+
1083
+ while accept_cnt < block_length:
1084
+ if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
1085
+ break
1086
+ new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
1087
+ accept_cnt += 1
1088
+
1089
+ x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
1090
+
1091
+ # EoS early stopping: all accepted tokens are finalized left-to-right,
1092
+ # so if any is EoS we can truncate and return immediately.
1093
+ if eos_token_id is not None:
1094
+ accepted = x[0, block_start : block_start + accept_cnt]
1095
+ eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
1096
+ if len(eos_positions) > 0:
1097
+ first_eos_rel = eos_positions[0].item()
1098
+ total_accept_token += first_eos_rel + 1
1099
+ output_end = prompt_len + total_accept_token
1100
+ return x[:, :output_end], nfe
1101
+
1102
+ x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1103
+ past_key_values.crop(block_start + accept_cnt)
1104
+ total_accept_token += accept_cnt
1105
+
1106
+ if total_accept_token >= max_new_tokens:
1107
+ break
1108
+
1109
+ return x[:, : -(block_length * 2)], nfe
1110
+
1111
+
1112
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]