YongganFu commited on
Commit
a195040
·
verified ·
1 Parent(s): 3fa953f

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
chat_utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ '''
8
+ The Gumbel max is a method for sampling categorical distributions.
9
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
+ Thus, we use float64.
11
+ '''
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
+ x0 = torch.argmax(logits_with_noise, dim=-1)
23
+
24
+ if remasking == 'low_confidence':
25
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
26
+ p = F.softmax(logits, dim=-1)
27
+ x0_p = torch.squeeze(
28
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
+ elif remasking == 'top_p_margin':
30
+ # Compute probabilities
31
+ p = F.softmax(logits, dim=-1) # (B, L, V)
32
+ # Top-2 per position
33
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
35
+
36
+ # Normalize margin to [0,1] over MASKED positions per row
37
+ plus_inf = torch.full_like(margin, float('inf'))
38
+ minus_inf = torch.full_like(margin, float('-inf'))
39
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
40
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
41
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
42
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
43
+ denom = (row_max - row_min)
44
+
45
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
46
+ normalized = torch.zeros_like(margin)
47
+ nonzero = denom > 0
48
+ normalized = torch.where(
49
+ mask_index & nonzero,
50
+ (margin - row_min) / (denom + 1e-12),
51
+ normalized
52
+ )
53
+ normalized = torch.where(
54
+ mask_index & (~nonzero),
55
+ torch.ones_like(normalized),
56
+ normalized
57
+ )
58
+ x0_p = normalized # ∈ [0,1] on masked positions
59
+ elif remasking == 'random':
60
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
61
+ else:
62
+ raise NotImplementedError(remasking)
63
+
64
+ # Calculate negative entropy if requested
65
+ if neg_entropy:
66
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
67
+ p = F.softmax(logits, dim=-1)
68
+ epsilon = 1e-10
69
+ log_probs = torch.log(p + epsilon)
70
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
71
+ else:
72
+ confidence_scores = x0_p
73
+
74
+ x0 = torch.where(mask_index, x0, x)
75
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
76
+
77
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
78
+ if threshold is not None:
79
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
80
+ # print(f'confidence: {confidence}')
81
+ for j in range(confidence.shape[0]):
82
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
83
+ transfer_index[j, select_index] = True
84
+ if threshold is not None:
85
+ for k in range(1, num_transfer_tokens[j]):
86
+ if confidence[j, select_index[k]] < threshold:
87
+ transfer_index[j, select_index[k]] = False
88
+ return x0, transfer_index
89
+
90
+
91
+ def get_num_transfer_tokens(mask_index, steps: int):
92
+ mask_num = mask_index.sum(dim=1, keepdim=True)
93
+ base = mask_num // steps
94
+ remainder = mask_num % steps
95
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
96
+ for i in range(mask_num.size(0)):
97
+ num_transfer_tokens[i, : int(remainder[i])] += 1
98
+ return num_transfer_tokens
99
+
100
+
101
+ @torch.no_grad()
102
+ def generate_with_prefix_cache_block_diff(
103
+ model,
104
+ prompt,
105
+ steps=128,
106
+ gen_length=128,
107
+ block_length=128,
108
+ temperature=0.,
109
+ remasking='low_confidence',
110
+ mask_id=126336,
111
+ threshold=None,
112
+ factor=None,
113
+ shift_logits=False,
114
+ neg_entropy=False,
115
+ causal_context=False,
116
+ ):
117
+ dream_style=shift_logits
118
+ # Initialize the accumulator
119
+ x_accum = prompt.clone()
120
+
121
+ assert gen_length % block_length == 0
122
+ num_blocks = gen_length // block_length
123
+
124
+ assert steps % num_blocks == 0
125
+ steps_per_block = steps // num_blocks
126
+
127
+ nfe = 0
128
+
129
+ if causal_context:
130
+ model_module = model.module if hasattr(model, "module") else model
131
+ for layer in model_module.encoder.layers:
132
+ if hasattr(layer.self_attn, 'diffusion_lm'):
133
+ layer.self_attn.diffusion_lm=False
134
+
135
+ # Compute KV cache for the prompt initially
136
+ output = model(prompt, use_cache=True)
137
+ past_key_values = output.past_key_values
138
+
139
+ if causal_context:
140
+ for layer in model_module.encoder.layers:
141
+ if hasattr(layer.self_attn, 'diffusion_lm'):
142
+ layer.self_attn.diffusion_lm=True
143
+
144
+ # For dream_style: store the "next token logit" of the context
145
+ next_logits_context = None
146
+ if dream_style:
147
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
148
+
149
+ for num_block in range(num_blocks):
150
+ # Create a new block with mask tokens (no seeding)
151
+ mask_block = torch.ones(
152
+ (prompt.shape[0], block_length),
153
+ dtype=prompt.dtype,
154
+ device=prompt.device
155
+ ) * mask_id
156
+
157
+ # Append the block of masks
158
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
159
+ current_block_start = prompt.size(1) + num_block * block_length
160
+ block_slice = slice(current_block_start, current_block_start + block_length)
161
+
162
+ # Build the initial mask for this block
163
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
164
+
165
+ # Precompute the transfer schedule for this block
166
+ if dream_style:
167
+ # still denoise *all* positions (0..Lb-1), since none are seeded
168
+ schedule_mask = mask_block_idx0
169
+ else:
170
+ schedule_mask = mask_block_idx0
171
+
172
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
173
+
174
+ # Denoise the current block
175
+ for i in range(steps_per_block):
176
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
177
+ if mask_block_idx.sum() == 0:
178
+ break
179
+
180
+ nfe += 1
181
+
182
+ # Forward only the current noisy block using cached context
183
+ logits_block = model(
184
+ x_accum[:, block_slice],
185
+ past_key_values=past_key_values,
186
+ use_cache=False
187
+ ).logits
188
+
189
+ if dream_style:
190
+ # Align logits so that each masked position has a predictor:
191
+ # prepend context-next logit, then use logits_block[:-1]
192
+ if block_length == 1:
193
+ logits_use = next_logits_context # (B, 1, V)
194
+ else:
195
+ logits_use = torch.cat(
196
+ [next_logits_context, logits_block[:, :-1, :]],
197
+ dim=1
198
+ ) # (B, Lb, V)
199
+
200
+ mask_use = mask_block_idx # (B, Lb)
201
+ x_use = x_accum[:, block_slice] # (B, Lb)
202
+
203
+ x0, transfer_idx = get_transfer_index(
204
+ logits_use, temperature, remasking, mask_use, x_use,
205
+ num_transfer_tokens=num_transfer_tokens[:, i],
206
+ threshold=threshold, neg_entropy=neg_entropy
207
+ )
208
+ cur = x_accum[:, block_slice].clone()
209
+ cur[transfer_idx] = x0[transfer_idx]
210
+ x_accum[:, block_slice] = cur
211
+
212
+ else:
213
+ # non-AR (same-position) case
214
+ x0, transfer_idx = get_transfer_index(
215
+ logits_block, temperature, remasking, mask_block_idx,
216
+ x_accum[:, block_slice],
217
+ num_transfer_tokens=num_transfer_tokens[:, i],
218
+ threshold=threshold, neg_entropy=neg_entropy
219
+ )
220
+ cur = x_accum[:, block_slice].clone()
221
+ cur[transfer_idx] = x0[transfer_idx]
222
+ x_accum[:, block_slice] = cur
223
+
224
+ if causal_context:
225
+ for layer in model_module.encoder.layers:
226
+ if hasattr(layer.self_attn, 'diffusion_lm'):
227
+ layer.self_attn.diffusion_lm=False
228
+
229
+ # after block is fully denoised, update KV cache
230
+ output = model(
231
+ x_accum[:, block_slice],
232
+ past_key_values=past_key_values,
233
+ use_cache=True
234
+ )
235
+ past_key_values = output.past_key_values
236
+
237
+ if causal_context:
238
+ for layer in model_module.encoder.layers:
239
+ if hasattr(layer.self_attn, 'diffusion_lm'):
240
+ layer.self_attn.diffusion_lm=True
241
+
242
+ if dream_style and num_block < num_blocks - 1:
243
+ # refresh context-next logit for the next block
244
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
245
+
246
+ return x_accum, nfe
config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ada_dlm_loss_ratio": null,
3
+ "ada_perm_ratio_global": null,
4
+ "ada_perm_ratio_per_block": null,
5
+ "adaptive_mask_rate": false,
6
+ "ar_loss_weight": 1.0,
7
+ "architectures": [
8
+ "MinistralDiffEncoderModel"
9
+ ],
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "attn_implementation": "sdpa",
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
15
+ "AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
16
+ },
17
+ "block_size": 32,
18
+ "bos_token_id": 1,
19
+ "diff_loss_weight": 1,
20
+ "dlm_arch": "encoder",
21
+ "dlm_loss_weight": null,
22
+ "dlm_paradigm": "bidirectional",
23
+ "dlm_type": "llada",
24
+ "dp_varying_mask_ratio": false,
25
+ "dtype": "float32",
26
+ "enforce_mask": false,
27
+ "eos_token_id": 2,
28
+ "global_loss_avg": false,
29
+ "head_dim": 128,
30
+ "hidden_act": "silu",
31
+ "hidden_size": 4096,
32
+ "initializer_range": 0.02,
33
+ "intermediate_size": 14336,
34
+ "mask_token_id": 100,
35
+ "max_position_embeddings": 262144,
36
+ "mlp_bias": false,
37
+ "model_type": "ministral_dlm",
38
+ "multi_sampling": null,
39
+ "num_ar_layers": 0,
40
+ "num_attention_heads": 32,
41
+ "num_diffusion_layers": 0,
42
+ "num_hidden_layers": 34,
43
+ "num_key_value_heads": 8,
44
+ "num_skip_loss_tokens": 0,
45
+ "prefix_ratio": 0.8,
46
+ "random_length_prob": 0.05,
47
+ "rms_norm_eps": 1e-05,
48
+ "rope_parameters": {
49
+ "beta_fast": 32.0,
50
+ "beta_slow": 1.0,
51
+ "factor": 16.0,
52
+ "llama_4_scaling_beta": 0.1,
53
+ "mscale": 1.0,
54
+ "mscale_all_dim": 1.0,
55
+ "original_max_position_embeddings": 16384,
56
+ "rope_theta": 1000000.0,
57
+ "rope_type": "yarn",
58
+ "type": "yarn"
59
+ },
60
+ "rope_theta": 1000000.0,
61
+ "seq_length": 8192,
62
+ "sliding_window": null,
63
+ "tie_word_embeddings": false,
64
+ "tok_mask_half_life_ratio": null,
65
+ "transformers_version": "5.0.0rc1",
66
+ "use_cache": false,
67
+ "vocab_size": 131072
68
+ }
configuration_ministral_dlm.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Ministral DLM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MinistralDLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
+ It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 131072):
35
+ Vocabulary size of the Ministral model.
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 14336):
39
+ Dimension of the MLP representations.
40
+ num_hidden_layers (`int`, *optional*, defaults to 34):
41
+ Number of hidden layers in the Transformer decoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 32):
43
+ Number of attention heads for each attention layer.
44
+ num_key_value_heads (`int`, *optional*, defaults to 8):
45
+ Number of key_value heads for Grouped Query Attention.
46
+ head_dim (`int`, *optional*, defaults to 128):
47
+ The attention head dimension.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function.
50
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
51
+ The maximum sequence length.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions.
58
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
59
+ Whether the model's input and output word embeddings should be tied.
60
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
61
+ The base period of the RoPE embeddings.
62
+ rope_parameters (`Dict`, *optional*):
63
+ Dictionary containing the scaling configuration for the RoPE embeddings.
64
+ Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
65
+ attention_bias (`bool`, defaults to `False`):
66
+ Whether to use a bias in the query, key, value and output projection layers.
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ mlp_bias (`bool`, *optional*, defaults to `False`):
70
+ Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
+ sliding_window (`int`, *optional*, defaults to None):
72
+ Sliding window attention size.
73
+ seq_length (`int`, *optional*, defaults to 8192):
74
+ Sequence length for training.
75
+ mask_token_id (`int`, *optional*, defaults to -1):
76
+ Token ID for masking in diffusion.
77
+ dlm_type (`str`, *optional*, defaults to 'llada'):
78
+ Type of diffusion language model ('llada', 'dream').
79
+ random_length_prob (`float`, *optional*):
80
+ Probability of using random lengths during training.
81
+ num_ar_layers (`int`, *optional*, defaults to 0):
82
+ Number of autoregressive layers.
83
+ num_diffusion_layers (`int`, *optional*, defaults to 0):
84
+ Number of diffusion layers.
85
+ diff_loss_weight (`float`, *optional*, defaults to 1):
86
+ Weight for diffusion loss.
87
+ enforce_mask (`bool`, *optional*, defaults to False):
88
+ Whether to enforce masking.
89
+ prefix_ratio (`float`, *optional*, defaults to 0.8):
90
+ Ratio for prefix in prefix_bidirectional mode.
91
+ dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
92
+ Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
93
+ dlm_arch (`str`, *optional*, defaults to 'encoder'):
94
+ Architecture type ('encoder', 'encoder_decoder').
95
+ block_size (`int`, *optional*, defaults to 32):
96
+ Block size for block diffusion paradigms.
97
+ tok_mask_half_life_ratio (`float`, *optional*):
98
+ Half-life ratio for token masking.
99
+ adaptive_mask_rate (`bool`, *optional*, defaults to False):
100
+ Whether to use adaptive mask rate.
101
+ multi_sampling (`int`, *optional*):
102
+ Number of samples for multi-sampling.
103
+ num_skip_loss_tokens (`int`, *optional*, defaults to 0):
104
+ Number of tokens to skip in loss calculation.
105
+ dlm_loss_weight (`float`, *optional*):
106
+ Weight for diffusion LM loss.
107
+ ar_loss_weight (`float`, *optional*, defaults to 1.0):
108
+ Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
109
+ global_loss_avg (`bool`, *optional*, defaults to False):
110
+ Whether to use global loss average.
111
+ dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
112
+ Whether to use varying mask ratio for each DP rank during sampling.
113
+ ada_perm_ratio_per_block (`float`, *optional*):
114
+ Adaptive permutation ratio for each block.
115
+ ada_perm_ratio_global (`float`, *optional*):
116
+ Adaptive permutation ratio for global.
117
+ """
118
+
119
+ model_type = "ministral_dlm"
120
+ keys_to_ignore_at_inference = ["past_key_values"]
121
+
122
+ # Default tensor parallel plan for base model `Ministral`
123
+ base_model_tp_plan = {
124
+ "layers.*.self_attn.q_proj": "colwise",
125
+ "layers.*.self_attn.k_proj": "colwise",
126
+ "layers.*.self_attn.v_proj": "colwise",
127
+ "layers.*.self_attn.o_proj": "rowwise",
128
+ "layers.*.mlp.gate_proj": "colwise",
129
+ "layers.*.mlp.up_proj": "colwise",
130
+ "layers.*.mlp.down_proj": "rowwise",
131
+ }
132
+ base_model_pp_plan = {
133
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
134
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
135
+ "norm": (["hidden_states"], ["hidden_states"]),
136
+ }
137
+
138
+ def __init__(
139
+ self,
140
+ vocab_size=131072,
141
+ hidden_size=4096,
142
+ intermediate_size=14336,
143
+ num_hidden_layers=34,
144
+ num_attention_heads=32,
145
+ num_key_value_heads=8,
146
+ head_dim=128,
147
+ hidden_act="silu",
148
+ max_position_embeddings=262144,
149
+ initializer_range=0.02,
150
+ rms_norm_eps=1e-05,
151
+ use_cache=True,
152
+ pad_token_id=None,
153
+ bos_token_id=1,
154
+ eos_token_id=2,
155
+ tie_word_embeddings=False,
156
+ rope_theta=1000000.0,
157
+ rope_parameters=None,
158
+ attention_bias=False,
159
+ attention_dropout=0.0,
160
+ mlp_bias=False,
161
+ sliding_window=None,
162
+ attn_implementation="sdpa",
163
+ seq_length=8192,
164
+ mask_token_id=-1,
165
+ dlm_type='llada',
166
+ random_length_prob=None,
167
+ num_ar_layers=0,
168
+ num_diffusion_layers=0,
169
+ diff_loss_weight=1,
170
+ enforce_mask=False,
171
+ prefix_ratio=0.8,
172
+ dlm_paradigm='bidirectional',
173
+ dlm_arch='encoder',
174
+ block_size=32,
175
+ tok_mask_half_life_ratio=None,
176
+ adaptive_mask_rate=False,
177
+ multi_sampling=None,
178
+ num_skip_loss_tokens=0,
179
+ dlm_loss_weight=None,
180
+ ar_loss_weight=1.0,
181
+ global_loss_avg=False,
182
+ dp_varying_mask_ratio=False,
183
+ ada_perm_ratio_per_block=None,
184
+ ada_perm_ratio_global=None,
185
+ ada_dlm_loss_ratio=None,
186
+ **kwargs,
187
+ ):
188
+ self.vocab_size = vocab_size
189
+ self.max_position_embeddings = max_position_embeddings
190
+ self.hidden_size = hidden_size
191
+ self.intermediate_size = intermediate_size
192
+ self.num_hidden_layers = num_hidden_layers
193
+ self.num_attention_heads = num_attention_heads
194
+
195
+ # for backward compatibility
196
+ if num_key_value_heads is None:
197
+ num_key_value_heads = num_attention_heads
198
+
199
+ self.num_key_value_heads = num_key_value_heads
200
+ self.head_dim = head_dim
201
+ self.hidden_act = hidden_act
202
+ self.initializer_range = initializer_range
203
+ self.rms_norm_eps = rms_norm_eps
204
+ self.use_cache = use_cache
205
+ self.rope_theta = rope_theta
206
+ self.rope_parameters = rope_parameters
207
+ self.attention_bias = attention_bias
208
+ self.attention_dropout = attention_dropout
209
+ self.mlp_bias = mlp_bias
210
+ self.sliding_window = sliding_window
211
+
212
+ rope_config_validation(self)
213
+
214
+ self.attn_implementation = attn_implementation
215
+ self.seq_length = seq_length
216
+
217
+ self.mask_token_id = mask_token_id
218
+ self.dlm_type = dlm_type
219
+ self.random_length_prob = random_length_prob
220
+ self.num_ar_layers = num_ar_layers
221
+ self.num_diffusion_layers = num_diffusion_layers
222
+ self.diff_loss_weight = diff_loss_weight
223
+ self.enforce_mask = enforce_mask
224
+ self.prefix_ratio = prefix_ratio
225
+ self.dlm_paradigm = dlm_paradigm
226
+ self.dlm_arch = dlm_arch
227
+ self.block_size = block_size
228
+ self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
229
+ self.adaptive_mask_rate = adaptive_mask_rate
230
+ self.multi_sampling = multi_sampling
231
+ self.num_skip_loss_tokens = num_skip_loss_tokens
232
+ self.dlm_loss_weight = dlm_loss_weight
233
+ self.ar_loss_weight = ar_loss_weight
234
+ self.global_loss_avg = global_loss_avg
235
+ self.dp_varying_mask_ratio = dp_varying_mask_ratio
236
+ self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
237
+ self.ada_perm_ratio_global = ada_perm_ratio_global
238
+ self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
239
+ super().__init__(
240
+ pad_token_id=pad_token_id,
241
+ bos_token_id=bos_token_id,
242
+ eos_token_id=eos_token_id,
243
+ tie_word_embeddings=tie_word_embeddings,
244
+ **kwargs,
245
+ )
246
+
247
+
248
+ __all__ = ["MinistralDLMConfig"]
249
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "5.0.0rc1",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:581e534c77fd49b8ab1d234f65bf08b88b03bd4dd10e397285a3441260957a8d
3
+ size 16979144720
modeling_ministral.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from transformers.utils.generic import check_model_inputs
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
14
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
15
+ from transformers.modeling_layers import (
16
+ GenericForQuestionAnswering,
17
+ GenericForSequenceClassification,
18
+ GenericForTokenClassification,
19
+ GradientCheckpointingLayer,
20
+ )
21
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
22
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
23
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
24
+ from transformers.processing_utils import Unpack
25
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
26
+ from transformers.utils.generic import maybe_autocast
27
+ from .configuration_ministral_dlm import MinistralDLMConfig
28
+
29
+
30
+ def rotate_half(x):
31
+ """Rotates half the hidden dims of the input."""
32
+ x1 = x[..., : x.shape[-1] // 2]
33
+ x2 = x[..., x.shape[-1] // 2 :]
34
+ return torch.cat((-x2, x1), dim=-1)
35
+
36
+
37
+ @use_kernel_func_from_hub("rotary_pos_emb")
38
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
39
+ """Applies Rotary Position Embedding to the query and key tensors.
40
+
41
+ Args:
42
+ q (`torch.Tensor`): The query tensor.
43
+ k (`torch.Tensor`): The key tensor.
44
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
45
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
46
+ position_ids (`torch.Tensor`, *optional*):
47
+ Deprecated and unused.
48
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
49
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
50
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
51
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
52
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
53
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
54
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
55
+ Returns:
56
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
57
+ """
58
+ cos = cos.unsqueeze(unsqueeze_dim)
59
+ sin = sin.unsqueeze(unsqueeze_dim)
60
+ q_embed = (q * cos) + (rotate_half(q) * sin)
61
+ k_embed = (k * cos) + (rotate_half(k) * sin)
62
+ return q_embed, k_embed
63
+
64
+
65
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
66
+ """
67
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
68
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
69
+ """
70
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
71
+ if n_rep == 1:
72
+ return hidden_states
73
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
74
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
75
+
76
+
77
+ def eager_attention_forward(
78
+ module: nn.Module,
79
+ query: torch.Tensor,
80
+ key: torch.Tensor,
81
+ value: torch.Tensor,
82
+ attention_mask: Optional[torch.Tensor],
83
+ scaling: float,
84
+ dropout: float = 0.0,
85
+ **kwargs: Unpack[TransformersKwargs],
86
+ ):
87
+ key_states = repeat_kv(key, module.num_key_value_groups)
88
+ value_states = repeat_kv(value, module.num_key_value_groups)
89
+
90
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
91
+ if attention_mask is not None:
92
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
93
+ attn_weights = attn_weights + causal_mask
94
+
95
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
96
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
97
+ attn_output = torch.matmul(attn_weights, value_states)
98
+ attn_output = attn_output.transpose(1, 2).contiguous()
99
+
100
+ return attn_output, attn_weights
101
+
102
+
103
+ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
104
+ scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
105
+ return scaling.unsqueeze(-1)
106
+
107
+
108
+ @use_kernelized_func(apply_rotary_pos_emb)
109
+ class Ministral3Attention(nn.Module):
110
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
111
+
112
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
113
+ super().__init__()
114
+ self.config = config
115
+ self.layer_idx = layer_idx
116
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
117
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
118
+ self.scaling = self.head_dim**-0.5
119
+ self.attention_dropout = config.attention_dropout
120
+ self.is_causal = True
121
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
122
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
123
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
124
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
125
+
126
+ self.diffusion_lm = config.diffusion_lm
127
+
128
+ def forward(
129
+ self,
130
+ hidden_states: torch.Tensor,
131
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
132
+ attention_mask: Optional[torch.Tensor],
133
+ past_key_values: Optional[Cache] = None,
134
+ cache_position: Optional[torch.LongTensor] = None,
135
+ use_cache: Optional[bool] = False,
136
+ **kwargs: Unpack[FlashAttentionKwargs],
137
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
138
+ input_shape = hidden_states.shape[:-1]
139
+ hidden_shape = (*input_shape, -1, self.head_dim)
140
+
141
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
142
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
143
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
144
+
145
+ cos, sin = position_embeddings
146
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
147
+ query_states = query_states * _get_llama_4_attn_scale(
148
+ cache_position,
149
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
150
+ self.config.rope_parameters.get("original_max_position_embeddings"),
151
+ ).to(query_states.dtype)
152
+
153
+ if past_key_values is not None:
154
+ if use_cache:
155
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
156
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
157
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
158
+ else: ## if use_cache == False, do not update cache
159
+ old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
160
+ key_states = torch.cat([old_k, key_states], dim=-2)
161
+ value_states = torch.cat([old_v, value_states], dim=-2)
162
+
163
+ attention_interface: Callable = eager_attention_forward
164
+ if self.config._attn_implementation != "eager":
165
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
166
+
167
+ if self.diffusion_lm:
168
+ attn_output, attn_weights = attention_interface(
169
+ self,
170
+ query_states,
171
+ key_states,
172
+ value_states,
173
+ None,
174
+ dropout=0.0 if not self.training else self.attention_dropout,
175
+ scaling=self.scaling,
176
+ is_causal=False,
177
+ **kwargs,
178
+ )
179
+
180
+ else:
181
+ attn_output, attn_weights = attention_interface(
182
+ self,
183
+ query_states,
184
+ key_states,
185
+ value_states,
186
+ attention_mask,
187
+ dropout=0.0 if not self.training else self.attention_dropout,
188
+ scaling=self.scaling,
189
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
190
+ **kwargs,
191
+ )
192
+
193
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
194
+ attn_output = self.o_proj(attn_output)
195
+ return attn_output, attn_weights
196
+
197
+
198
+ class Ministral3MLP(nn.Module):
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.config = config
202
+ self.hidden_size = config.hidden_size
203
+ self.intermediate_size = config.intermediate_size
204
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
205
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
206
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
207
+ self.act_fn = ACT2FN[config.hidden_act]
208
+
209
+ def forward(self, x):
210
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
211
+ return down_proj
212
+
213
+
214
+ @use_kernel_forward_from_hub("RMSNorm")
215
+ class Ministral3RMSNorm(nn.Module):
216
+ def __init__(self, hidden_size, eps=1e-6):
217
+ """
218
+ Ministral3RMSNorm is equivalent to T5LayerNorm
219
+ """
220
+ super().__init__()
221
+ self.weight = nn.Parameter(torch.ones(hidden_size))
222
+ self.variance_epsilon = eps
223
+
224
+ def forward(self, hidden_states):
225
+ input_dtype = hidden_states.dtype
226
+ hidden_states = hidden_states.to(torch.float32)
227
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
228
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
229
+ return self.weight * hidden_states.to(input_dtype)
230
+
231
+ def extra_repr(self):
232
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
233
+
234
+
235
+ class Ministral3DecoderLayer(GradientCheckpointingLayer):
236
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
237
+ super().__init__()
238
+ self.hidden_size = config.hidden_size
239
+
240
+ if hasattr(config, 'attn_class'):
241
+ attn_class = config.attn_class
242
+ else:
243
+ attn_class = Ministral3Attention
244
+
245
+ self.self_attn = attn_class(config=config, layer_idx=layer_idx)
246
+ self.mlp = Ministral3MLP(config)
247
+ self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
+ self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.Tensor,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ position_ids: Optional[torch.LongTensor] = None,
255
+ past_key_values: Optional[Cache] = None,
256
+ use_cache: Optional[bool] = False,
257
+ cache_position: Optional[torch.LongTensor] = None,
258
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
259
+ **kwargs: Unpack[TransformersKwargs],
260
+ ) -> torch.Tensor:
261
+ residual = hidden_states
262
+ hidden_states = self.input_layernorm(hidden_states)
263
+ # Self Attention
264
+ hidden_states, _ = self.self_attn(
265
+ hidden_states=hidden_states,
266
+ attention_mask=attention_mask,
267
+ position_ids=position_ids,
268
+ past_key_values=past_key_values,
269
+ use_cache=use_cache,
270
+ cache_position=cache_position,
271
+ position_embeddings=position_embeddings,
272
+ **kwargs,
273
+ )
274
+ hidden_states = residual + hidden_states
275
+
276
+ # Fully Connected
277
+ residual = hidden_states
278
+ hidden_states = self.post_attention_layernorm(hidden_states)
279
+ hidden_states = self.mlp(hidden_states)
280
+ hidden_states = residual + hidden_states
281
+ return hidden_states
282
+
283
+
284
+ @auto_docstring
285
+ class Ministral3PreTrainedModel(PreTrainedModel):
286
+ config: MinistralDLMConfig
287
+ base_model_prefix = "model"
288
+ supports_gradient_checkpointing = True
289
+ _no_split_modules = ["Ministral3DecoderLayer"]
290
+ _skip_keys_device_placement = ["past_key_values"]
291
+ _supports_flash_attn = True
292
+ _supports_sdpa = True
293
+ _supports_flex_attn = True
294
+
295
+ _can_compile_fullgraph = True
296
+ _supports_attention_backend = True
297
+ _can_record_outputs = {
298
+ "hidden_states": Ministral3DecoderLayer,
299
+ "attentions": Ministral3Attention,
300
+ }
301
+
302
+
303
+ class Ministral3RotaryEmbedding(nn.Module):
304
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
305
+
306
+ def __init__(self, config: MinistralDLMConfig, device=None):
307
+ super().__init__()
308
+ self.max_seq_len_cached = config.max_position_embeddings
309
+ self.original_max_seq_len = config.max_position_embeddings
310
+
311
+ self.config = config
312
+
313
+ self.rope_type = self.config.rope_parameters["rope_type"]
314
+ rope_init_fn: Callable = self.compute_default_rope_parameters
315
+ if self.rope_type != "default":
316
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
317
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
318
+
319
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
320
+ self.original_inv_freq = inv_freq
321
+
322
+
323
+ @staticmethod
324
+ def compute_default_rope_parameters(
325
+ config: Optional[MinistralDLMConfig] = None,
326
+ device: Optional["torch.device"] = None,
327
+ seq_len: Optional[int] = None,
328
+ ) -> tuple["torch.Tensor", float]:
329
+ """
330
+ Computes the inverse frequencies according to the original RoPE implementation
331
+ Args:
332
+ config ([`~transformers.PreTrainedConfig`]):
333
+ The model configuration.
334
+ device (`torch.device`):
335
+ The device to use for initialization of the inverse frequencies.
336
+ seq_len (`int`, *optional*):
337
+ The current sequence length. Unused for this type of RoPE.
338
+ Returns:
339
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
340
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
341
+ """
342
+ base = config.rope_parameters["rope_theta"]
343
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
344
+
345
+ attention_factor = 1.0 # Unused in this type of RoPE
346
+
347
+ # Compute the inverse frequencies
348
+ inv_freq = 1.0 / (
349
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
350
+ )
351
+ return inv_freq, attention_factor
352
+
353
+ @torch.no_grad()
354
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
355
+ def forward(self, x, position_ids):
356
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
357
+ position_ids_expanded = position_ids[:, None, :].float()
358
+
359
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
360
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
361
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
362
+ emb = torch.cat((freqs, freqs), dim=-1)
363
+ cos = emb.cos() * self.attention_scaling
364
+ sin = emb.sin() * self.attention_scaling
365
+
366
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
367
+
368
+
369
+ @auto_docstring
370
+ class Ministral3Model(Ministral3PreTrainedModel):
371
+ def __init__(self, config: MinistralDLMConfig):
372
+ super().__init__(config)
373
+ self.padding_idx = config.pad_token_id
374
+ self.vocab_size = config.vocab_size
375
+
376
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
377
+ self.layers = nn.ModuleList(
378
+ [Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
379
+ )
380
+ self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
+ self.rotary_emb = Ministral3RotaryEmbedding(config=config)
382
+ self.gradient_checkpointing = False
383
+
384
+ # Initialize weights and apply final processing
385
+ self.post_init()
386
+
387
+ @check_model_inputs
388
+ @auto_docstring
389
+ def forward(
390
+ self,
391
+ input_ids: Optional[torch.LongTensor] = None,
392
+ attention_mask: Optional[torch.Tensor] = None,
393
+ position_ids: Optional[torch.LongTensor] = None,
394
+ past_key_values: Optional[Cache] = None,
395
+ inputs_embeds: Optional[torch.FloatTensor] = None,
396
+ use_cache: Optional[bool] = None,
397
+ cache_position: Optional[torch.LongTensor] = None,
398
+ **kwargs: Unpack[TransformersKwargs],
399
+ ) -> BaseModelOutputWithPast:
400
+ if (input_ids is None) ^ (inputs_embeds is not None):
401
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
402
+
403
+ if inputs_embeds is None:
404
+ inputs_embeds = self.embed_tokens(input_ids)
405
+
406
+ if use_cache and past_key_values is None:
407
+ past_key_values = DynamicCache(config=self.config)
408
+
409
+ if cache_position is None:
410
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
411
+ cache_position = torch.arange(
412
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
413
+ )
414
+
415
+ if position_ids is None:
416
+ position_ids = cache_position.unsqueeze(0)
417
+
418
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
419
+ causal_mask = mask_function(
420
+ config=self.config,
421
+ input_embeds=inputs_embeds,
422
+ attention_mask=attention_mask,
423
+ cache_position=cache_position,
424
+ past_key_values=past_key_values,
425
+ position_ids=position_ids,
426
+ )
427
+
428
+ hidden_states = inputs_embeds
429
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
430
+
431
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
432
+ hidden_states = decoder_layer(
433
+ hidden_states,
434
+ attention_mask=causal_mask,
435
+ position_ids=position_ids,
436
+ past_key_values=past_key_values,
437
+ use_cache=use_cache,
438
+ cache_position=cache_position,
439
+ position_embeddings=position_embeddings,
440
+ **kwargs,
441
+ )
442
+ hidden_states = self.norm(hidden_states)
443
+ return BaseModelOutputWithPast(
444
+ last_hidden_state=hidden_states,
445
+ past_key_values=past_key_values if use_cache else None,
446
+ )
447
+
448
+
449
+ @auto_docstring
450
+ class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
451
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
452
+ _tp_plan = {"lm_head": "colwise_rep"}
453
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
454
+
455
+ def __init__(self, config):
456
+ super().__init__(config)
457
+ self.model = Ministral3Model(config)
458
+ self.vocab_size = config.vocab_size
459
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
460
+
461
+ # Initialize weights and apply final processing
462
+ self.post_init()
463
+
464
+ @can_return_tuple
465
+ @auto_docstring
466
+ def forward(
467
+ self,
468
+ input_ids: Optional[torch.LongTensor] = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_values: Optional[Cache] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ labels: Optional[torch.LongTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ cache_position: Optional[torch.LongTensor] = None,
476
+ logits_to_keep: Union[int, torch.Tensor] = 0,
477
+ **kwargs: Unpack[TransformersKwargs],
478
+ ) -> CausalLMOutputWithPast:
479
+ r"""
480
+ Example:
481
+
482
+ ```python
483
+ >>> from transformers import AutoTokenizer, Ministral3ForCausalLM
484
+
485
+ >>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
486
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
487
+
488
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
489
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
490
+
491
+ >>> # Generate
492
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
493
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
494
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
495
+ ```"""
496
+ outputs: BaseModelOutputWithPast = self.model(
497
+ input_ids=input_ids,
498
+ attention_mask=attention_mask,
499
+ position_ids=position_ids,
500
+ past_key_values=past_key_values,
501
+ inputs_embeds=inputs_embeds,
502
+ use_cache=use_cache,
503
+ cache_position=cache_position,
504
+ **kwargs,
505
+ )
506
+
507
+ hidden_states = outputs.last_hidden_state
508
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
509
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
510
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
511
+
512
+ loss = None
513
+ if labels is not None:
514
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
515
+
516
+ return CausalLMOutputWithPast(
517
+ loss=loss,
518
+ logits=logits,
519
+ past_key_values=outputs.past_key_values,
520
+ hidden_states=outputs.hidden_states,
521
+ attentions=outputs.attentions,
522
+ )
523
+
524
+
525
+ class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
526
+ pass
527
+
528
+
529
+ class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
530
+ pass
531
+
532
+
533
+ class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
534
+ pass
535
+
536
+
537
+ __all__ = [
538
+ "Ministral3ForCausalLM",
539
+ "Ministral3ForQuestionAnswering",
540
+ "Ministral3Model",
541
+ "Ministral3PreTrainedModel",
542
+ "Ministral3ForSequenceClassification",
543
+ "Ministral3ForTokenClassification",
544
+ ]
modeling_ministral_dlm.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Callable, Optional, Tuple, Union
3
+ import random
4
+ import os
5
+ import sys
6
+ import json
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
13
+
14
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
15
+
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+
18
+ from transformers.processing_utils import Unpack
19
+
20
+ from transformers.cache_utils import Cache, DynamicCache
21
+
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from transformers.generation import GenerationMixin
25
+
26
+ import math
27
+
28
+ from .chat_utils import generate_with_prefix_cache_block_diff
29
+ from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
30
+ from .configuration_ministral_dlm import MinistralDLMConfig
31
+
32
+ # @torch.compile(dynamic=True, mode="reduce-overhead")
33
+ # @torch.compile(mode="default")
34
+ # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
35
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
36
+ def fused_flex_attention(q, k, v, block_mask=None):
37
+ return flex_attention(q, k, v, block_mask=block_mask)
38
+
39
+ # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
40
+ class MinistralFlexAttention(Ministral3Attention):
41
+ def __init__(self, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+
44
+ self.max_seq_length = self.config.seq_length
45
+ self.block_size_orig = self.config.block_size
46
+
47
+ if self.config.dlm_paradigm == 'bidirectional':
48
+ self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
49
+ elif self.config.dlm_paradigm == 'autoregressive':
50
+ self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
51
+ elif self.config.dlm_paradigm == 'block_diff':
52
+ self.block_diff_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size_orig)
53
+ elif self.config.dlm_paradigm == 'sbd_block_diff':
54
+ self.sbd_block_diff_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size_orig)
55
+ else:
56
+ raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
57
+
58
+ self.block_size = self.block_size_orig
59
+ self.mode = self.config.dlm_paradigm
60
+
61
+ import torch._dynamo.config as dcfg
62
+ dcfg.cache_size_limit = 512
63
+
64
+
65
+ def set_attention_mode(self, mode, block_size=None):
66
+ self.mode = mode
67
+ self.block_size = block_size
68
+
69
+ def compute_block_mask(self, mode, q_len=None, block_size=None):
70
+
71
+ def bidirectional_mask(b, h, q, kv):
72
+ return (q >= kv) | (q < kv)
73
+
74
+ def autoregressive_mask(b, h, q, kv):
75
+ return (q >= kv)
76
+
77
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
78
+ """
79
+ Constructs the specialized block diffusion attention mask for training
80
+ composed of three masks:
81
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
82
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
83
+ - **Block Causal Mask (M_BC)**: Attention to update x0
84
+
85
+ Args:
86
+ b, h: Batch and head indices (ignored for mask logic).
87
+ q_idx, kv_idx: Query and Key indices.
88
+ seq_len: Total sequence length.
89
+ block_size: Defines the block structure.
90
+
91
+ Returns:
92
+ A boolean attention mask.
93
+ """
94
+
95
+ # Indicate whether token belongs to xt or x0
96
+ x0_flag_q = (q_idx >= n)
97
+ x0_flag_kv = (kv_idx >= n)
98
+
99
+ # Compute block indices
100
+ block_q = torch.where(x0_flag_q == 1,
101
+ (q_idx - n) // block_size,
102
+ q_idx // block_size)
103
+ block_kv = torch.where(x0_flag_kv == 1,
104
+ (kv_idx - n) // block_size,
105
+ kv_idx // block_size)
106
+
107
+ # **1. Block Diagonal Mask (M_BD) **
108
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
109
+
110
+ # **2. Offset Block-Causal Mask (M_OBC) **
111
+ offset_block_causal = (
112
+ (block_q > block_kv)
113
+ & (x0_flag_kv == 1)
114
+ & (x0_flag_q == 0)
115
+ )
116
+
117
+ # **3. Block-Causal Mask (M_BC) **
118
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
119
+
120
+ # **4. Combine Masks **
121
+ return block_diagonal | offset_block_causal | block_causal
122
+
123
+
124
+ def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
125
+ x0_flag_q = (q_idx >= n)
126
+ x0_flag_kv = (kv_idx >= n)
127
+
128
+ # Compute block indices
129
+ block_q = torch.where(x0_flag_q == 1,
130
+ (q_idx - n) // block_size,
131
+ q_idx // block_size)
132
+ block_kv = torch.where(x0_flag_kv == 1,
133
+ (kv_idx - n) // block_size,
134
+ kv_idx // block_size)
135
+
136
+ # **1. Block Diagonal Mask (M_BD) **
137
+ block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
138
+
139
+ # **2. Offset Block-Causal Mask (M_OBC) **
140
+ offset_block_causal = (
141
+ (block_q > block_kv)
142
+ & (x0_flag_kv == 1)
143
+ & (x0_flag_q == 0)
144
+ )
145
+
146
+ # **3. Fully Causal Mask (M_BC) **
147
+ fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
148
+
149
+ # **4. Combine Masks **
150
+ return block_diagonal | offset_block_causal | fully_causal
151
+
152
+ if mode == 'bidirectional':
153
+ attn_mask = bidirectional_mask
154
+ elif mode == 'autoregressive':
155
+ attn_mask = autoregressive_mask
156
+ elif mode == 'block_diff':
157
+ assert block_size is not None
158
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
159
+ elif mode == 'sbd_block_diff':
160
+ assert block_size is not None
161
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, self.max_seq_length)
162
+ else:
163
+ raise ValueError(f"Unknown attention mode: {mode}")
164
+
165
+ if q_len is not None:
166
+ Q_LEN = q_len
167
+ else:
168
+ if mode in ['block_diff', 'sbd_block_diff']:
169
+ Q_LEN = self.max_seq_length * 2
170
+ else:
171
+ Q_LEN = self.max_seq_length
172
+
173
+ block_mask = create_block_mask(
174
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
175
+ )
176
+
177
+ return block_mask
178
+
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
184
+ attention_mask: Optional[torch.Tensor],
185
+ past_key_values: Optional[Cache] = None,
186
+ cache_position: Optional[torch.LongTensor] = None,
187
+ is_training: bool = True,
188
+ **kwargs: Unpack[FlashAttentionKwargs],
189
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
190
+ bsz, q_len, _ = hidden_states.size()
191
+ input_shape = hidden_states.shape[:-1]
192
+ hidden_shape = (*input_shape, -1, self.head_dim)
193
+
194
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
195
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
196
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
197
+
198
+ cos, sin = position_embeddings
199
+
200
+ if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
201
+ # Split query and key states in half along sequence length dimension
202
+ q1, q2 = query_states.chunk(2, dim=2)
203
+ k1, k2 = key_states.chunk(2, dim=2)
204
+
205
+ # Apply RoPE independently to each half
206
+ q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
207
+ q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
208
+
209
+ # Recombine the halves
210
+ query_states = torch.cat([q1, q2], dim=2)
211
+ key_states = torch.cat([k1, k2], dim=2)
212
+ else:
213
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
214
+
215
+ query_states = query_states * _get_llama_4_attn_scale(
216
+ cache_position,
217
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
218
+ self.config.rope_parameters.get("original_max_position_embeddings"),
219
+ ).to(query_states.dtype)
220
+
221
+ if past_key_values is not None:
222
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
223
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
224
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
225
+
226
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
227
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
228
+
229
+ if self.mode == 'bidirectional':
230
+ if q_len != self.bidirectional_mask.shape[-2]:
231
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
232
+ else:
233
+ block_mask = self.bidirectional_mask
234
+
235
+ elif self.mode == 'autoregressive':
236
+ if q_len != self.autoregressive_mask.shape[-2]:
237
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
238
+ else:
239
+ block_mask = self.autoregressive_mask
240
+
241
+ elif self.mode == 'block_diff':
242
+ if self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
243
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
244
+ else:
245
+ block_mask = self.block_diff_mask
246
+ elif self.mode == 'sbd_block_diff':
247
+ if self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
248
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
249
+ else:
250
+ block_mask = self.sbd_block_diff_mask
251
+ else:
252
+ raise ValueError(f"Unknown attention mode: {self.mode}")
253
+
254
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
255
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
256
+
257
+ attn_output = self.o_proj(attn_output)
258
+
259
+ return attn_output, None
260
+
261
+
262
+ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
263
+ """Return a Bool mask of length len(log_w) with exactly k True."""
264
+ g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
265
+ topk = torch.topk(log_w + g, k).indices
266
+ mask = torch.zeros_like(log_w, dtype=torch.bool)
267
+ mask[topk] = True
268
+ return mask
269
+
270
+
271
+ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
272
+ """
273
+ A single model with:
274
+ - a bidirectional encoder + diffusion‐LM head over A
275
+ - a causal decoder + LM head over B, conditioned on F_A
276
+ """
277
+
278
+ def __init__(self, config: MinistralDLMConfig):
279
+ super().__init__(config)
280
+
281
+ self.mask_token_id = config.mask_token_id
282
+
283
+ diffusion_config = copy.deepcopy(config)
284
+ diffusion_config.diffusion_lm = True
285
+
286
+ if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
287
+ diffusion_config.attn_class = MinistralFlexAttention
288
+ elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
289
+ diffusion_config.attn_class = Ministral3Attention
290
+
291
+ if config.dlm_paradigm == 'autoregressive':
292
+ diffusion_config.diffusion_lm = False
293
+ else:
294
+ raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
295
+
296
+ self.encoder = Ministral3Model(diffusion_config)
297
+ self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
298
+ self.vocab_size = config.vocab_size
299
+
300
+ self.current_iter_ratio = None
301
+
302
+ self.post_init()
303
+
304
+
305
+ def get_input_embeddings(self):
306
+ return self.encoder.embed_tokens
307
+
308
+ def set_input_embeddings(self, value):
309
+ self.encoder.embed_tokens = value
310
+
311
+ def get_output_embeddings(self):
312
+ return self.diffusion_head
313
+
314
+ def set_output_embeddings(self, new_embeddings):
315
+ self.diffusion_head = new_embeddings
316
+
317
+
318
+ def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
319
+ b, l = input_ids.shape
320
+ device = input_ids.device
321
+
322
+ if self.config.dp_varying_mask_ratio:
323
+ # Enable different random seeds for each DP rank during sampling
324
+ import torch.distributed as dist
325
+ dp_rank = 0
326
+ if dist.is_initialized():
327
+ try:
328
+ dp_rank = dist.get_rank()
329
+ except Exception:
330
+ dp_rank = 0
331
+ # Use a local generator to avoid affecting global RNG state
332
+ generator = torch.Generator(device=device)
333
+ generator.manual_seed(torch.seed() + dp_rank)
334
+ else:
335
+ generator = None
336
+
337
+ if self.config.adaptive_mask_rate:
338
+ assert block_size is not None
339
+
340
+ # --- simple linear window mapping ---
341
+ bs_min = getattr(self.config, "t_bs_min", 16)
342
+ bs_max = getattr(self.config, "t_bs_max", 128)
343
+ w = getattr(self.config, "t_window_width", 0.6) # fixed width
344
+
345
+ # fraction in [0,1] (unclamped first)
346
+ frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
347
+ # upper bound decreases linearly from 1.0 -> 0.5
348
+ u_max = 1.0 - w * frac
349
+ # clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
350
+ u_max = max(0.6, min(1.0, u_max))
351
+ u_min = u_max - w # ensures width = w
352
+
353
+ # sample t ~ Uniform(u_min, u_max)
354
+ t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
355
+ else:
356
+ t = torch.rand(b, device=device, generator=generator)
357
+
358
+ p_mask = (1 - eps) * t + eps # shape: (b,)
359
+ p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
360
+
361
+ masked_indices = torch.rand((b, l), device=device) < p_mask
362
+
363
+ if loss_mask is not None:
364
+ masked_indices[loss_mask == 0] = 0
365
+
366
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
367
+
368
+ return noisy_batch, masked_indices, p_mask
369
+
370
+
371
+ def forward_process_exp(
372
+ self,
373
+ input_ids: torch.Tensor,
374
+ eps: float = 1e-3,
375
+ block_size: int | None = None,
376
+ half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
377
+ loss_mask: Optional[torch.Tensor] = None,
378
+ ):
379
+ """
380
+ Two-stage corruption with optional per-block sampling.
381
+
382
+ • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
383
+ • Stage 2: sample exactly k positions with weights
384
+ w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
385
+ uniform when m→1).
386
+
387
+ If `block_size` is given, the procedure is run *independently*
388
+ inside each contiguous block of that length (last block may be shorter).
389
+ When block_size is provided, m is sampled per-block and p_mask is per-block.
390
+
391
+ Args
392
+ ----
393
+ input_ids : (B, L) LongTensor
394
+ eps : minimum corruption ratio
395
+ block_size: if not None, operate block-wise with per-block m sampling
396
+ half_life_ratio : controls steepness when m→0
397
+ """
398
+ B, L = input_ids.shape
399
+ device = input_ids.device
400
+ dtype = torch.float32
401
+
402
+ masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
403
+ p_mask = torch.zeros((B, L), dtype=dtype, device=device)
404
+
405
+ # ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
406
+ for b in range(B):
407
+ if block_size is None:
408
+ # ---------- Per-batch sampling (original behavior) ----------
409
+ m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
410
+ k_tot = int(round(m * L))
411
+ k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
412
+
413
+ # Fill p_mask for this batch
414
+ p_mask[b, :] = m
415
+
416
+ slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
417
+
418
+ # ------- single pool over the whole sentence -------------
419
+ lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
420
+
421
+ pos = torch.arange(L, device=device, dtype=dtype)
422
+ log_w = (lam_base * slope * pos).clone()
423
+
424
+ masked_indices[b] = gumbel_topk(log_w, k_tot)
425
+
426
+ else:
427
+ # ---------- Per-block sampling ----------
428
+ num_blocks = math.ceil(L / block_size)
429
+ lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
430
+
431
+ for blk in range(num_blocks):
432
+ start = blk * block_size
433
+ end = min((blk + 1) * block_size, L)
434
+ blk_len = end - start
435
+
436
+ # Sample m per block
437
+ m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
438
+
439
+ # Fill p_mask for this block
440
+ p_mask[b, start:end] = m_blk
441
+
442
+ # per-block budget
443
+ k_blk = int(round(m_blk * blk_len))
444
+ k_blk = max(0, min(k_blk, blk_len))
445
+ if k_blk == 0:
446
+ continue
447
+
448
+ slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
449
+
450
+ pos = torch.arange(blk_len, device=device, dtype=dtype)
451
+ log_w = lam_base * slope * pos
452
+
453
+ blk_mask = gumbel_topk(log_w, k_blk)
454
+ masked_indices[b, start:end] = blk_mask
455
+
456
+ if loss_mask is not None:
457
+ masked_indices[loss_mask == 0] = 0
458
+
459
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
460
+ return noisy_batch, masked_indices, p_mask
461
+
462
+
463
+ def forward(
464
+ self,
465
+ input_ids: torch.LongTensor,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ position_ids: Optional[torch.LongTensor] = None,
468
+ labels: Optional[torch.LongTensor] = None,
469
+ split_len: Optional[int] = None,
470
+ past_key_values: Optional[Cache] = None,
471
+ block_size: Optional[int] = None,
472
+ block_diff_ppl: bool = False,
473
+ eps: float = 1e-3,
474
+ is_teacher: bool = False,
475
+ masked_indices: Optional[torch.Tensor] = None,
476
+ p_mask: Optional[torch.Tensor] = None,
477
+ teacher_logits: Optional[torch.Tensor] = None,
478
+ masked_indices_teacher: Optional[torch.Tensor] = None,
479
+ loss_mask: Optional[torch.Tensor] = None,
480
+ ce_loss_weight: float = 1.0,
481
+ output_last_hidden_states_only: bool = False,
482
+ **kwargs,
483
+ ) -> CausalLMOutputWithPast:
484
+
485
+ batch_size, seq_len = input_ids.shape
486
+
487
+ if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
488
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
489
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
490
+ input_ids = input_ids[:, :random_length]
491
+ labels = labels[:, :random_length]
492
+
493
+ if attention_mask is not None:
494
+ attention_mask = attention_mask[:, :random_length]
495
+ if position_ids is not None:
496
+ position_ids = position_ids[:, :random_length]
497
+ if loss_mask is not None:
498
+ loss_mask = loss_mask[:, :random_length]
499
+
500
+ elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
501
+ if labels is not None and block_size is None:
502
+ if torch.rand(1) < self.config.random_length_prob:
503
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
504
+ else:
505
+ block_size = self.config.block_size
506
+
507
+ else:
508
+ raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
509
+
510
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
511
+ if masked_indices is not None:
512
+ assert p_mask is not None
513
+
514
+ if loss_mask is not None:
515
+ masked_indices[loss_mask == 0] = 0
516
+
517
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
518
+
519
+ else:
520
+ if self.config.tok_mask_half_life_ratio is not None:
521
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
522
+ else:
523
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
524
+
525
+ else:
526
+ noisy_inputs = input_ids
527
+ masked_indices = None
528
+ p_mask = None
529
+
530
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
531
+ for layer in self.encoder.layers:
532
+ if hasattr(layer.self_attn, 'set_attention_mode'):
533
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
534
+
535
+ input_ids_len = noisy_inputs.shape[1]
536
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
537
+ if position_ids is None:
538
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
539
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
540
+
541
+ if block_diff_ppl:
542
+ if position_ids is None:
543
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
544
+
545
+ enc_out = self.encoder(
546
+ past_key_values=past_key_values,
547
+ input_ids=noisy_inputs,
548
+ attention_mask=attention_mask,
549
+ position_ids=position_ids,
550
+ is_training=(labels is not None) or (block_diff_ppl),
551
+ **kwargs,
552
+ )
553
+
554
+ if output_last_hidden_states_only:
555
+ return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
556
+
557
+ logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
558
+
559
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
560
+ if self.config.dlm_paradigm == 'sbd_block_diff':
561
+ causal_logits = logits[:, input_ids_len:]
562
+ else:
563
+ causal_logits = None
564
+
565
+ logits = logits[:, :input_ids_len]
566
+
567
+ loss = None
568
+ if labels is not None:
569
+ if self.config.dlm_paradigm == 'autoregressive':
570
+ shift_logits = logits[..., :-1, :].contiguous()
571
+ shift_labels = labels[..., 1:].contiguous()
572
+
573
+ if loss_mask is None:
574
+ loss_fct = CrossEntropyLoss()
575
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
576
+ shift_labels = shift_labels.view(-1)
577
+ loss = loss_fct(shift_logits, shift_labels)
578
+
579
+ else:
580
+ loss_mask = loss_mask[..., 1:].contiguous()
581
+
582
+ loss_fct = CrossEntropyLoss(reduction='none')
583
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
584
+ shift_labels = shift_labels.view(-1)
585
+ shift_labels = shift_labels.to(shift_logits.device)
586
+
587
+ token_losses = loss_fct(shift_logits, shift_labels)
588
+
589
+ flat_loss_mask = loss_mask.reshape(-1)
590
+ loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
591
+
592
+ else:
593
+ # Handle DREAM vs LLADA style losses
594
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
595
+ logits = logits[..., :-1, :].contiguous()
596
+ labels = labels[..., 1:].contiguous()
597
+ masked_indices = masked_indices[:, 1:]
598
+ p_mask = p_mask[:, 1:]
599
+
600
+ if self.config.ada_perm_ratio_per_block is not None:
601
+ # Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
602
+ block_size = self.config.block_size
603
+ batch_size, seq_len = masked_indices.shape
604
+ num_blocks = seq_len // block_size
605
+
606
+ # Get the max logit (confidence) for each position
607
+ confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
608
+
609
+ # Create a mask for tokens to include in loss
610
+ selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
611
+
612
+ for blk in range(num_blocks):
613
+ start = blk * block_size
614
+ end = min((blk + 1) * block_size, seq_len)
615
+
616
+ # Get masked indices within this block
617
+ block_masked = masked_indices[:, start:end] # (batch_size, block_len)
618
+ block_confidence = confidence[:, start:end] # (batch_size, block_len)
619
+
620
+ for b in range(batch_size):
621
+ # Get positions that are masked in this block for this batch
622
+ masked_positions = torch.where(block_masked[b])[0]
623
+ num_masked = len(masked_positions)
624
+
625
+ if num_masked > 0:
626
+ # Number of tokens to keep (top by confidence)
627
+ k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
628
+
629
+ # Get confidence values for masked positions
630
+ masked_confidence = block_confidence[b, masked_positions]
631
+
632
+ # Get indices of top-k confident tokens
633
+ _, topk_indices = torch.topk(masked_confidence, k)
634
+ selected_positions = masked_positions[topk_indices]
635
+
636
+ # Mark these positions in the selected mask
637
+ selected_mask[b, start + selected_positions] = True
638
+
639
+ # Calculate loss only for selected positions
640
+ token_loss = torch.nn.functional.cross_entropy(
641
+ logits[selected_mask],
642
+ labels[selected_mask],
643
+ reduction='none'
644
+ ) / p_mask[selected_mask]
645
+
646
+ num_mask_tokens = selected_mask.sum()
647
+
648
+ else:
649
+ # Calculate token-wise cross entropy loss for masked positions in B
650
+ token_loss = torch.nn.functional.cross_entropy(
651
+ logits[masked_indices],
652
+ labels[masked_indices],
653
+ reduction='none'
654
+ ) / p_mask[masked_indices]
655
+
656
+ num_mask_tokens = masked_indices.sum()
657
+
658
+ if self.config.global_loss_avg:
659
+ loss = token_loss.sum()
660
+ else:
661
+ loss = token_loss.sum() / num_mask_tokens
662
+
663
+ if self.config.ada_dlm_loss_ratio is not None:
664
+ assert self.current_iter_ratio is not None
665
+ assert self.config.dlm_loss_weight is not None
666
+
667
+ dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
668
+ loss = dlm_loss_weight * loss
669
+
670
+ elif self.config.dlm_loss_weight is not None:
671
+ loss = self.config.dlm_loss_weight * loss
672
+
673
+ if self.config.dlm_paradigm == 'sbd_block_diff':
674
+ causal_logits = causal_logits[..., :-1, :].contiguous()
675
+ causal_logits = causal_logits.view(-1, causal_logits.size(-1))
676
+
677
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
678
+ causal_labels = labels.view(-1)
679
+ else:
680
+ causal_labels = labels[..., 1:].contiguous().view(-1)
681
+
682
+ if self.config.global_loss_avg:
683
+ loss_fct = CrossEntropyLoss(reduction='sum')
684
+ ar_loss = loss_fct(causal_logits, causal_labels)
685
+
686
+ self.loss_diffusion = loss.detach().item() / num_mask_tokens
687
+ self.loss_ar = ar_loss.detach().item() / seq_len
688
+
689
+ loss = loss + self.config.ar_loss_weight * ar_loss
690
+ else:
691
+ loss_fct = CrossEntropyLoss()
692
+ ar_loss = loss_fct(causal_logits, causal_labels)
693
+
694
+ self.loss_diffusion = loss.detach().item()
695
+ self.loss_ar = ar_loss.detach().item()
696
+
697
+ loss = loss + self.config.ar_loss_weight * ar_loss
698
+
699
+ if self.config.global_loss_avg:
700
+ if self.config.dlm_paradigm == 'sbd_block_diff':
701
+ loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
702
+ else:
703
+ loss = (loss, num_mask_tokens)
704
+
705
+ return CausalLMOutputWithPast(
706
+ loss=loss if not is_teacher else logits,
707
+ logits=logits,
708
+ past_key_values=enc_out.past_key_values,
709
+ hidden_states=None,
710
+ attentions=None,
711
+ )
712
+
713
+
714
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=False, temperature=0):
715
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
716
+ model=self,
717
+ prompt=prompt_ids,
718
+ gen_length=max_new_tokens,
719
+ steps=steps,
720
+ block_length=block_length,
721
+ remasking="low_confidence",
722
+ temperature=temperature,
723
+ mask_id=self.mask_token_id,
724
+ threshold=threshold,
725
+ shift_logits=shift_logits,
726
+ neg_entropy=False,
727
+ causal_context=causal_context,
728
+ )
729
+
730
+ return out_ids, nfe
731
+
732
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
733
+