YongganFu commited on
Commit
21e8ae0
·
verified ·
1 Parent(s): 06bb923

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
chat_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ '''
8
+ The Gumbel max is a method for sampling categorical distributions.
9
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
+ Thus, we use float64.
11
+ '''
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
+ x0 = torch.argmax(logits_with_noise, dim=-1)
23
+
24
+ if remasking == 'low_confidence':
25
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
26
+ p = F.softmax(logits, dim=-1)
27
+ x0_p = torch.squeeze(
28
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
+ elif remasking == 'top_p_margin':
30
+ # Compute probabilities
31
+ p = F.softmax(logits, dim=-1) # (B, L, V)
32
+ # Top-2 per position
33
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
35
+
36
+ # Normalize margin to [0,1] over MASKED positions per row
37
+ plus_inf = torch.full_like(margin, float('inf'))
38
+ minus_inf = torch.full_like(margin, float('-inf'))
39
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
40
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
41
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
42
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
43
+ denom = (row_max - row_min)
44
+
45
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
46
+ normalized = torch.zeros_like(margin)
47
+ nonzero = denom > 0
48
+ normalized = torch.where(
49
+ mask_index & nonzero,
50
+ (margin - row_min) / (denom + 1e-12),
51
+ normalized
52
+ )
53
+ normalized = torch.where(
54
+ mask_index & (~nonzero),
55
+ torch.ones_like(normalized),
56
+ normalized
57
+ )
58
+ x0_p = normalized # ∈ [0,1] on masked positions
59
+ elif remasking == 'random':
60
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
61
+ else:
62
+ raise NotImplementedError(remasking)
63
+
64
+ # Calculate negative entropy if requested
65
+ if neg_entropy:
66
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
67
+ p = F.softmax(logits, dim=-1)
68
+ epsilon = 1e-10
69
+ log_probs = torch.log(p + epsilon)
70
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
71
+ else:
72
+ confidence_scores = x0_p
73
+
74
+ x0 = torch.where(mask_index, x0, x)
75
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
76
+
77
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
78
+ if threshold is not None:
79
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
80
+ # print(f'confidence: {confidence}')
81
+ for j in range(confidence.shape[0]):
82
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
83
+ transfer_index[j, select_index] = True
84
+ if threshold is not None:
85
+ for k in range(1, num_transfer_tokens[j]):
86
+ if confidence[j, select_index[k]] < threshold:
87
+ transfer_index[j, select_index[k]] = False
88
+ return x0, transfer_index
89
+
90
+
91
+ def get_num_transfer_tokens(mask_index, steps: int):
92
+ mask_num = mask_index.sum(dim=1, keepdim=True)
93
+ base = mask_num // steps
94
+ remainder = mask_num % steps
95
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
96
+ for i in range(mask_num.size(0)):
97
+ num_transfer_tokens[i, : int(remainder[i])] += 1
98
+ return num_transfer_tokens
99
+
100
+
101
+ @torch.no_grad()
102
+ def generate_with_prefix_cache_block_diff(
103
+ model,
104
+ prompt,
105
+ steps=128,
106
+ gen_length=128,
107
+ block_length=128,
108
+ temperature=0.,
109
+ remasking='low_confidence',
110
+ mask_id=126336,
111
+ threshold=None,
112
+ factor=None,
113
+ shift_logits=False,
114
+ neg_entropy=False,
115
+ causal_context=False,
116
+ ):
117
+ dream_style=shift_logits
118
+ # Initialize the accumulator
119
+ x_accum = prompt.clone()
120
+
121
+ assert gen_length % block_length == 0
122
+ num_blocks = gen_length // block_length
123
+
124
+ assert steps % num_blocks == 0
125
+ steps_per_block = steps // num_blocks
126
+
127
+ nfe = 0
128
+
129
+ if causal_context:
130
+ model_module = model.module if hasattr(model, "module") else model
131
+ for layer in model_module.encoder.layers:
132
+ if hasattr(layer.self_attn, 'diffusion_lm'):
133
+ layer.self_attn.diffusion_lm=False
134
+
135
+ # Compute KV cache for the prompt initially
136
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context)
137
+ past_key_values = output.past_key_values
138
+
139
+ if causal_context:
140
+ for layer in model_module.encoder.layers:
141
+ if hasattr(layer.self_attn, 'diffusion_lm'):
142
+ layer.self_attn.diffusion_lm=True
143
+
144
+ # For dream_style: store the "next token logit" of the context
145
+ next_logits_context = None
146
+ if dream_style:
147
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
148
+
149
+ for num_block in range(num_blocks):
150
+ # Create a new block with mask tokens (no seeding)
151
+ mask_block = torch.ones(
152
+ (prompt.shape[0], block_length),
153
+ dtype=prompt.dtype,
154
+ device=prompt.device
155
+ ) * mask_id
156
+
157
+ # Append the block of masks
158
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
159
+ current_block_start = prompt.size(1) + num_block * block_length
160
+ block_slice = slice(current_block_start, current_block_start + block_length)
161
+
162
+ # Build the initial mask for this block
163
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
164
+
165
+ # Precompute the transfer schedule for this block
166
+ if dream_style:
167
+ # still denoise *all* positions (0..Lb-1), since none are seeded
168
+ schedule_mask = mask_block_idx0
169
+ else:
170
+ schedule_mask = mask_block_idx0
171
+
172
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
173
+
174
+ # Denoise the current block
175
+ for i in range(steps_per_block):
176
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
177
+ if mask_block_idx.sum() == 0:
178
+ break
179
+
180
+ nfe += 1
181
+
182
+ # Forward only the current noisy block using cached context
183
+ logits_block = model(
184
+ x_accum[:, block_slice],
185
+ past_key_values=past_key_values,
186
+ use_cache=False
187
+ ).logits
188
+
189
+ if dream_style:
190
+ # Align logits so that each masked position has a predictor:
191
+ # prepend context-next logit, then use logits_block[:-1]
192
+ if block_length == 1:
193
+ logits_use = next_logits_context # (B, 1, V)
194
+ else:
195
+ logits_use = torch.cat(
196
+ [next_logits_context, logits_block[:, :-1, :]],
197
+ dim=1
198
+ ) # (B, Lb, V)
199
+
200
+ mask_use = mask_block_idx # (B, Lb)
201
+ x_use = x_accum[:, block_slice] # (B, Lb)
202
+
203
+ x0, transfer_idx = get_transfer_index(
204
+ logits_use, temperature, remasking, mask_use, x_use,
205
+ num_transfer_tokens=num_transfer_tokens[:, i],
206
+ threshold=threshold, neg_entropy=neg_entropy
207
+ )
208
+ cur = x_accum[:, block_slice].clone()
209
+ cur[transfer_idx] = x0[transfer_idx]
210
+ x_accum[:, block_slice] = cur
211
+
212
+ else:
213
+ # non-AR (same-position) case
214
+ x0, transfer_idx = get_transfer_index(
215
+ logits_block, temperature, remasking, mask_block_idx,
216
+ x_accum[:, block_slice],
217
+ num_transfer_tokens=num_transfer_tokens[:, i],
218
+ threshold=threshold, neg_entropy=neg_entropy
219
+ )
220
+ cur = x_accum[:, block_slice].clone()
221
+ cur[transfer_idx] = x0[transfer_idx]
222
+ x_accum[:, block_slice] = cur
223
+
224
+ if causal_context:
225
+ for layer in model_module.encoder.layers:
226
+ if hasattr(layer.self_attn, 'diffusion_lm'):
227
+ layer.self_attn.diffusion_lm=False
228
+
229
+ # after block is fully denoised, update KV cache
230
+ output = model(
231
+ x_accum[:, block_slice],
232
+ past_key_values=past_key_values,
233
+ use_cache=True,
234
+ use_causal_mask=causal_context
235
+ )
236
+ past_key_values = output.past_key_values
237
+
238
+ if causal_context:
239
+ for layer in model_module.encoder.layers:
240
+ if hasattr(layer.self_attn, 'diffusion_lm'):
241
+ layer.self_attn.diffusion_lm=True
242
+
243
+ if dream_style and num_block < num_blocks - 1:
244
+ # refresh context-next logit for the next block
245
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
246
+
247
+ return x_accum, nfe
config.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ada_dlm_loss_ratio": null,
3
+ "ada_perm_ratio_global": null,
4
+ "ada_perm_ratio_per_block": null,
5
+ "adaptive_mask_rate": false,
6
+ "ar_loss_weight": 1.0,
7
+ "architectures": [
8
+ "MinistralDiffEncoderModel"
9
+ ],
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "attn_implementation": "sdpa",
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
15
+ "AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
16
+ },
17
+ "block_size": 32,
18
+ "bos_token_id": 1,
19
+ "diff_loss_weight": 1,
20
+ "dlm_arch": "encoder",
21
+ "dlm_loss_weight": null,
22
+ "dlm_paradigm": "bidirectional",
23
+ "dlm_type": "llada",
24
+ "dp_varying_mask_ratio": false,
25
+ "enforce_mask": false,
26
+ "eos_token_id": 2,
27
+ "global_loss_avg": false,
28
+ "head_dim": 128,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 3072,
31
+ "initializer_range": 0.02,
32
+ "intermediate_size": 9216,
33
+ "mask_token_id": 100,
34
+ "max_position_embeddings": 262144,
35
+ "mlp_bias": false,
36
+ "model_type": "ministral_dlm",
37
+ "multi_sampling": null,
38
+ "num_ar_layers": 0,
39
+ "num_attention_heads": 32,
40
+ "num_diffusion_layers": 0,
41
+ "num_hidden_layers": 26,
42
+ "num_key_value_heads": 8,
43
+ "num_skip_loss_tokens": 0,
44
+ "prefix_ratio": 0.8,
45
+ "random_length_prob": 0,
46
+ "rms_norm_eps": 1e-05,
47
+ "rope_parameters": {
48
+ "beta_fast": 32.0,
49
+ "beta_slow": 1.0,
50
+ "factor": 16.0,
51
+ "llama_4_scaling_beta": 0.1,
52
+ "mscale": 1.0,
53
+ "mscale_all_dim": 1.0,
54
+ "original_max_position_embeddings": 16384,
55
+ "rope_theta": 1000000.0,
56
+ "rope_type": "yarn",
57
+ "type": "yarn"
58
+ },
59
+ "rope_scaling": {
60
+ "beta_fast": 32.0,
61
+ "beta_slow": 1.0,
62
+ "factor": 16.0,
63
+ "llama_4_scaling_beta": 0.1,
64
+ "mscale": 1.0,
65
+ "mscale_all_dim": 1.0,
66
+ "original_max_position_embeddings": 16384,
67
+ "rope_theta": 1000000.0,
68
+ "rope_type": "yarn",
69
+ "type": "yarn"
70
+ },
71
+ "rope_theta": 1000000.0,
72
+ "sliding_window": null,
73
+ "tie_word_embeddings": false,
74
+ "tok_mask_half_life_ratio": null,
75
+ "torch_dtype": "bfloat16",
76
+ "transformers_version": "4.55.4",
77
+ "use_cache": false,
78
+ "vocab_size": 131072
79
+ }
configuration_ministral_dlm.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Ministral DLM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MinistralDLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
+ It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 131072):
35
+ Vocabulary size of the Ministral model.
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 14336):
39
+ Dimension of the MLP representations.
40
+ num_hidden_layers (`int`, *optional*, defaults to 34):
41
+ Number of hidden layers in the Transformer decoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 32):
43
+ Number of attention heads for each attention layer.
44
+ num_key_value_heads (`int`, *optional*, defaults to 8):
45
+ Number of key_value heads for Grouped Query Attention.
46
+ head_dim (`int`, *optional*, defaults to 128):
47
+ The attention head dimension.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function.
50
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
51
+ The maximum sequence length.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions.
58
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
59
+ Whether the model's input and output word embeddings should be tied.
60
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
61
+ The base period of the RoPE embeddings.
62
+ rope_parameters (`Dict`, *optional*):
63
+ Dictionary containing the scaling configuration for the RoPE embeddings.
64
+ Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
65
+ attention_bias (`bool`, defaults to `False`):
66
+ Whether to use a bias in the query, key, value and output projection layers.
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ mlp_bias (`bool`, *optional*, defaults to `False`):
70
+ Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
+ sliding_window (`int`, *optional*, defaults to None):
72
+ Sliding window attention size.
73
+ mask_token_id (`int`, *optional*, defaults to -1):
74
+ Token ID for masking in diffusion.
75
+ dlm_type (`str`, *optional*, defaults to 'llada'):
76
+ Type of diffusion language model ('llada', 'dream').
77
+ random_length_prob (`float`, *optional*):
78
+ Probability of using random lengths during training.
79
+ num_ar_layers (`int`, *optional*, defaults to 0):
80
+ Number of autoregressive layers.
81
+ num_diffusion_layers (`int`, *optional*, defaults to 0):
82
+ Number of diffusion layers.
83
+ diff_loss_weight (`float`, *optional*, defaults to 1):
84
+ Weight for diffusion loss.
85
+ enforce_mask (`bool`, *optional*, defaults to False):
86
+ Whether to enforce masking.
87
+ prefix_ratio (`float`, *optional*, defaults to 0.8):
88
+ Ratio for prefix in prefix_bidirectional mode.
89
+ dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
90
+ Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
91
+ dlm_arch (`str`, *optional*, defaults to 'encoder'):
92
+ Architecture type ('encoder', 'encoder_decoder').
93
+ block_size (`int`, *optional*, defaults to 32):
94
+ Block size for block diffusion paradigms.
95
+ tok_mask_half_life_ratio (`float`, *optional*):
96
+ Half-life ratio for token masking.
97
+ adaptive_mask_rate (`bool`, *optional*, defaults to False):
98
+ Whether to use adaptive mask rate.
99
+ multi_sampling (`int`, *optional*):
100
+ Number of samples for multi-sampling.
101
+ num_skip_loss_tokens (`int`, *optional*, defaults to 0):
102
+ Number of tokens to skip in loss calculation.
103
+ dlm_loss_weight (`float`, *optional*):
104
+ Weight for diffusion LM loss.
105
+ ar_loss_weight (`float`, *optional*, defaults to 1.0):
106
+ Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
107
+ global_loss_avg (`bool`, *optional*, defaults to False):
108
+ Whether to use global loss average.
109
+ dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
110
+ Whether to use varying mask ratio for each DP rank during sampling.
111
+ ada_perm_ratio_per_block (`float`, *optional*):
112
+ Adaptive permutation ratio for each block.
113
+ ada_perm_ratio_global (`float`, *optional*):
114
+ Adaptive permutation ratio for global.
115
+ """
116
+
117
+ model_type = "ministral_dlm"
118
+ keys_to_ignore_at_inference = ["past_key_values"]
119
+
120
+ # Default tensor parallel plan for base model `Ministral`
121
+ base_model_tp_plan = {
122
+ "layers.*.self_attn.q_proj": "colwise",
123
+ "layers.*.self_attn.k_proj": "colwise",
124
+ "layers.*.self_attn.v_proj": "colwise",
125
+ "layers.*.self_attn.o_proj": "rowwise",
126
+ "layers.*.mlp.gate_proj": "colwise",
127
+ "layers.*.mlp.up_proj": "colwise",
128
+ "layers.*.mlp.down_proj": "rowwise",
129
+ }
130
+ base_model_pp_plan = {
131
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
132
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
133
+ "norm": (["hidden_states"], ["hidden_states"]),
134
+ }
135
+
136
+ def __init__(
137
+ self,
138
+ vocab_size=131072,
139
+ hidden_size=4096,
140
+ intermediate_size=14336,
141
+ num_hidden_layers=34,
142
+ num_attention_heads=32,
143
+ num_key_value_heads=8,
144
+ head_dim=128,
145
+ hidden_act="silu",
146
+ max_position_embeddings=262144,
147
+ initializer_range=0.02,
148
+ rms_norm_eps=1e-05,
149
+ use_cache=True,
150
+ pad_token_id=None,
151
+ bos_token_id=1,
152
+ eos_token_id=2,
153
+ tie_word_embeddings=False,
154
+ rope_theta=1000000.0,
155
+ rope_parameters=None,
156
+ rope_scaling=None,
157
+ attention_bias=False,
158
+ attention_dropout=0.0,
159
+ mlp_bias=False,
160
+ sliding_window=None,
161
+ attn_implementation="sdpa",
162
+ mask_token_id=-1,
163
+ dlm_type='llada',
164
+ random_length_prob=None,
165
+ num_ar_layers=0,
166
+ num_diffusion_layers=0,
167
+ diff_loss_weight=1,
168
+ enforce_mask=False,
169
+ prefix_ratio=0.8,
170
+ dlm_paradigm='bidirectional',
171
+ dlm_arch='encoder',
172
+ block_size=32,
173
+ tok_mask_half_life_ratio=None,
174
+ adaptive_mask_rate=False,
175
+ multi_sampling=None,
176
+ num_skip_loss_tokens=0,
177
+ dlm_loss_weight=None,
178
+ ar_loss_weight=1.0,
179
+ global_loss_avg=False,
180
+ dp_varying_mask_ratio=False,
181
+ ada_perm_ratio_per_block=None,
182
+ ada_perm_ratio_global=None,
183
+ ada_dlm_loss_ratio=None,
184
+ **kwargs,
185
+ ):
186
+ self.vocab_size = vocab_size
187
+ self.max_position_embeddings = max_position_embeddings
188
+ self.hidden_size = hidden_size
189
+ self.intermediate_size = intermediate_size
190
+ self.num_hidden_layers = num_hidden_layers
191
+ self.num_attention_heads = num_attention_heads
192
+
193
+ # for backward compatibility
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_attention_heads
196
+
197
+ self.num_key_value_heads = num_key_value_heads
198
+ self.head_dim = head_dim
199
+ self.hidden_act = hidden_act
200
+ self.initializer_range = initializer_range
201
+ self.rms_norm_eps = rms_norm_eps
202
+ self.use_cache = use_cache
203
+ self.rope_theta = rope_theta
204
+ self.rope_parameters = rope_parameters
205
+ self.rope_scaling = rope_scaling
206
+ self.attention_bias = attention_bias
207
+ self.attention_dropout = attention_dropout
208
+ self.mlp_bias = mlp_bias
209
+ self.sliding_window = sliding_window
210
+
211
+ rope_config_validation(self)
212
+
213
+ self.attn_implementation = attn_implementation
214
+
215
+ self.mask_token_id = mask_token_id
216
+ self.dlm_type = dlm_type
217
+ self.random_length_prob = random_length_prob
218
+ self.num_ar_layers = num_ar_layers
219
+ self.num_diffusion_layers = num_diffusion_layers
220
+ self.diff_loss_weight = diff_loss_weight
221
+ self.enforce_mask = enforce_mask
222
+ self.prefix_ratio = prefix_ratio
223
+ self.dlm_paradigm = dlm_paradigm
224
+ self.dlm_arch = dlm_arch
225
+ self.block_size = block_size
226
+ self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
227
+ self.adaptive_mask_rate = adaptive_mask_rate
228
+ self.multi_sampling = multi_sampling
229
+ self.num_skip_loss_tokens = num_skip_loss_tokens
230
+ self.dlm_loss_weight = dlm_loss_weight
231
+ self.ar_loss_weight = ar_loss_weight
232
+ self.global_loss_avg = global_loss_avg
233
+ self.dp_varying_mask_ratio = dp_varying_mask_ratio
234
+ self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
235
+ self.ada_perm_ratio_global = ada_perm_ratio_global
236
+ self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
237
+ super().__init__(
238
+ pad_token_id=pad_token_id,
239
+ bos_token_id=bos_token_id,
240
+ eos_token_id=eos_token_id,
241
+ tie_word_embeddings=tie_word_embeddings,
242
+ **kwargs,
243
+ )
244
+
245
+
246
+ __all__ = ["MinistralDLMConfig"]
247
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.55.4",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf9ef7826f4d23f1c7029720ca07bbb56d36fcfdefe93eaf42ddc14bb61f3db8
3
+ size 7663347000
modeling_ministral.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from transformers.utils.generic import check_model_inputs
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_layers import (
17
+ GenericForQuestionAnswering,
18
+ GenericForSequenceClassification,
19
+ GenericForTokenClassification,
20
+ GradientCheckpointingLayer,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
+ # from transformers.utils.generic import maybe_autocast
28
+ from .configuration_ministral_dlm import MinistralDLMConfig
29
+
30
+ #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
+
32
+ def rotate_half(x):
33
+ """Rotates half the hidden dims of the input."""
34
+ x1 = x[..., : x.shape[-1] // 2]
35
+ x2 = x[..., x.shape[-1] // 2 :]
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+ # @use_kernel_func_from_hub("rotary_pos_emb")
39
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
40
+ """Applies Rotary Position Embedding to the query and key tensors.
41
+
42
+ Args:
43
+ q (`torch.Tensor`): The query tensor.
44
+ k (`torch.Tensor`): The key tensor.
45
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
46
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
47
+ position_ids (`torch.Tensor`, *optional*):
48
+ Deprecated and unused.
49
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
50
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
51
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
52
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
53
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
54
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
55
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
56
+ Returns:
57
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
58
+ """
59
+ cos = cos.unsqueeze(unsqueeze_dim)
60
+ sin = sin.unsqueeze(unsqueeze_dim)
61
+ q_embed = (q * cos) + (rotate_half(q) * sin)
62
+ k_embed = (k * cos) + (rotate_half(k) * sin)
63
+ return q_embed, k_embed
64
+
65
+
66
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
67
+ """
68
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
69
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
70
+ """
71
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
72
+ if n_rep == 1:
73
+ return hidden_states
74
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
75
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
76
+
77
+
78
+ def eager_attention_forward(
79
+ module: nn.Module,
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor],
84
+ scaling: float,
85
+ dropout: float = 0.0,
86
+ **kwargs: Unpack[TransformersKwargs],
87
+ ):
88
+ key_states = repeat_kv(key, module.num_key_value_groups)
89
+ value_states = repeat_kv(value, module.num_key_value_groups)
90
+
91
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
92
+ if attention_mask is not None:
93
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
94
+ attn_weights = attn_weights + causal_mask
95
+
96
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
97
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
98
+ attn_output = torch.matmul(attn_weights, value_states)
99
+ attn_output = attn_output.transpose(1, 2).contiguous()
100
+
101
+ return attn_output, attn_weights
102
+
103
+
104
+ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
105
+ scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
106
+ return scaling.unsqueeze(-1)
107
+
108
+
109
+ # @use_kernelized_func(apply_rotary_pos_emb)
110
+ class Ministral3Attention(nn.Module):
111
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
112
+
113
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
114
+ super().__init__()
115
+ self.config = config
116
+ self.layer_idx = layer_idx
117
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
118
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
119
+ self.scaling = self.head_dim**-0.5
120
+ self.attention_dropout = config.attention_dropout
121
+ self.is_causal = True
122
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
123
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
124
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
125
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
126
+
127
+ self.diffusion_lm = config.diffusion_lm
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.Tensor,
132
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
133
+ attention_mask: Optional[torch.Tensor],
134
+ past_key_values: Optional[Cache] = None,
135
+ cache_position: Optional[torch.LongTensor] = None,
136
+ use_cache: Optional[bool] = False,
137
+ **kwargs: Unpack[FlashAttentionKwargs],
138
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
139
+ input_shape = hidden_states.shape[:-1]
140
+ hidden_shape = (*input_shape, -1, self.head_dim)
141
+
142
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
143
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
144
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
145
+
146
+ cos, sin = position_embeddings
147
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
148
+ query_states = query_states * _get_llama_4_attn_scale(
149
+ cache_position,
150
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
151
+ self.config.rope_parameters.get("original_max_position_embeddings"),
152
+ ).to(query_states.dtype)
153
+
154
+ if past_key_values is not None:
155
+ if use_cache:
156
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
157
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
158
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
159
+ else: ## if use_cache == False, do not update cache
160
+ old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
161
+ key_states = torch.cat([old_k, key_states], dim=-2)
162
+ value_states = torch.cat([old_v, value_states], dim=-2)
163
+
164
+ attention_interface: Callable = eager_attention_forward
165
+ if self.config._attn_implementation != "eager":
166
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
167
+
168
+ if self.diffusion_lm:
169
+ attn_output, attn_weights = attention_interface(
170
+ self,
171
+ query_states,
172
+ key_states,
173
+ value_states,
174
+ None,
175
+ dropout=0.0 if not self.training else self.attention_dropout,
176
+ scaling=self.scaling,
177
+ is_causal=False,
178
+ **kwargs,
179
+ )
180
+
181
+ else:
182
+ attn_output, attn_weights = attention_interface(
183
+ self,
184
+ query_states,
185
+ key_states,
186
+ value_states,
187
+ attention_mask,
188
+ dropout=0.0 if not self.training else self.attention_dropout,
189
+ scaling=self.scaling,
190
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
191
+ **kwargs,
192
+ )
193
+
194
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
195
+ attn_output = self.o_proj(attn_output)
196
+ return attn_output, attn_weights
197
+
198
+
199
+ class Ministral3MLP(nn.Module):
200
+ def __init__(self, config):
201
+ super().__init__()
202
+ self.config = config
203
+ self.hidden_size = config.hidden_size
204
+ self.intermediate_size = config.intermediate_size
205
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
206
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
207
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
208
+ self.act_fn = ACT2FN[config.hidden_act]
209
+
210
+ def forward(self, x):
211
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
212
+ return down_proj
213
+
214
+
215
+ @use_kernel_forward_from_hub("RMSNorm")
216
+ class Ministral3RMSNorm(nn.Module):
217
+ def __init__(self, hidden_size, eps=1e-6):
218
+ """
219
+ Ministral3RMSNorm is equivalent to T5LayerNorm
220
+ """
221
+ super().__init__()
222
+ self.weight = nn.Parameter(torch.ones(hidden_size))
223
+ self.variance_epsilon = eps
224
+
225
+ def forward(self, hidden_states):
226
+ input_dtype = hidden_states.dtype
227
+ hidden_states = hidden_states.to(torch.float32)
228
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
229
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
230
+ return self.weight * hidden_states.to(input_dtype)
231
+
232
+ def extra_repr(self):
233
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
234
+
235
+
236
+ class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: MinistralDLMConfig, layer_idx: int):
238
+ super().__init__()
239
+ self.hidden_size = config.hidden_size
240
+
241
+ if hasattr(config, 'attn_class'):
242
+ attn_class = config.attn_class
243
+ else:
244
+ attn_class = Ministral3Attention
245
+
246
+ self.self_attn = attn_class(config=config, layer_idx=layer_idx)
247
+ self.mlp = Ministral3MLP(config)
248
+ self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
+ self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ position_ids: Optional[torch.LongTensor] = None,
256
+ past_key_values: Optional[Cache] = None,
257
+ use_cache: Optional[bool] = False,
258
+ cache_position: Optional[torch.LongTensor] = None,
259
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
260
+ **kwargs: Unpack[TransformersKwargs],
261
+ ) -> torch.Tensor:
262
+ residual = hidden_states
263
+ hidden_states = self.input_layernorm(hidden_states)
264
+ # Self Attention
265
+ hidden_states, _ = self.self_attn(
266
+ hidden_states=hidden_states,
267
+ attention_mask=attention_mask,
268
+ position_ids=position_ids,
269
+ past_key_values=past_key_values,
270
+ use_cache=use_cache,
271
+ cache_position=cache_position,
272
+ position_embeddings=position_embeddings,
273
+ **kwargs,
274
+ )
275
+ hidden_states = residual + hidden_states
276
+
277
+ # Fully Connected
278
+ residual = hidden_states
279
+ hidden_states = self.post_attention_layernorm(hidden_states)
280
+ hidden_states = self.mlp(hidden_states)
281
+ hidden_states = residual + hidden_states
282
+ return hidden_states
283
+
284
+
285
+ @auto_docstring
286
+ class Ministral3PreTrainedModel(PreTrainedModel):
287
+ config: MinistralDLMConfig
288
+ base_model_prefix = "model"
289
+ supports_gradient_checkpointing = True
290
+ _no_split_modules = ["Ministral3DecoderLayer"]
291
+ _skip_keys_device_placement = ["past_key_values"]
292
+ _supports_flash_attn = True
293
+ _supports_sdpa = True
294
+ _supports_flex_attn = True
295
+
296
+ _can_compile_fullgraph = True
297
+ _supports_attention_backend = True
298
+ _can_record_outputs = {
299
+ "hidden_states": Ministral3DecoderLayer,
300
+ "attentions": Ministral3Attention,
301
+ }
302
+
303
+
304
+ class Ministral3RotaryEmbedding(nn.Module):
305
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
306
+
307
+ def __init__(self, config: MinistralDLMConfig, device=None):
308
+ super().__init__()
309
+ self.max_seq_len_cached = config.max_position_embeddings
310
+ self.original_max_seq_len = config.max_position_embeddings
311
+
312
+ self.config = config
313
+
314
+ self.rope_type = self.config.rope_parameters["rope_type"]
315
+ rope_init_fn: Callable = self.compute_default_rope_parameters
316
+ if self.rope_type != "default":
317
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
318
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
319
+
320
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
321
+ self.original_inv_freq = inv_freq
322
+
323
+
324
+ @staticmethod
325
+ def compute_default_rope_parameters(
326
+ config: Optional[MinistralDLMConfig] = None,
327
+ device: Optional["torch.device"] = None,
328
+ seq_len: Optional[int] = None,
329
+ ) -> tuple["torch.Tensor", float]:
330
+ """
331
+ Computes the inverse frequencies according to the original RoPE implementation
332
+ Args:
333
+ config ([`~transformers.PreTrainedConfig`]):
334
+ The model configuration.
335
+ device (`torch.device`):
336
+ The device to use for initialization of the inverse frequencies.
337
+ seq_len (`int`, *optional*):
338
+ The current sequence length. Unused for this type of RoPE.
339
+ Returns:
340
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
341
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
342
+ """
343
+ base = config.rope_parameters["rope_theta"]
344
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
345
+
346
+ attention_factor = 1.0 # Unused in this type of RoPE
347
+
348
+ # Compute the inverse frequencies
349
+ inv_freq = 1.0 / (
350
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
351
+ )
352
+ return inv_freq, attention_factor
353
+
354
+ @torch.no_grad()
355
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
356
+ def forward(self, x, position_ids):
357
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
358
+ position_ids_expanded = position_ids[:, None, :].float()
359
+
360
+ # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
361
+ # with maybe_autocast(device_type=device_type, enabled=False): # Force float32
362
+
363
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
364
+ emb = torch.cat((freqs, freqs), dim=-1)
365
+ cos = emb.cos() * self.attention_scaling
366
+ sin = emb.sin() * self.attention_scaling
367
+
368
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
369
+
370
+
371
+ @auto_docstring
372
+ class Ministral3Model(Ministral3PreTrainedModel):
373
+ def __init__(self, config: MinistralDLMConfig):
374
+ super().__init__(config)
375
+ self.padding_idx = config.pad_token_id
376
+ self.vocab_size = config.vocab_size
377
+
378
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
379
+ self.layers = nn.ModuleList(
380
+ [Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
381
+ )
382
+ self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
383
+ self.rotary_emb = Ministral3RotaryEmbedding(config=config)
384
+ self.gradient_checkpointing = False
385
+
386
+ # Initialize weights and apply final processing
387
+ self.post_init()
388
+
389
+ @check_model_inputs
390
+ @auto_docstring
391
+ def forward(
392
+ self,
393
+ input_ids: Optional[torch.LongTensor] = None,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ position_ids: Optional[torch.LongTensor] = None,
396
+ past_key_values: Optional[Cache] = None,
397
+ inputs_embeds: Optional[torch.FloatTensor] = None,
398
+ use_cache: Optional[bool] = None,
399
+ cache_position: Optional[torch.LongTensor] = None,
400
+ **kwargs: Unpack[TransformersKwargs],
401
+ ) -> BaseModelOutputWithPast:
402
+ if (input_ids is None) ^ (inputs_embeds is not None):
403
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
404
+
405
+ if inputs_embeds is None:
406
+ inputs_embeds = self.embed_tokens(input_ids)
407
+
408
+ if use_cache and past_key_values is None:
409
+ # past_key_values = DynamicCache(config=self.config)
410
+ past_key_values = DynamicCache()
411
+
412
+ if cache_position is None:
413
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
414
+ cache_position = torch.arange(
415
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
416
+ )
417
+
418
+ if position_ids is None:
419
+ position_ids = cache_position.unsqueeze(0)
420
+
421
+ if kwargs.get("use_causal_mask", False):
422
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
423
+ causal_mask = mask_function(
424
+ config=self.config,
425
+ input_embeds=inputs_embeds,
426
+ attention_mask=attention_mask,
427
+ cache_position=cache_position,
428
+ past_key_values=past_key_values,
429
+ position_ids=position_ids,
430
+ )
431
+
432
+ else:
433
+ causal_mask = None
434
+
435
+ hidden_states = inputs_embeds
436
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
437
+
438
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
439
+ hidden_states = decoder_layer(
440
+ hidden_states,
441
+ attention_mask=causal_mask,
442
+ position_ids=position_ids,
443
+ past_key_values=past_key_values,
444
+ use_cache=use_cache,
445
+ cache_position=cache_position,
446
+ position_embeddings=position_embeddings,
447
+ **kwargs,
448
+ )
449
+ hidden_states = self.norm(hidden_states)
450
+ return BaseModelOutputWithPast(
451
+ last_hidden_state=hidden_states,
452
+ past_key_values=past_key_values if use_cache else None,
453
+ )
454
+
455
+
456
+ @auto_docstring
457
+ class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
458
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
459
+ _tp_plan = {"lm_head": "colwise_rep"}
460
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
461
+
462
+ def __init__(self, config):
463
+ super().__init__(config)
464
+ self.model = Ministral3Model(config)
465
+ self.vocab_size = config.vocab_size
466
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
467
+
468
+ # Initialize weights and apply final processing
469
+ self.post_init()
470
+
471
+ @can_return_tuple
472
+ @auto_docstring
473
+ def forward(
474
+ self,
475
+ input_ids: Optional[torch.LongTensor] = None,
476
+ attention_mask: Optional[torch.Tensor] = None,
477
+ position_ids: Optional[torch.LongTensor] = None,
478
+ past_key_values: Optional[Cache] = None,
479
+ inputs_embeds: Optional[torch.FloatTensor] = None,
480
+ labels: Optional[torch.LongTensor] = None,
481
+ use_cache: Optional[bool] = None,
482
+ cache_position: Optional[torch.LongTensor] = None,
483
+ logits_to_keep: Union[int, torch.Tensor] = 0,
484
+ **kwargs: Unpack[TransformersKwargs],
485
+ ) -> CausalLMOutputWithPast:
486
+ r"""
487
+ Example:
488
+
489
+ ```python
490
+ >>> from transformers import AutoTokenizer, Ministral3ForCausalLM
491
+
492
+ >>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
493
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
494
+
495
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
496
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
497
+
498
+ >>> # Generate
499
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
500
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
501
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
502
+ ```"""
503
+ outputs: BaseModelOutputWithPast = self.model(
504
+ input_ids=input_ids,
505
+ attention_mask=attention_mask,
506
+ position_ids=position_ids,
507
+ past_key_values=past_key_values,
508
+ inputs_embeds=inputs_embeds,
509
+ use_cache=use_cache,
510
+ cache_position=cache_position,
511
+ **kwargs,
512
+ )
513
+
514
+ hidden_states = outputs.last_hidden_state
515
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
516
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
517
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
518
+
519
+ loss = None
520
+ if labels is not None:
521
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
522
+
523
+ return CausalLMOutputWithPast(
524
+ loss=loss,
525
+ logits=logits,
526
+ past_key_values=outputs.past_key_values,
527
+ hidden_states=outputs.hidden_states,
528
+ attentions=outputs.attentions,
529
+ )
530
+
531
+
532
+ class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
533
+ pass
534
+
535
+
536
+ class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
537
+ pass
538
+
539
+
540
+ class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
541
+ pass
542
+
543
+
544
+ __all__ = [
545
+ "Ministral3ForCausalLM",
546
+ "Ministral3ForQuestionAnswering",
547
+ "Ministral3Model",
548
+ "Ministral3PreTrainedModel",
549
+ "Ministral3ForSequenceClassification",
550
+ "Ministral3ForTokenClassification",
551
+ ]
modeling_ministral_dlm.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, Tuple, Union
4
+ import random
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
+
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.generation import GenerationMixin
27
+
28
+ import math
29
+
30
+ from .chat_utils import generate_with_prefix_cache_block_diff
31
+ from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
+ from .configuration_ministral_dlm import MinistralDLMConfig
33
+
34
+
35
+ @dataclass
36
+ class MinistralDiffOutputWithPast(ModelOutput):
37
+ loss: torch.FloatTensor | None = None
38
+ logits: torch.FloatTensor | None = None
39
+ causal_logits: torch.FloatTensor | None = None
40
+ past_key_values: Cache | None = None
41
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
42
+ attentions: tuple[torch.FloatTensor, ...] | None = None
43
+
44
+
45
+ # @torch.compile(dynamic=True, mode="reduce-overhead")
46
+ # @torch.compile(mode="default")
47
+ # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
48
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
49
+ def fused_flex_attention(q, k, v, block_mask=None):
50
+ return flex_attention(q, k, v, block_mask=block_mask)
51
+
52
+ # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
53
+ class MinistralFlexAttention(Ministral3Attention):
54
+ def __init__(self, *args, **kwargs):
55
+ super().__init__(*args, **kwargs)
56
+
57
+ self.block_size_orig = self.config.block_size
58
+
59
+ if self.config.dlm_paradigm == 'bidirectional':
60
+ self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
61
+ elif self.config.dlm_paradigm == 'autoregressive':
62
+ self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
63
+ elif self.config.dlm_paradigm == 'block_diff':
64
+ self.block_diff_mask = None
65
+ elif self.config.dlm_paradigm == 'sbd_block_diff':
66
+ self.sbd_block_diff_mask = None
67
+ else:
68
+ raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
69
+
70
+ self.block_size = self.block_size_orig
71
+ self.mode = self.config.dlm_paradigm
72
+
73
+ import torch._dynamo.config as dcfg
74
+ dcfg.cache_size_limit = 512
75
+
76
+
77
+ def set_attention_mode(self, mode, block_size=None):
78
+ self.mode = mode
79
+ self.block_size = block_size
80
+
81
+ def compute_block_mask(self, mode, q_len, block_size=None):
82
+
83
+ def bidirectional_mask(b, h, q, kv):
84
+ return (q >= kv) | (q < kv)
85
+
86
+ def autoregressive_mask(b, h, q, kv):
87
+ return (q >= kv)
88
+
89
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
90
+ """
91
+ Constructs the specialized block diffusion attention mask for training
92
+ composed of three masks:
93
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
94
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
95
+ - **Block Causal Mask (M_BC)**: Attention to update x0
96
+ Args:
97
+ b, h: Batch and head indices (ignored for mask logic).
98
+ q_idx, kv_idx: Query and Key indices.
99
+ seq_len: Total sequence length.
100
+ block_size: Defines the block structure.
101
+ Returns:
102
+ A boolean attention mask.
103
+ """
104
+
105
+ # Indicate whether token belongs to xt or x0
106
+ x0_flag_q = (q_idx >= n)
107
+ x0_flag_kv = (kv_idx >= n)
108
+
109
+ # Compute block indices
110
+ block_q = torch.where(x0_flag_q == 1,
111
+ (q_idx - n) // block_size,
112
+ q_idx // block_size)
113
+ block_kv = torch.where(x0_flag_kv == 1,
114
+ (kv_idx - n) // block_size,
115
+ kv_idx // block_size)
116
+
117
+ # **1. Block Diagonal Mask (M_BD) **
118
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
119
+
120
+ # **2. Offset Block-Causal Mask (M_OBC) **
121
+ offset_block_causal = (
122
+ (block_q > block_kv)
123
+ & (x0_flag_kv == 1)
124
+ & (x0_flag_q == 0)
125
+ )
126
+
127
+ # **3. Block-Causal Mask (M_BC) **
128
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
129
+
130
+ # **4. Combine Masks **
131
+ return block_diagonal | offset_block_causal | block_causal
132
+
133
+
134
+ def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
135
+ x0_flag_q = (q_idx >= n)
136
+ x0_flag_kv = (kv_idx >= n)
137
+
138
+ # Compute block indices
139
+ block_q = torch.where(x0_flag_q == 1,
140
+ (q_idx - n) // block_size,
141
+ q_idx // block_size)
142
+ block_kv = torch.where(x0_flag_kv == 1,
143
+ (kv_idx - n) // block_size,
144
+ kv_idx // block_size)
145
+
146
+ # **1. Block Diagonal Mask (M_BD) **
147
+ block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
148
+
149
+ # **2. Offset Block-Causal Mask (M_OBC) **
150
+ offset_block_causal = (
151
+ (block_q > block_kv)
152
+ & (x0_flag_kv == 1)
153
+ & (x0_flag_q == 0)
154
+ )
155
+
156
+ # **3. Fully Causal Mask (M_BC) **
157
+ fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
158
+
159
+ # **4. Combine Masks **
160
+ return block_diagonal | offset_block_causal | fully_causal
161
+
162
+ if mode == 'bidirectional':
163
+ attn_mask = bidirectional_mask
164
+ elif mode == 'autoregressive':
165
+ attn_mask = autoregressive_mask
166
+ elif mode == 'block_diff':
167
+ assert block_size is not None
168
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, q_len//2)
169
+ elif mode == 'sbd_block_diff':
170
+ assert block_size is not None
171
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, q_len//2)
172
+ else:
173
+ raise ValueError(f"Unknown attention mode: {mode}")
174
+
175
+ block_mask = create_block_mask(
176
+ attn_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len
177
+ )
178
+
179
+ return block_mask
180
+
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
186
+ attention_mask: Optional[torch.Tensor],
187
+ past_key_values: Optional[Cache] = None,
188
+ cache_position: Optional[torch.LongTensor] = None,
189
+ is_training: bool = True,
190
+ **kwargs: Unpack[FlashAttentionKwargs],
191
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
192
+ bsz, q_len, _ = hidden_states.size()
193
+ input_shape = hidden_states.shape[:-1]
194
+ hidden_shape = (*input_shape, -1, self.head_dim)
195
+
196
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
197
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
198
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
199
+
200
+ cos, sin = position_embeddings
201
+
202
+ if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
203
+ # Split query and key states in half along sequence length dimension
204
+ q1, q2 = query_states.chunk(2, dim=2)
205
+ k1, k2 = key_states.chunk(2, dim=2)
206
+
207
+ # Apply RoPE independently to each half
208
+ q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
209
+ q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
210
+
211
+ # Recombine the halves
212
+ query_states = torch.cat([q1, q2], dim=2)
213
+ key_states = torch.cat([k1, k2], dim=2)
214
+ else:
215
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
216
+
217
+ query_states = query_states * _get_llama_4_attn_scale(
218
+ cache_position,
219
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
220
+ self.config.rope_parameters.get("original_max_position_embeddings"),
221
+ ).to(query_states.dtype)
222
+
223
+ if past_key_values is not None:
224
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
225
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
226
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
227
+
228
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
229
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
230
+
231
+ if self.mode == 'bidirectional':
232
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
233
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
234
+ else:
235
+ block_mask = self.bidirectional_mask
236
+
237
+ elif self.mode == 'autoregressive':
238
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
239
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
240
+ else:
241
+ block_mask = self.autoregressive_mask
242
+
243
+ elif self.mode == 'block_diff':
244
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
245
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
246
+ else:
247
+ block_mask = self.block_diff_mask
248
+ elif self.mode == 'sbd_block_diff':
249
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
250
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
251
+ else:
252
+ block_mask = self.sbd_block_diff_mask
253
+ else:
254
+ raise ValueError(f"Unknown attention mode: {self.mode}")
255
+
256
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
257
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
258
+
259
+ attn_output = self.o_proj(attn_output)
260
+
261
+ return attn_output, None
262
+
263
+
264
+ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
265
+ """Return a Bool mask of length len(log_w) with exactly k True."""
266
+ g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
267
+ topk = torch.topk(log_w + g, k).indices
268
+ mask = torch.zeros_like(log_w, dtype=torch.bool)
269
+ mask[topk] = True
270
+ return mask
271
+
272
+
273
+ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
274
+ """
275
+ A single model with:
276
+ - a bidirectional encoder + diffusion‐LM head over A
277
+ - a causal decoder + LM head over B, conditioned on F_A
278
+ """
279
+
280
+ def __init__(self, config: MinistralDLMConfig):
281
+ super().__init__(config)
282
+
283
+ self.mask_token_id = config.mask_token_id
284
+
285
+ diffusion_config = copy.deepcopy(config)
286
+ diffusion_config.diffusion_lm = True
287
+
288
+ if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
289
+ diffusion_config.attn_class = MinistralFlexAttention
290
+ elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
291
+ diffusion_config.attn_class = Ministral3Attention
292
+
293
+ if config.dlm_paradigm == 'autoregressive':
294
+ diffusion_config.diffusion_lm = False
295
+ else:
296
+ raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
297
+
298
+ self.encoder = Ministral3Model(diffusion_config)
299
+ self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
300
+ self.vocab_size = config.vocab_size
301
+
302
+ self.current_iter_ratio = None
303
+
304
+ self.post_init()
305
+
306
+
307
+ def get_input_embeddings(self):
308
+ return self.encoder.embed_tokens
309
+
310
+ def set_input_embeddings(self, value):
311
+ self.encoder.embed_tokens = value
312
+
313
+ def get_output_embeddings(self):
314
+ return self.diffusion_head
315
+
316
+ def set_output_embeddings(self, new_embeddings):
317
+ self.diffusion_head = new_embeddings
318
+
319
+
320
+ def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
321
+ b, l = input_ids.shape
322
+ device = input_ids.device
323
+
324
+ if self.config.dp_varying_mask_ratio:
325
+ # Enable different random seeds for each DP rank during sampling
326
+ import torch.distributed as dist
327
+ dp_rank = 0
328
+ if dist.is_initialized():
329
+ try:
330
+ dp_rank = dist.get_rank()
331
+ except Exception:
332
+ dp_rank = 0
333
+ # Use a local generator to avoid affecting global RNG state
334
+ generator = torch.Generator(device=device)
335
+ generator.manual_seed(torch.seed() + dp_rank)
336
+ else:
337
+ generator = None
338
+
339
+ if self.config.adaptive_mask_rate:
340
+ assert block_size is not None
341
+
342
+ # --- simple linear window mapping ---
343
+ bs_min = getattr(self.config, "t_bs_min", 16)
344
+ bs_max = getattr(self.config, "t_bs_max", 128)
345
+ w = getattr(self.config, "t_window_width", 0.6) # fixed width
346
+
347
+ # fraction in [0,1] (unclamped first)
348
+ frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
349
+ # upper bound decreases linearly from 1.0 -> 0.5
350
+ u_max = 1.0 - w * frac
351
+ # clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
352
+ u_max = max(0.6, min(1.0, u_max))
353
+ u_min = u_max - w # ensures width = w
354
+
355
+ # sample t ~ Uniform(u_min, u_max)
356
+ t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
357
+ else:
358
+ t = torch.rand(b, device=device, generator=generator)
359
+
360
+ p_mask = (1 - eps) * t + eps # shape: (b,)
361
+ p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
362
+
363
+ masked_indices = torch.rand((b, l), device=device) < p_mask
364
+
365
+ if loss_mask is not None:
366
+ masked_indices[loss_mask == 0] = 0
367
+
368
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
369
+
370
+ return noisy_batch, masked_indices, p_mask
371
+
372
+
373
+ def forward_process_exp(
374
+ self,
375
+ input_ids: torch.Tensor,
376
+ eps: float = 1e-3,
377
+ block_size: int | None = None,
378
+ half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
379
+ loss_mask: Optional[torch.Tensor] = None,
380
+ ):
381
+ """
382
+ Two-stage corruption with optional per-block sampling.
383
+ • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
384
+ • Stage 2: sample exactly k positions with weights
385
+ w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
386
+ uniform when m→1).
387
+ If `block_size` is given, the procedure is run *independently*
388
+ inside each contiguous block of that length (last block may be shorter).
389
+ When block_size is provided, m is sampled per-block and p_mask is per-block.
390
+ Args
391
+ ----
392
+ input_ids : (B, L) LongTensor
393
+ eps : minimum corruption ratio
394
+ block_size: if not None, operate block-wise with per-block m sampling
395
+ half_life_ratio : controls steepness when m→0
396
+ """
397
+ B, L = input_ids.shape
398
+ device = input_ids.device
399
+ dtype = torch.float32
400
+
401
+ masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
402
+ p_mask = torch.zeros((B, L), dtype=dtype, device=device)
403
+
404
+ # ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
405
+ for b in range(B):
406
+ if block_size is None:
407
+ # ---------- Per-batch sampling (original behavior) ----------
408
+ m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
409
+ k_tot = int(round(m * L))
410
+ k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
411
+
412
+ # Fill p_mask for this batch
413
+ p_mask[b, :] = m
414
+
415
+ slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
416
+
417
+ # ------- single pool over the whole sentence -------------
418
+ lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
419
+
420
+ pos = torch.arange(L, device=device, dtype=dtype)
421
+ log_w = (lam_base * slope * pos).clone()
422
+
423
+ masked_indices[b] = gumbel_topk(log_w, k_tot)
424
+
425
+ else:
426
+ # ---------- Per-block sampling ----------
427
+ num_blocks = math.ceil(L / block_size)
428
+ lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
429
+
430
+ for blk in range(num_blocks):
431
+ start = blk * block_size
432
+ end = min((blk + 1) * block_size, L)
433
+ blk_len = end - start
434
+
435
+ # Sample m per block
436
+ m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
437
+
438
+ # Fill p_mask for this block
439
+ p_mask[b, start:end] = m_blk
440
+
441
+ # per-block budget
442
+ k_blk = int(round(m_blk * blk_len))
443
+ k_blk = max(0, min(k_blk, blk_len))
444
+ if k_blk == 0:
445
+ continue
446
+
447
+ slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
448
+
449
+ pos = torch.arange(blk_len, device=device, dtype=dtype)
450
+ log_w = lam_base * slope * pos
451
+
452
+ blk_mask = gumbel_topk(log_w, k_blk)
453
+ masked_indices[b, start:end] = blk_mask
454
+
455
+ if loss_mask is not None:
456
+ masked_indices[loss_mask == 0] = 0
457
+
458
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
459
+ return noisy_batch, masked_indices, p_mask
460
+
461
+
462
+ def forward(
463
+ self,
464
+ input_ids: torch.LongTensor,
465
+ attention_mask: Optional[torch.Tensor] = None,
466
+ position_ids: Optional[torch.LongTensor] = None,
467
+ labels: Optional[torch.LongTensor] = None,
468
+ split_len: Optional[int] = None,
469
+ past_key_values: Optional[Cache] = None,
470
+ block_size: Optional[int] = None,
471
+ block_diff_ppl: bool = False,
472
+ eps: float = 1e-3,
473
+ is_teacher: bool = False,
474
+ masked_indices: Optional[torch.Tensor] = None,
475
+ p_mask: Optional[torch.Tensor] = None,
476
+ teacher_logits: Optional[torch.Tensor] = None,
477
+ masked_indices_teacher: Optional[torch.Tensor] = None,
478
+ loss_mask: Optional[torch.Tensor] = None,
479
+ ce_loss_weight: float = 1.0,
480
+ output_last_hidden_states_only: bool = False,
481
+ skip_loss: bool = False,
482
+ **kwargs,
483
+ ) -> CausalLMOutputWithPast:
484
+
485
+ batch_size, seq_len = input_ids.shape
486
+
487
+ if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
488
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
489
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
490
+ input_ids = input_ids[:, :random_length]
491
+ labels = labels[:, :random_length]
492
+
493
+ if attention_mask is not None:
494
+ attention_mask = attention_mask[:, :random_length]
495
+ if position_ids is not None:
496
+ position_ids = position_ids[:, :random_length]
497
+ if loss_mask is not None:
498
+ loss_mask = loss_mask[:, :random_length]
499
+
500
+ elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
501
+ if labels is not None and block_size is None:
502
+ if torch.rand(1) < self.config.random_length_prob:
503
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
504
+ else:
505
+ block_size = self.config.block_size
506
+
507
+ else:
508
+ raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
509
+
510
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
511
+ if masked_indices is not None:
512
+ # assert p_mask is not None
513
+
514
+ if loss_mask is not None:
515
+ masked_indices[loss_mask == 0] = 0
516
+
517
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
518
+
519
+ else:
520
+ if self.config.tok_mask_half_life_ratio is not None:
521
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
522
+ else:
523
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
524
+
525
+ else:
526
+ noisy_inputs = input_ids
527
+ masked_indices = None
528
+ p_mask = None
529
+
530
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
531
+ for layer in self.encoder.layers:
532
+ if hasattr(layer.self_attn, 'set_attention_mode'):
533
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
534
+
535
+ input_ids_len = noisy_inputs.shape[1]
536
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
537
+ if position_ids is None:
538
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
539
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
540
+
541
+ if block_diff_ppl:
542
+ if position_ids is None:
543
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
544
+
545
+ enc_out = self.encoder(
546
+ past_key_values=past_key_values,
547
+ input_ids=noisy_inputs,
548
+ attention_mask=attention_mask,
549
+ position_ids=position_ids,
550
+ is_training=(labels is not None) or (block_diff_ppl),
551
+ **kwargs,
552
+ )
553
+
554
+ if output_last_hidden_states_only:
555
+ return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
556
+
557
+ logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
558
+ causal_logits = None
559
+
560
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
561
+ if self.config.dlm_paradigm == 'sbd_block_diff':
562
+ causal_logits = logits[:, input_ids_len:]
563
+ else:
564
+ causal_logits = None
565
+
566
+ logits = logits[:, :input_ids_len]
567
+
568
+ loss = None
569
+ if labels is not None and not skip_loss:
570
+ if self.config.dlm_paradigm == 'autoregressive':
571
+ shift_logits = logits[..., :-1, :].contiguous()
572
+ shift_labels = labels[..., 1:].contiguous()
573
+
574
+ if loss_mask is None:
575
+ loss_fct = CrossEntropyLoss()
576
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
577
+ shift_labels = shift_labels.view(-1)
578
+ loss = loss_fct(shift_logits, shift_labels)
579
+
580
+ else:
581
+ loss_mask = loss_mask[..., 1:].contiguous()
582
+
583
+ loss_fct = CrossEntropyLoss(reduction='none')
584
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
585
+ shift_labels = shift_labels.view(-1)
586
+ shift_labels = shift_labels.to(shift_logits.device)
587
+
588
+ token_losses = loss_fct(shift_logits, shift_labels)
589
+
590
+ flat_loss_mask = loss_mask.reshape(-1)
591
+ loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
592
+
593
+ else:
594
+ # Handle DREAM vs LLADA style losses
595
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
596
+ logits = logits[..., :-1, :].contiguous()
597
+ labels = labels[..., 1:].contiguous()
598
+ masked_indices = masked_indices[:, 1:]
599
+ p_mask = p_mask[:, 1:]
600
+
601
+ if self.config.ada_perm_ratio_per_block is not None:
602
+ # Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
603
+ block_size = self.config.block_size
604
+ batch_size, seq_len = masked_indices.shape
605
+ num_blocks = seq_len // block_size
606
+
607
+ # Get the max logit (confidence) for each position
608
+ confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
609
+
610
+ # Create a mask for tokens to include in loss
611
+ selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
612
+
613
+ for blk in range(num_blocks):
614
+ start = blk * block_size
615
+ end = min((blk + 1) * block_size, seq_len)
616
+
617
+ # Get masked indices within this block
618
+ block_masked = masked_indices[:, start:end] # (batch_size, block_len)
619
+ block_confidence = confidence[:, start:end] # (batch_size, block_len)
620
+
621
+ for b in range(batch_size):
622
+ # Get positions that are masked in this block for this batch
623
+ masked_positions = torch.where(block_masked[b])[0]
624
+ num_masked = len(masked_positions)
625
+
626
+ if num_masked > 0:
627
+ # Number of tokens to keep (top by confidence)
628
+ k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
629
+
630
+ # Get confidence values for masked positions
631
+ masked_confidence = block_confidence[b, masked_positions]
632
+
633
+ # Get indices of top-k confident tokens
634
+ _, topk_indices = torch.topk(masked_confidence, k)
635
+ selected_positions = masked_positions[topk_indices]
636
+
637
+ # Mark these positions in the selected mask
638
+ selected_mask[b, start + selected_positions] = True
639
+
640
+ # Calculate loss only for selected positions
641
+ token_loss = torch.nn.functional.cross_entropy(
642
+ logits[selected_mask],
643
+ labels[selected_mask],
644
+ reduction='none'
645
+ ) / p_mask[selected_mask]
646
+
647
+ num_mask_tokens = selected_mask.sum()
648
+
649
+ else:
650
+ # Calculate token-wise cross entropy loss for masked positions in B
651
+ token_loss = torch.nn.functional.cross_entropy(
652
+ logits[masked_indices],
653
+ labels[masked_indices],
654
+ reduction='none'
655
+ ) / p_mask[masked_indices]
656
+
657
+ num_mask_tokens = masked_indices.sum()
658
+
659
+ if self.config.global_loss_avg:
660
+ loss = token_loss.sum()
661
+ else:
662
+ loss = token_loss.sum() / num_mask_tokens
663
+
664
+ if self.config.ada_dlm_loss_ratio is not None:
665
+ assert self.current_iter_ratio is not None
666
+ assert self.config.dlm_loss_weight is not None
667
+
668
+ dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
669
+ loss = dlm_loss_weight * loss
670
+
671
+ elif self.config.dlm_loss_weight is not None:
672
+ loss = self.config.dlm_loss_weight * loss
673
+
674
+ if self.config.dlm_paradigm == 'sbd_block_diff':
675
+ causal_logits = causal_logits[..., :-1, :].contiguous()
676
+ causal_logits = causal_logits.view(-1, causal_logits.size(-1))
677
+
678
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
679
+ causal_labels = labels.view(-1)
680
+ else:
681
+ causal_labels = labels[..., 1:].contiguous().view(-1)
682
+
683
+ if self.config.global_loss_avg:
684
+ loss_fct = CrossEntropyLoss(reduction='sum')
685
+ ar_loss = loss_fct(causal_logits, causal_labels)
686
+
687
+ self.loss_diffusion = loss.detach().item() / num_mask_tokens
688
+ self.loss_ar = ar_loss.detach().item() / seq_len
689
+
690
+ loss = loss + self.config.ar_loss_weight * ar_loss
691
+ else:
692
+ loss_fct = CrossEntropyLoss()
693
+ ar_loss = loss_fct(causal_logits, causal_labels)
694
+
695
+ self.loss_diffusion = loss.detach().item()
696
+ self.loss_ar = ar_loss.detach().item()
697
+
698
+ loss = loss + self.config.ar_loss_weight * ar_loss
699
+
700
+ if self.config.global_loss_avg:
701
+ if self.config.dlm_paradigm == 'sbd_block_diff':
702
+ loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
703
+ else:
704
+ loss = (loss, num_mask_tokens)
705
+
706
+ return MinistralDiffOutputWithPast(
707
+ loss=loss if not is_teacher else logits,
708
+ logits=logits,
709
+ causal_logits=causal_logits,
710
+ past_key_values=enc_out.past_key_values,
711
+ hidden_states=None,
712
+ attentions=None,
713
+ )
714
+
715
+
716
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0):
717
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
718
+ model=self,
719
+ prompt=prompt_ids,
720
+ gen_length=max_new_tokens,
721
+ steps=steps,
722
+ block_length=block_length,
723
+ remasking="low_confidence",
724
+ temperature=temperature,
725
+ mask_id=self.mask_token_id,
726
+ threshold=threshold,
727
+ shift_logits=shift_logits,
728
+ neg_entropy=False,
729
+ causal_context=causal_context,
730
+ )
731
+
732
+ return out_ids, nfe
733
+
734
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]