Jorg0 m1ngcheng commited on
Commit
aee550f
·
0 Parent(s):

Duplicate from inclusionAI/LLaDA2.1-mini

Browse files

Co-authored-by: mingcheng, aka 明城 <m1ngcheng@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ tags:
5
+ - dllm
6
+ - diffusion
7
+ - llm
8
+ - text_generation
9
+ ---
10
+ # LLaDA2.1-mini
11
+
12
+ 🚀 **LLaDA2.1-flash** is now live on **ZenmuxAI**! Try it via API 🛠️ or Chat 💬: https://zenmux.ai/inclusionai/llada2.1-flash
13
+
14
+ **LLaDA2.1-mini** is a diffusion language model of the LLaDA series featuring the editing enhancement. It significantly improves inference speed while delivering strong task performance.
15
+
16
+ <div align="center">
17
+ <img src="https://mdn.alipayobjects.com/huamei_qa8qxu/afts/img/A*uOo8QKQMiBwAAAAAgNAAAAgAemJ7AQ/original" width="800" />
18
+ </div>
19
+
20
+
21
+ <div align="center">
22
+ <img src="https://mdn.alipayobjects.com/huamei_qa8qxu/afts/img/A*biwvQpCmKjEAAAAAULAAAAgAemJ7AQ/original" width="800" />
23
+ </div>
24
+
25
+ ---
26
+ ## Model Performance
27
+
28
+ <table>
29
+ <thead>
30
+ <tr>
31
+ <th align="left"><b>Benchmark</b></th>
32
+ <th align="center"><b>Qwen3-8B<br>(no_think)</b><br><sub>(Score)</sub></th>
33
+ <th align="center"><b>Ling-mini-2.0</b><br><br><sub>(Score)</sub></th>
34
+ <th align="center"><b>LLaDA2.0-mini</b><br><br><sub>(Score | TPF)</sub></th>
35
+ <th align="center"><b>LLaDA2.1-mini<br>(S Mode)</b><br><sub>(Score | TPF)</sub></th>
36
+ <th align="center"><b>LLaDA2.1-mini<br>(Q Mode)</b><br><sub>(Score | TPF)</sub></th>
37
+ </tr>
38
+ </thead>
39
+ <tbody>
40
+ <tr>
41
+ <td align="left"><b>Average</b></td>
42
+ <td align="center">61.59</td>
43
+ <td align="center">64.72</td>
44
+ <td align="center">63.39 | 2.60</td>
45
+ <td align="center">62.07 | 5.34</td>
46
+ <td align="center">63.90 | 3.12</td>
47
+ </tr>
48
+ <tr><td colspan="6" align="center"><b>Knowledge</b></td></tr>
49
+ <tr>
50
+ <td align="left">GPQA</td>
51
+ <td align="center">48.01</td>
52
+ <td align="center">59.41</td>
53
+ <td align="center">47.76 | 2.73</td>
54
+ <td align="center">48.36 | 3.62</td>
55
+ <td align="center">53.28 | 2.12</td>
56
+ </tr>
57
+ <tr>
58
+ <td align="left">MMLU-Pro</td>
59
+ <td align="center">65.83</td>
60
+ <td align="center">67.18</td>
61
+ <td align="center">64.27 | 2.15</td>
62
+ <td align="center">63.42 | 4.22</td>
63
+ <td align="center">64.84 | 2.41</td>
64
+ </tr>
65
+ <tr>
66
+ <td align="left">C-EVAL</td>
67
+ <td align="center">80.6</td>
68
+ <td align="center">82.17</td>
69
+ <td align="center">81.80 | 1.78</td>
70
+ <td align="center">78.40 | 3.39</td>
71
+ <td align="center">78.59 | 1.91</td>
72
+ </tr>
73
+ <tr>
74
+ <td align="left">PHYBench</td>
75
+ <td align="center">9.76</td>
76
+ <td align="center">14.59</td>
77
+ <td align="center">11.70 | 2.48</td>
78
+ <td align="center">12.75 | 4.41</td>
79
+ <td align="center">13.05 | 2.52</td>
80
+ </tr>
81
+ <tr>
82
+ <td align="left">TriviaQA</td>
83
+ <td align="center">52.51</td>
84
+ <td align="center">55.63</td>
85
+ <td align="center">51.33 | 1.54</td>
86
+ <td align="center">53.33 | 3.21</td>
87
+ <td align="center">54.24 | 2.02</td>
88
+ </tr>
89
+ <tr><td colspan="6" align="center"><b>Reasoning</b></td></tr>
90
+ <tr>
91
+ <td align="left">BIG-Bench Hard</td>
92
+ <td align="center">79.48</td>
93
+ <td align="center">83.70</td>
94
+ <td align="center">78.21 | 2.36</td>
95
+ <td align="center">78.42 | 5.02</td>
96
+ <td align="center">80.58 | 2.86</td>
97
+ </tr>
98
+ <tr>
99
+ <td align="left">BIG-Bench Extra Hard</td>
100
+ <td align="center">18.27</td>
101
+ <td align="center">14.81</td>
102
+ <td align="center">16.47 | 2.03</td>
103
+ <td align="center">15.30 | 3.19</td>
104
+ <td align="center">15.78 | 1.66</td>
105
+ </tr>
106
+ <tr>
107
+ <td align="left">bbh-zh</td>
108
+ <td align="center">80.09</td>
109
+ <td align="center">66.11</td>
110
+ <td align="center">75.75 | 2.77</td>
111
+ <td align="center">67.65 | 3.89</td>
112
+ <td align="center">70.40 | 2.35</td>
113
+ </tr>
114
+ <tr>
115
+ <td align="left">MuSR</td>
116
+ <td align="center">70.02</td>
117
+ <td align="center">71.36</td>
118
+ <td align="center">71.48 | 1.45</td>
119
+ <td align="center">70.43 | 2.48</td>
120
+ <td align="center">71.89 | 1.56</td>
121
+ </tr>
122
+ <tr>
123
+ <td align="left">ZebraLogic</td>
124
+ <td align="center">37.48</td>
125
+ <td align="center">79.85</td>
126
+ <td align="center">64.20 | 2.30</td>
127
+ <td align="center">68.50 | 5.38</td>
128
+ <td align="center">77.10 | 2.93</td>
129
+ </tr>
130
+ <tr>
131
+ <td align="left">PrOntoQA</td>
132
+ <td align="center">93.12</td>
133
+ <td align="center">96.06</td>
134
+ <td align="center">86.00 | 2.36</td>
135
+ <td align="center">87.50 | 4.86</td>
136
+ <td align="center">84.50 | 2.73</td>
137
+ </tr>
138
+ <tr>
139
+ <td align="left">PIQA</td>
140
+ <td align="center">88.30</td>
141
+ <td align="center">87.54</td>
142
+ <td align="center">86.51 | 1.45</td>
143
+ <td align="center">84.87 | 2.59</td>
144
+ <td align="center">86.89 | 1.45</td>
145
+ </tr>
146
+ <tr>
147
+ <td align="left">OCNLI</td>
148
+ <td align="center">61.49</td>
149
+ <td align="center">60.17</td>
150
+ <td align="center">64.51 | 4.06</td>
151
+ <td align="center">61.02 | 1.78</td>
152
+ <td align="center">61.59 | 1.23</td>
153
+ </tr>
154
+ <tr>
155
+ <td align="left">HellaSwag</td>
156
+ <td align="center">79.56</td>
157
+ <td align="center">69.02</td>
158
+ <td align="center">79.01 | 1.50</td>
159
+ <td align="center">75.71 | 2.39</td>
160
+ <td align="center">76.19 | 1.49</td>
161
+ </tr>
162
+ <tr>
163
+ <td align="left">KOR-Bench</td>
164
+ <td align="center">54.96</td>
165
+ <td align="center">63.2</td>
166
+ <td align="center">49.92 | 2.45</td>
167
+ <td align="center">46.64 | 4.28</td>
168
+ <td align="center">48.00 | 2.35</td>
169
+ </tr>
170
+ <tr>
171
+ <td align="left">DROP</td>
172
+ <td align="center">84.56</td>
173
+ <td align="center">78.80</td>
174
+ <td align="center">81.89 | 2.02</td>
175
+ <td align="center">81.55 | 5.84</td>
176
+ <td align="center">82.37 | 2.87</td>
177
+ </tr>
178
+ <tr>
179
+ <td align="left">SQuAD 2.0</td>
180
+ <td align="center">85.21</td>
181
+ <td align="center">75.56</td>
182
+ <td align="center">86.50 | 2.47</td>
183
+ <td align="center">84.51 | 4.33</td>
184
+ <td align="center">85.13 | 3.09</td>
185
+ </tr>
186
+ <tr><td colspan="6" align="center"><b>Coding</b></td></tr>
187
+ <tr>
188
+ <td align="left">LiveCodeBench</td>
189
+ <td align="center">26.76</td>
190
+ <td align="center">42.29</td>
191
+ <td align="center">31.83 | 3.34</td>
192
+ <td align="center">28.85 | 6.42</td>
193
+ <td align="center">30.40 | 3.63</td>
194
+ </tr>
195
+ <tr>
196
+ <td align="left">CRUXEval-O</td>
197
+ <td align="center">74.06</td>
198
+ <td align="center">76.12</td>
199
+ <td align="center">71.62 | 2.78</td>
200
+ <td align="center">70.62 | 5.85</td>
201
+ <td align="center">73.75 | 3.35</td>
202
+ </tr>
203
+ <tr>
204
+ <td align="left">MBPP+</td>
205
+ <td align="center">72.69</td>
206
+ <td align="center">77.25</td>
207
+ <td align="center">78.24 | 3.43</td>
208
+ <td align="center">73.28 | 10.59</td>
209
+ <td align="center">74.07 | 6.30</td>
210
+ </tr>
211
+ <tr>
212
+ <td align="left">HumanEval+</td>
213
+ <td align="center">79.5</td>
214
+ <td align="center">80.03</td>
215
+ <td align="center">81.40 | 5.16</td>
216
+ <td align="center">80.49 | 12.32</td>
217
+ <td align="center">82.93 | 7.77</td>
218
+ </tr>
219
+ <tr>
220
+ <td align="left">MultiPL-E</td>
221
+ <td align="center">61.70</td>
222
+ <td align="center">67.09</td>
223
+ <td align="center">67.46 | 2.78</td>
224
+ <td align="center">64.16 | 7.23</td>
225
+ <td align="center">67.17 | 4.01</td>
226
+ </tr>
227
+ <tr>
228
+ <td align="left">BigCodeBench-Full</td>
229
+ <td align="center">36.05</td>
230
+ <td align="center">35.00</td>
231
+ <td align="center">32.89 | 2.87</td>
232
+ <td align="center">30.18 | 7.33</td>
233
+ <td align="center">34.39 | 4.09</td>
234
+ </tr>
235
+ <tr>
236
+ <td align="left">BIRD-SQL</td>
237
+ <td align="center">36.11</td>
238
+ <td align="center">39.67</td>
239
+ <td align="center">39.34 | 1.96</td>
240
+ <td align="center">37.32 | 4.48</td>
241
+ <td align="center">38.40 | 2.42</td>
242
+ </tr>
243
+ <tr>
244
+ <td align="left">Spider</td>
245
+ <td align="center">72.80</td>
246
+ <td align="center">76.43</td>
247
+ <td align="center">76.76 | 3.93</td>
248
+ <td align="center">75.78 | 7.98</td>
249
+ <td align="center">77.55 | 5.48</td>
250
+ </tr>
251
+ <tr><td colspan="6" align="center"><b>Math</b></td></tr>
252
+ <tr>
253
+ <td align="left">AIME 2025</td>
254
+ <td align="center">22.08</td>
255
+ <td align="center">47.66</td>
256
+ <td align="center">36.67 | 2.41</td>
257
+ <td align="center">36.67 | 6.34</td>
258
+ <td align="center">43.33 | 3.29</td>
259
+ </tr>
260
+ <tr>
261
+ <td align="left">OlympiadBench</td>
262
+ <td align="center">55.33</td>
263
+ <td align="center">72.30</td>
264
+ <td align="center">67.70 | 2.63</td>
265
+ <td align="center">64.30 | 7.08</td>
266
+ <td align="center">66.67 | 3.99</td>
267
+ </tr>
268
+ <tr>
269
+ <td align="left">GSM-Plus</td>
270
+ <td align="center">85.56</td>
271
+ <td align="center">87.18</td>
272
+ <td align="center">86.50 | 2.41</td>
273
+ <td align="center">85.88 | 6.82</td>
274
+ <td align="center">86.55 | 3.69</td>
275
+ </tr>
276
+ <tr>
277
+ <td align="left">CMATH</td>
278
+ <td align="center">95.42</td>
279
+ <td align="center">96.40</td>
280
+ <td align="center">95.72 | 1.98</td>
281
+ <td align="center">95.63 | 4.94</td>
282
+ <td align="center">94.99 | 2.56</td>
283
+ </tr>
284
+ <tr>
285
+ <td align="left">Omni-MATH</td>
286
+ <td align="center">33.20</td>
287
+ <td align="center">48.80</td>
288
+ <td align="center">41.70 | 2.57</td>
289
+ <td align="center">41.70 | 6.41</td>
290
+ <td align="center">43.60 | 3.56</td>
291
+ </tr>
292
+ <tr><td colspan="6" align="center"><b>Agent & Alignment</b></td></tr>
293
+ <tr>
294
+ <td align="left">IFEval-strict-prompt</td>
295
+ <td align="center">84.29</td>
296
+ <td align="center">76.16</td>
297
+ <td align="center">80.78 | 1.24</td>
298
+ <td align="center">81.33 | 1.83</td>
299
+ <td align="center">83.18 | 1.25</td>
300
+ </tr>
301
+ <tr>
302
+ <td align="left">BFCL v3</td>
303
+ <td align="center">70.12</td>
304
+ <td align="center">53.75</td>
305
+ <td align="center">70.72 | 4.26</td>
306
+ <td align="center">72.06 | 7.39</td>
307
+ <td align="center">73.61 | 5.14</td>
308
+ </tr>
309
+ <tr>
310
+ <td align="left">Nexus FC</td>
311
+ <td align="center">37.71</td>
312
+ <td align="center">34.38</td>
313
+ <td align="center">35.18 | 4.06</td>
314
+ <td align="center">31.59 | 8.27</td>
315
+ <td align="center">33.69 | 4.91</td>
316
+ </tr>
317
+ </tbody>
318
+ </table>
319
+
320
+ ---
321
+
322
+ ## 🚀 Highlights
323
+ + **Error-Correcting Editable:** Structural innovation of editable generation for dLLM
324
+ + **Speedy vs Quality Mode:** The 16B mini model achieves ultra-fast inference under Speed Mode while remaining competitive across various tasks and under Quality Mode.
325
+ + **Reinforcement Learning on 100B-scale dLLM:** Tailored algorithm and framework to enable reinforcement learning for large dLLM.
326
+
327
+ ## 🗺️ What's Next
328
+
329
+ + **Powerful Agentic/Tool Use Capability with LLaDA:** Next update will be equipped with powerful **Agentic** and long-distance tool-use capability.
330
+ + **Extreme Editing:** Next update will feature stronger and more extensive editing capabilities, aimed at correcting more errors in parallel reasoning.
331
+ + **Explore More Training Paradigms:** We want to explore more training paradigms than SFT and RL for dLLM.
332
+
333
+ ---
334
+
335
+ ## 📦 Model Variants
336
+
337
+ | Model ID | Description | Hugging Face Link |
338
+ | --- | --- | --- |
339
+ | `inclusionAI/LLaDA2.1-mini` | Instruction-tuned model, ready for downstream applications. | [🤗 Model Card](https://huggingface.co/inclusionAI/LLaDA2.1-mini) |
340
+ | `inclusionAI/LLaDA2.1-flash` | Instruction-tuned model, ready for downstream applications. | [🤗 Model Card](https://huggingface.co/inclusionAI/LLaDA2.1-flash) |
341
+
342
+
343
+ ---
344
+
345
+ ## 🔍 Model Overview
346
+ **LLaDA2.1-mini** has the following specifications:
347
+
348
+ + **Type**: Mixture-of-Experts (MoE) Diffusion Language Model
349
+ + **Total Parameters (Non-Embedding)**: 16B
350
+ + **Number of Layers**: 20
351
+ + **Attention Heads**: 16
352
+ + **Context Length**: 32,768 tokens
353
+ + **Position Embedding**: Rotary (RoPE)
354
+ + **Vocabulary Size**: 157,184
355
+
356
+ ---
357
+
358
+ ### 🤗 Hugging Face Transformers
359
+ Make sure you have `transformers` and its dependencies installed:
360
+
361
+ ```python
362
+ import torch
363
+ import torch.nn.functional as F
364
+ from transformers import AutoModelForCausalLM, AutoTokenizer
365
+
366
+ model_path = "/path/to/LLaDA2.1-mini"
367
+ device = "auto"
368
+ model = AutoModelForCausalLM.from_pretrained(
369
+ model_path, trust_remote_code=True, device_map=device,
370
+ )
371
+ model = model.to(torch.bfloat16)
372
+ model.eval()
373
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
374
+
375
+ prompt = """Calculate 1+5-28*0.5-200=?"""
376
+ input_ids = tokenizer.apply_chat_template(
377
+ [{"role": "user", "content": prompt}],
378
+ add_generation_prompt=True,
379
+ tokenize=True,
380
+ return_tensors="pt",
381
+ )
382
+ generated_tokens = model.generate(
383
+ inputs=input_ids,
384
+ eos_early_stop=True,
385
+ gen_length=512,
386
+ block_length=32,
387
+ threshold=0.5,
388
+ editing_threshold=0,
389
+ temperature=0.0,
390
+ )
391
+ generated_answer = tokenizer.decode(
392
+ generated_tokens[0],
393
+ skip_special_tokens=True,
394
+ )
395
+ print(generated_answer)
396
+ ```
397
+
398
+ ### Best Practices
399
+ To achieve optimal performance, we recommend the following settings:
400
+
401
+ 1. **Sampling Parameters**:
402
+ We recommend the following general sampling parameters: `block_length=32`, `temperature=0.0`, `top_p=None` and `top_k=None`. We are currently exploring more diverse sampling configurations.
403
+
404
+ 2. **Denoising Thresholds**:
405
+ There are three denoising params: `threshold`, `editing_threshold` and `max_post_steps`. We recommend `threshold=0.7`, `editing_threshold=0.5` for **Quality Mode** and `threshold=0.5`, `editing_threshold=0.0` for **Speed Mode**. For both modes, we suggest setting max_post_steps to a value greater than 5. We recommend 16 as a balanced default, which was used for most of our internal testing.
406
+
407
+ Note: Low `threshold` may causes stuttering in trade-off for quick inference.
408
+
409
+ 3. **Adequate Output Length**:
410
+ We recommend using an output length of 16384 tokens for most scenarios.
411
+
412
+ ---
413
+
414
+ ## 🤖ModelScope
415
+ If you're in mainland China, we strongly recommend you to use our model from 🤖[ModelScope](https://modelscope.cn/models/inclusionAI/LLaDA2.1-mini)
416
+
417
+ ---
418
+
419
+ ## Deployment
420
+ ### SGLang
421
+ SGLang enables dLLM inference either through offline batching or by launching an HTTP server for online requests. You can start the SGLang dLLM using the following commands:
422
+
423
+ ``` bash
424
+ python3 -m sglang.launch_server \
425
+ --model-path inclusionAI/LLaDA2.1-mini \
426
+ --dllm-algorithm JointThreshold \
427
+ --tp-size 1 \
428
+ --trust-remote-code \
429
+ --mem-fraction-static 0.8 \
430
+ --max-running-requests 1 \
431
+ --attention-backend flashinfer
432
+ ```
433
+
434
+ ### Enviroment Preparation
435
+ Pull Request (PR) has been submitted and merged to the SGLang community, please prepare the environment with the lateset version
436
+ ___
437
+ ## 🌐 License
438
+ This project is licensed under the terms of the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
439
+
440
+ ---
441
+
442
+ ## 🤝 Contact & Collaboration
443
+ For questions, collaborations, or feedback, please reach out via [Hugging Face](https://huggingface.co/inclusionAI/LLaDA2.1-mini) or open an issue in the [repository](https://github.com/inclusionAI).
444
+
445
+ 👉 Join us in advancing open, efficient, and intelligent language models!
446
+
447
+ ---
448
+
449
+ ## Citation
450
+ ```bibtex
451
+ @misc{bie2026llada21speedingtextdiffusion,
452
+ title={LLaDA2.1: Speeding Up Text Diffusion via Token Editing},
453
+ author={Tiwei Bie and Maosong Cao and Xiang Cao and Bingsen Chen and Fuyuan Chen and Kun Chen and Lun Du and Daozhuo Feng and Haibo Feng and Mingliang Gong and Zhuocheng Gong and Yanmei Gu and Jian Guan and Kaiyuan Guan and Hongliang He and Zenan Huang and Juyong Jiang and Zhonghui Jiang and Zhenzhong Lan and Chengxi Li and Jianguo Li and Zehuan Li and Huabin Liu and Lin Liu and Guoshan Lu and Yuan Lu and Yuxin Ma and Xingyu Mou and Zhenxuan Pan and Kaida Qiu and Yuji Ren and Jianfeng Tan and Yiding Tian and Zian Wang and Lanning Wei and Tao Wu and Yipeng Xing and Wentao Ye and Liangyu Zha and Tianze Zhang and Xiaolu Zhang and Junbo Zhao and Da Zheng and Hao Zhong and Wanli Zhong and Jun Zhou and Junlin Zhou and Liwang Zhu and Muzhi Zhu and Yihong Zhuang},
454
+ year={2026},
455
+ eprint={2602.08676},
456
+ archivePrefix={arXiv},
457
+ primaryClass={cs.LG},
458
+ url={https://arxiv.org/abs/2602.08676},
459
+ }
460
+ ```
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LLaDA2MoeModelLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_llada2_moe.LLaDA2MoeConfig",
8
+ "AutoModel": "modeling_llada2_moe.LLaDA2MoeModel",
9
+ "AutoModelForCausalLM": "modeling_llada2_moe.LLaDA2MoeModelLM"
10
+ },
11
+ "dtype": "bfloat16",
12
+ "embedding_dropout": 0.0,
13
+ "first_k_dense_replace": 1,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 5120,
19
+ "max_position_embeddings": 32768,
20
+ "max_window_layers": 28,
21
+ "model_type": "llada2_moe",
22
+ "moe_intermediate_size": 512,
23
+ "moe_router_enable_expert_bias": true,
24
+ "n_group": 8,
25
+ "norm_head": false,
26
+ "norm_softmax": false,
27
+ "norm_topk_prob": true,
28
+ "num_attention_heads": 16,
29
+ "num_experts": 256,
30
+ "num_experts_per_tok": 8,
31
+ "num_hidden_layers": 20,
32
+ "num_key_value_heads": 4,
33
+ "num_shared_experts": 1,
34
+ "output_dropout": 0.0,
35
+ "output_router_logits": false,
36
+ "pad_token_id": 156892,
37
+ "partial_rotary_factor": 0.5,
38
+ "rms_norm_eps": 1e-06,
39
+ "rope_scaling": null,
40
+ "rope_theta": 600000,
41
+ "rotary_dim": 64,
42
+ "routed_scaling_factor": 2.5,
43
+ "router_dtype": "fp32",
44
+ "score_function": "sigmoid",
45
+ "sliding_window": 4096,
46
+ "tie_word_embeddings": false,
47
+ "topk_group": 4,
48
+ "transformers_version": "4.57.1",
49
+ "use_bias": false,
50
+ "use_cache": false,
51
+ "use_qkv_bias": false,
52
+ "use_rmsnorm": true,
53
+ "use_sliding_window": false,
54
+ "using_split_qkv_in_self_attention": false,
55
+ "vocab_size": 157184
56
+ }
configuration_llada2_moe.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLaDA2 MoE model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class LLaDA2MoeConfig(PretrainedConfig):
7
+ model_type = "llada2_moe"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=30592,
12
+ hidden_size=1024,
13
+ intermediate_size=None,
14
+ num_hidden_layers=24,
15
+ num_attention_heads=16,
16
+ num_key_value_heads=0,
17
+ hidden_act="silu",
18
+ use_qkv_bias=False, # llada2 only
19
+ use_qk_norm=True,
20
+ use_bias=True, # llada2 only
21
+ rms_norm_eps=1e-05,
22
+ norm_head=False, # llada2 only
23
+ tie_word_embeddings=False, # PretrainedConfig key, here change default value.
24
+ embedding_dropout=0.1,
25
+ attention_dropout=0.1,
26
+ output_dropout=0.1,
27
+ initializer_range=0.02,
28
+ max_position_embeddings=16384,
29
+ rope_theta=10000.0,
30
+ use_cache=True,
31
+ use_sliding_window=False,
32
+ sliding_window=4096,
33
+ max_window_layers=28,
34
+ rope_scaling=None,
35
+ pad_token_id=126081,
36
+ num_experts=16,
37
+ num_shared_experts=0,
38
+ num_experts_per_tok=2,
39
+ n_group=8,
40
+ topk_group=4,
41
+ routed_scaling_factor=2.5,
42
+ moe_intermediate_size=None,
43
+ first_k_dense_replace=0,
44
+ head_dim=None,
45
+ output_router_logits=False,
46
+ partial_rotary_factor=0.5,
47
+ **kwargs,
48
+ ):
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.vocab_size = vocab_size
51
+ self.hidden_size = hidden_size
52
+ self.intermediate_size = intermediate_size
53
+ self.num_attention_heads = num_attention_heads
54
+ self.num_key_value_heads = num_key_value_heads
55
+ self.hidden_act = hidden_act
56
+ self.use_qkv_bias = use_qkv_bias
57
+ self.use_qk_norm = use_qk_norm
58
+ self.use_bias = use_bias
59
+ self.norm_head = norm_head
60
+ self.rms_norm_eps = rms_norm_eps
61
+ self.embedding_dropout = embedding_dropout
62
+ self.attention_dropout = attention_dropout
63
+ self.output_dropout = output_dropout
64
+ self.initializer_range = initializer_range
65
+ self.max_position_embeddings = max_position_embeddings
66
+ self.rope_theta = rope_theta
67
+ self.use_cache = use_cache
68
+ self.use_sliding_window = use_sliding_window
69
+ self.sliding_window = sliding_window
70
+ self.max_window_layers = max_window_layers
71
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
72
+ self.rope_scaling = rope_scaling
73
+
74
+ # MoE configs
75
+ self.num_experts = num_experts
76
+ self.num_shared_experts = num_shared_experts
77
+ self.num_experts_per_tok = num_experts_per_tok
78
+ self.n_group = n_group
79
+ self.topk_group = topk_group
80
+ self.moe_intermediate_size = moe_intermediate_size
81
+ self.first_k_dense_replace = first_k_dense_replace
82
+ self.output_router_logits = output_router_logits
83
+ self.routed_scaling_factor = routed_scaling_factor
84
+ self.partial_rotary_factor = partial_rotary_factor
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
88
+ )
model-00000-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a904092ad20835b30400a1ee939f765d88214381e0256f1a5b598da9b425cc
3
+ size 5735111000
model-00001-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f89a9d7bffaa06dffee29ecca93c870657ae670cb9a848c052adeb84487f4df
3
+ size 3825430152
model-00002-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441c1b58669f0cabc1c6f22ab83a56f8c2e38eba2a2e534d01257d7c8247695a
3
+ size 3825430152
model-00003-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:602798afb9acfccefe5f2186a59ee24b4847f3f587f3fa589173f02fc917c4a6
3
+ size 3825431744
model-00004-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bd5fc04fc64bc8ef9e54dbbb630c570417dc45ee2bbd47b0d71dffc61830ccf
3
+ size 3825431976
model-00005-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7b097875ab571ecfb710cecb3cacb46b4de67a88b5eacdf8942a8133384e03a
3
+ size 3825431976
model-00006-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d77a97241a25a79c973aaed3395eb21f8c940e40da71a1279ae19ea8431ba5ee
3
+ size 3825431976
model-00007-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61f8f4a0c54780d41250b9265075169971f2e4fe6b52f4ccd766ddd2305bc764
3
+ size 3825431976
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_llada2_moe.py ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Antgroup and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """PyTorch LLaDA2MoE model."""
20
+
21
+ import math
22
+ from typing import List, Callable, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.modeling_attn_mask_utils import (
32
+ _prepare_4d_causal_attention_mask_for_sdpa,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ MoeModelOutputWithPast,
36
+ MoeCausalLMOutputWithPast,
37
+ )
38
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
+ from transformers.processing_utils import Unpack
41
+ from transformers.pytorch_utils import (
42
+ ALL_LAYERNORM_LAYERS,
43
+ )
44
+ from transformers.utils import (
45
+ TransformersKwargs,
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from .configuration_llada2_moe import LLaDA2MoeConfig
52
+ from transformers.generation.utils import GenerationMixin
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CONFIG_FOR_DOC = "LLaDA2MoeConfig"
58
+
59
+
60
+ def _get_unpad_data(attention_mask):
61
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
62
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
63
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
64
+ cu_seqlens = F.pad(
65
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
66
+ )
67
+ return (
68
+ indices,
69
+ cu_seqlens,
70
+ max_seqlen_in_batch,
71
+ )
72
+
73
+
74
+ class LLaDA2MoeRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LLaDA2MoeRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ input_dtype = hidden_states.dtype
85
+ hidden_states = hidden_states.to(torch.float32)
86
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
87
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
88
+ return self.weight * hidden_states.to(input_dtype)
89
+
90
+
91
+ ALL_LAYERNORM_LAYERS.append(LLaDA2MoeRMSNorm)
92
+
93
+
94
+ class LLaDA2MoeRotaryEmbedding(nn.Module):
95
+ def __init__(self, config: LLaDA2MoeConfig, device=None):
96
+ super().__init__()
97
+ # BC: "rope_type" was originally "type"
98
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
99
+ self.rope_type = config.rope_scaling.get(
100
+ "rope_type", config.rope_scaling.get("type")
101
+ )
102
+ else:
103
+ self.rope_type = "default"
104
+ self.max_seq_len_cached = config.max_position_embeddings
105
+ self.original_max_seq_len = config.max_position_embeddings
106
+
107
+ self.config = config
108
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
109
+
110
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
111
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
112
+ self.original_inv_freq = self.inv_freq
113
+
114
+ @torch.no_grad()
115
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
116
+ def forward(self, x, position_ids):
117
+ inv_freq_expanded = (
118
+ self.inv_freq[None, :, None]
119
+ .float()
120
+ .expand(position_ids.shape[0], -1, 1)
121
+ .to(x.device)
122
+ )
123
+ position_ids_expanded = position_ids[:, None, :].float()
124
+
125
+ device_type = (
126
+ x.device.type
127
+ if isinstance(x.device.type, str) and x.device.type != "mps"
128
+ else "cpu"
129
+ )
130
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
131
+ freqs = (
132
+ inv_freq_expanded.float() @ position_ids_expanded.float()
133
+ ).transpose(1, 2)
134
+ emb = torch.cat((freqs, freqs), dim=-1)
135
+ cos = emb.cos() * self.attention_scaling
136
+ sin = emb.sin() * self.attention_scaling
137
+
138
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
139
+
140
+
141
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
142
+ def rotate_half(x):
143
+ """Rotates half the hidden dims of the input."""
144
+ x1 = x[..., : x.shape[-1] // 2]
145
+ x2 = x[..., x.shape[-1] // 2 :]
146
+ return torch.cat((-x2, x1), dim=-1)
147
+
148
+
149
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
150
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
151
+ """Applies Rotary Position Embedding to the query and key tensors.
152
+
153
+ Args:
154
+ q (`torch.Tensor`): The query tensor.
155
+ k (`torch.Tensor`): The key tensor.
156
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
157
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
158
+ position_ids (`torch.Tensor`):
159
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
160
+ used to pass offsetted position ids when working with a KV-cache.
161
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
162
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
163
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
164
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
165
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
166
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
167
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
168
+ Returns:
169
+ `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
170
+ """
171
+ cos = cos.unsqueeze(unsqueeze_dim)
172
+ sin = sin.unsqueeze(unsqueeze_dim)
173
+
174
+ # Keep half or full tensor for later concatenation
175
+ rotary_dim = cos.shape[-1]
176
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
177
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
178
+
179
+ # Apply rotary embeddings on the first half or full tensor
180
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
181
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
182
+
183
+ # Concatenate back to full shape
184
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
185
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
186
+ return q_embed, k_embed
187
+
188
+
189
+ class LLaDA2MoeMLP(nn.Module):
190
+ def __init__(self, config: LLaDA2MoeConfig, intermediate_size: int):
191
+ super().__init__()
192
+ self.config = config
193
+ self.hidden_size = config.hidden_size
194
+ self.intermediate_size = intermediate_size
195
+
196
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
198
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
199
+ self.act_fn = ACT2FN[config.hidden_act]
200
+
201
+ def forward(self, x):
202
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
203
+
204
+
205
+ class LLaDA2MoeGate(nn.Module):
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.config = config
209
+ self.top_k = config.num_experts_per_tok
210
+ self.num_experts = config.num_experts
211
+
212
+ self.n_group = config.n_group
213
+ self.topk_group = config.topk_group
214
+
215
+ # topk selection algorithm
216
+ self.gating_dim = config.hidden_size
217
+ self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
218
+ self.routed_scaling_factor = config.routed_scaling_factor
219
+
220
+ self.register_buffer("expert_bias", torch.zeros(self.num_experts))
221
+ self.reset_parameters()
222
+
223
+ def reset_parameters(self) -> None:
224
+ import torch.nn.init as init
225
+
226
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
227
+
228
+ def group_limited_topk(
229
+ self,
230
+ scores: torch.Tensor,
231
+ ):
232
+ num_tokens, _ = scores.size()
233
+ # Organize the experts into groups
234
+ group_scores = (
235
+ scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
236
+ )
237
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
238
+ group_mask = torch.zeros_like(group_scores)
239
+ group_mask.scatter_(1, group_idx, 1)
240
+
241
+ # Mask the experts based on selection groups
242
+ score_mask = (
243
+ group_mask.unsqueeze(-1)
244
+ .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
245
+ .reshape(num_tokens, -1)
246
+ )
247
+
248
+ masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
249
+ probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1)
250
+
251
+ return probs, top_indices
252
+
253
+ def forward(self, hidden_states):
254
+ # compute gating score
255
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
256
+ logits = F.linear(
257
+ hidden_states.type(torch.float32), self.weight.type(torch.float32)
258
+ )
259
+
260
+ scores = torch.sigmoid(logits.float()).type_as(logits)
261
+
262
+ scores_for_routing = scores + self.expert_bias
263
+ _, topk_idx = self.group_limited_topk(scores_for_routing)
264
+
265
+ scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
266
+
267
+ topk_weight = (
268
+ scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
269
+ if self.top_k > 1
270
+ else scores
271
+ )
272
+ topk_weight = topk_weight * self.routed_scaling_factor
273
+
274
+ return topk_idx, topk_weight, logits
275
+
276
+
277
+ class LLaDA2MoeSparseMoeBlock(nn.Module):
278
+ """
279
+ A mixed expert module containing shared experts.
280
+ """
281
+
282
+ def __init__(self, config: LLaDA2MoeConfig):
283
+ super().__init__()
284
+ self.config = config
285
+ self.num_experts_per_tok = config.num_experts_per_tok
286
+ self._setup_experts()
287
+ self.gate = LLaDA2MoeGate(config)
288
+ if config.num_shared_experts is not None:
289
+ self.shared_experts = LLaDA2MoeMLP(
290
+ config=config,
291
+ intermediate_size=config.moe_intermediate_size
292
+ * config.num_shared_experts,
293
+ )
294
+
295
+ def _setup_experts(self):
296
+ self.experts = nn.ModuleList(
297
+ [
298
+ LLaDA2MoeMLP(
299
+ config=self.config,
300
+ intermediate_size=self.config.moe_intermediate_size,
301
+ )
302
+ for _ in range(self.config.num_experts)
303
+ ]
304
+ )
305
+
306
+ def forward(self, hidden_states):
307
+ identity = hidden_states
308
+ bsz, seq_len, h = hidden_states.shape
309
+ topk_idx, topk_weight, router_logits = self.gate(hidden_states)
310
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
311
+ flat_topk_idx = topk_idx.view(-1)
312
+ if self.training:
313
+ hidden_states = hidden_states.repeat_interleave(
314
+ self.num_experts_per_tok, dim=0
315
+ )
316
+ y = torch.empty_like(hidden_states)
317
+ for i, expert in enumerate(self.experts):
318
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
319
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
320
+ y = y.to(hidden_states.dtype).view(bsz, seq_len, h)
321
+ else:
322
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(
323
+ bsz, seq_len, h
324
+ )
325
+ if self.config.num_shared_experts is not None:
326
+ y = y + self.shared_experts(identity)
327
+ return y, (
328
+ router_logits.view(bsz, seq_len, -1),
329
+ topk_idx.view(bsz, seq_len, -1),
330
+ )
331
+
332
+ @torch.no_grad()
333
+ def moe_infer(self, x, topk_ids, topk_weight):
334
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
335
+ cnts.scatter_(1, topk_ids, 1)
336
+ tokens_per_expert = cnts.sum(dim=0)
337
+ idxs = topk_ids.view(-1).argsort()
338
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
339
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
340
+ outputs = []
341
+ start_idx = 0
342
+ for i, num_tokens_tensor in enumerate(tokens_per_expert):
343
+ num_tokens = num_tokens_tensor.item()
344
+ if num_tokens == 0:
345
+ continue
346
+ end_idx = start_idx + num_tokens
347
+ expert = self.experts[i]
348
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
349
+ expert_out = expert(tokens_for_this_expert)
350
+ outputs.append(expert_out.to(x.device))
351
+ start_idx = end_idx
352
+
353
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
354
+ new_x = torch.empty_like(outs)
355
+ new_x[idxs] = outs
356
+ final_out = (
357
+ new_x.view(*topk_ids.shape, -1)
358
+ .type(topk_weight.dtype)
359
+ .mul_(topk_weight.unsqueeze(dim=-1))
360
+ .sum(dim=1)
361
+ .type(new_x.dtype)
362
+ )
363
+ return final_out
364
+
365
+
366
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
367
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
368
+ """
369
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
370
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
371
+ """
372
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
373
+ if n_rep == 1:
374
+ return hidden_states
375
+ hidden_states = hidden_states[:, :, None, :, :].expand(
376
+ batch, num_key_value_heads, n_rep, slen, head_dim
377
+ )
378
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
379
+
380
+
381
+ def eager_attention_forward(
382
+ module: nn.Module,
383
+ query: torch.Tensor,
384
+ key: torch.Tensor,
385
+ value: torch.Tensor,
386
+ attention_mask: Optional[torch.Tensor],
387
+ scaling: float,
388
+ dropout: float = 0.0,
389
+ **kwargs: Unpack[TransformersKwargs],
390
+ ):
391
+ key_states = repeat_kv(key, module.num_key_value_groups)
392
+ value_states = repeat_kv(value, module.num_key_value_groups)
393
+
394
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
395
+ if attention_mask is not None:
396
+ attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
397
+
398
+ # upcast attention to fp32
399
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
400
+ query.dtype
401
+ )
402
+ attn_weights = nn.functional.dropout(
403
+ attn_weights, p=dropout, training=module.training
404
+ )
405
+ attn_output = torch.matmul(attn_weights, value_states)
406
+ attn_output = attn_output.transpose(1, 2).contiguous()
407
+
408
+ return attn_output, attn_weights
409
+
410
+
411
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->LLaDA2Moe
412
+ class LLaDA2MoeAttention(nn.Module):
413
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
414
+
415
+ def __init__(self, config: LLaDA2MoeConfig, layer_idx: Optional[int] = None):
416
+ super().__init__()
417
+ self.config = config
418
+ self.layer_idx = layer_idx
419
+ if layer_idx is None:
420
+ logger.warning_once(
421
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
422
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
423
+ "when creating this class."
424
+ )
425
+ self.attention_dropout = config.attention_dropout
426
+ self.hidden_size = config.hidden_size
427
+ self.num_heads = config.num_attention_heads
428
+ self.head_dim = config.head_dim or self.hidden_size // self.num_heads
429
+ partial_rotary_factor = (
430
+ config.partial_rotary_factor
431
+ if hasattr(config, "partial_rotary_factor")
432
+ else 1.0
433
+ )
434
+ self.rope_dim = int(self.head_dim * partial_rotary_factor)
435
+ self.num_key_value_heads = config.num_key_value_heads
436
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
437
+ self.max_position_embeddings = config.max_position_embeddings
438
+ self.rope_theta = config.rope_theta
439
+ self.scaling = self.head_dim**-0.5
440
+ self.is_causal = False
441
+
442
+ self.query_key_value = nn.Linear(
443
+ self.hidden_size,
444
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
445
+ bias=config.use_qkv_bias,
446
+ )
447
+
448
+ if self.config.use_qk_norm:
449
+ self.query_layernorm = LLaDA2MoeRMSNorm(
450
+ self.head_dim, eps=config.rms_norm_eps
451
+ )
452
+ self.key_layernorm = LLaDA2MoeRMSNorm(
453
+ self.head_dim, eps=config.rms_norm_eps
454
+ )
455
+ self.dense = nn.Linear(
456
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias
457
+ )
458
+ self.sliding_window = getattr(config, "sliding_window", None)
459
+
460
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
461
+ return (
462
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
463
+ .transpose(1, 2)
464
+ .contiguous()
465
+ )
466
+
467
+ def forward(
468
+ self,
469
+ hidden_states: torch.Tensor,
470
+ attention_mask: Optional[torch.Tensor] = None,
471
+ position_ids: Optional[torch.LongTensor] = None,
472
+ past_key_value: Optional[Cache] = None,
473
+ output_attentions: bool = False,
474
+ use_cache: bool = False,
475
+ position_embeddings: Optional[
476
+ Tuple[torch.Tensor, torch.Tensor]
477
+ ] = None, # necessary, but kept here for BC
478
+ **kwargs,
479
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
480
+ input_shape = hidden_states.shape[:-1]
481
+
482
+ bsz, q_len, _ = hidden_states.size()
483
+
484
+ qkv = self.query_key_value(hidden_states)
485
+ qkv = qkv.view(
486
+ bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim
487
+ )
488
+
489
+ query_states, key_states, value_states = qkv.split(
490
+ [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
491
+ )
492
+ query_states = query_states.transpose(1, 2)
493
+ key_states = key_states.transpose(1, 2)
494
+ value_states = value_states.transpose(1, 2)
495
+
496
+ if self.config.use_qk_norm:
497
+ query_states = self.query_layernorm(query_states)
498
+ key_states = self.key_layernorm(key_states)
499
+
500
+ cos, sin = position_embeddings
501
+ query_states, key_states = apply_rotary_pos_emb(
502
+ query_states, key_states, cos, sin
503
+ )
504
+
505
+ if past_key_value is not None:
506
+ if self.layer_idx is None:
507
+ raise ValueError(
508
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
509
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
510
+ "with a layer index."
511
+ )
512
+ cache_kwargs = {"sin": sin, "cos": cos}
513
+ key_states, value_states = past_key_value.update(
514
+ key_states, value_states, self.layer_idx, cache_kwargs
515
+ )
516
+
517
+ attention_interface: Callable = eager_attention_forward
518
+ if self.config._attn_implementation != "eager":
519
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
520
+ self.config._attn_implementation
521
+ ]
522
+
523
+ attn_output, attn_weights = attention_interface(
524
+ self,
525
+ query_states,
526
+ key_states,
527
+ value_states,
528
+ attention_mask,
529
+ dropout=0.0 if not self.training else self.attention_dropout,
530
+ scaling=self.scaling,
531
+ sliding_window=self.sliding_window, # diff with Llama
532
+ **kwargs,
533
+ )
534
+
535
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
536
+ attn_output = self.dense(attn_output)
537
+
538
+ return attn_output, attn_weights, past_key_value
539
+
540
+
541
+ class LLaDA2MoeDecoderLayer(nn.Module):
542
+ def __init__(self, config: LLaDA2MoeConfig, layer_idx: int):
543
+ super().__init__()
544
+ self.hidden_size = config.hidden_size
545
+
546
+ self.attention = LLaDA2MoeAttention(config=config, layer_idx=layer_idx)
547
+
548
+ self.mlp = (
549
+ LLaDA2MoeSparseMoeBlock(config)
550
+ if (
551
+ config.num_experts is not None
552
+ and layer_idx >= config.first_k_dense_replace
553
+ )
554
+ else LLaDA2MoeMLP(config=config, intermediate_size=config.intermediate_size)
555
+ )
556
+ self.input_layernorm = LLaDA2MoeRMSNorm(
557
+ config.hidden_size, eps=config.rms_norm_eps
558
+ )
559
+ self.post_attention_layernorm = LLaDA2MoeRMSNorm(
560
+ config.hidden_size, eps=config.rms_norm_eps
561
+ )
562
+
563
+ def forward(
564
+ self,
565
+ hidden_states: torch.Tensor,
566
+ attention_mask: Optional[torch.Tensor] = None,
567
+ position_ids: Optional[torch.LongTensor] = None,
568
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
569
+ output_attentions: Optional[bool] = False,
570
+ output_router_logits: Optional[bool] = False,
571
+ use_cache: Optional[bool] = False,
572
+ position_embeddings: Optional[
573
+ Tuple[torch.Tensor, torch.Tensor]
574
+ ] = None, # necessary, but kept here for BC
575
+ **kwargs,
576
+ ) -> Tuple[
577
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
578
+ ]:
579
+ """
580
+ Args:
581
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
582
+ attention_mask (`torch.FloatTensor`, *optional*):
583
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
584
+ query_sequence_length, key_sequence_length)` if default attention is used.
585
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
586
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
587
+ config.n_positions - 1]`.
588
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
589
+ cached past key and value projection states
590
+ output_attentions (`bool`, *optional*):
591
+ Whether to return the attentions tensors of all attention layers. See `attentions` under
592
+ returned tensors for more detail.
593
+ output_router_logits (`bool`, *optional*):
594
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
595
+ and should not be returned during inference.
596
+ use_cache (`bool`, *optional*):
597
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
598
+ (see `past_key_values`).
599
+ """
600
+ residual = hidden_states
601
+
602
+ hidden_states = self.input_layernorm(hidden_states)
603
+
604
+ # Self Attention
605
+ hidden_states, self_attn_weights, present_key_value = self.attention(
606
+ hidden_states=hidden_states,
607
+ attention_mask=attention_mask,
608
+ position_ids=position_ids,
609
+ past_key_value=past_key_value,
610
+ output_attentions=output_attentions,
611
+ position_embeddings=position_embeddings,
612
+ use_cache=use_cache,
613
+ )
614
+ hidden_states = residual + hidden_states
615
+
616
+ # Fully Connected
617
+ residual = hidden_states
618
+ hidden_states = self.post_attention_layernorm(hidden_states)
619
+ hidden_states = self.mlp(hidden_states)
620
+ if isinstance(hidden_states, tuple):
621
+ hidden_states, router_logits = hidden_states
622
+ else:
623
+ router_logits = None
624
+ hidden_states = residual + hidden_states.to(residual.device)
625
+
626
+ outputs = (hidden_states,)
627
+
628
+ if output_attentions:
629
+ outputs += (self_attn_weights,)
630
+
631
+ if use_cache:
632
+ outputs += (present_key_value,)
633
+
634
+ if output_router_logits:
635
+ outputs += (router_logits,)
636
+
637
+ return outputs
638
+
639
+
640
+ LLADA2MOE_START_DOCSTRING = r"""
641
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
642
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
643
+ etc.)
644
+
645
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
646
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
647
+ and behavior.
648
+
649
+ Parameters:
650
+ config ([`LLaDA2MoeConfig`]):
651
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
652
+ load the weights associated with the model, only the configuration. Check out the
653
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
654
+ """
655
+
656
+
657
+ @add_start_docstrings(
658
+ "The bare LLaDA2Moe Model outputting raw hidden-states without any specific head on top.",
659
+ LLADA2MOE_START_DOCSTRING,
660
+ )
661
+ class LLaDA2MoePreTrainedModel(PreTrainedModel):
662
+ config_class = LLaDA2MoeConfig
663
+ base_model_prefix = "model"
664
+ supports_gradient_checkpointing = True
665
+ _no_split_modules = ["LLaDA2MoeDecoderLayer"]
666
+ _skip_keys_device_placement = ["past_key_values"]
667
+ _supports_flash_attn_2 = False
668
+ _supports_sdpa = True
669
+ _supports_flex_attn = True
670
+ _supports_cache_class = True
671
+
672
+ def _init_weights(self, module):
673
+ std = self.config.initializer_range
674
+ if isinstance(module, nn.Linear):
675
+ module.weight.data.normal_(mean=0.0, std=std)
676
+ if module.bias is not None:
677
+ module.bias.data.zero_()
678
+ elif isinstance(module, nn.Embedding):
679
+ module.weight.data.normal_(mean=0.0, std=std)
680
+ if module.padding_idx is not None:
681
+ module.weight.data[module.padding_idx].zero_()
682
+
683
+
684
+ LLADA2MOE_INPUTS_DOCSTRING = r"""
685
+ Args:
686
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
687
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
688
+ it.
689
+
690
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
691
+ [`PreTrainedTokenizer.__call__`] for details.
692
+
693
+ [What are input IDs?](../glossary#input-ids)
694
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
695
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
696
+
697
+ - 1 for tokens that are **not masked**,
698
+ - 0 for tokens that are **masked**.
699
+
700
+ [What are attention masks?](../glossary#attention-mask)
701
+
702
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
703
+ [`PreTrainedTokenizer.__call__`] for details.
704
+
705
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
706
+ `past_key_values`).
707
+
708
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
709
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
710
+ information on the default strategy.
711
+
712
+ - 1 indicates the head is **not masked**,
713
+ - 0 indicates the head is **masked**.
714
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
715
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
716
+ config.n_positions - 1]`.
717
+
718
+ [What are position IDs?](../glossary#position-ids)
719
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
720
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
721
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
722
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
723
+
724
+ Two formats are allowed:
725
+ - a [`~cache_utils.Cache`] instance;
726
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
727
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
728
+ cache format.
729
+
730
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
731
+ legacy cache format will be returned.
732
+
733
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
734
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
735
+ of shape `(batch_size, sequence_length)`.
736
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
737
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
738
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
739
+ model's internal embedding lookup matrix.
740
+ use_cache (`bool`, *optional*):
741
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
742
+ `past_key_values`).
743
+ output_attentions (`bool`, *optional*):
744
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
745
+ tensors for more detail.
746
+ output_hidden_states (`bool`, *optional*):
747
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
748
+ more detail.
749
+ return_dict (`bool`, *optional*):
750
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
751
+ """
752
+
753
+
754
+ @add_start_docstrings(
755
+ "The bare LLaDA2Moe Model outputting raw hidden-states without any specific head on top.",
756
+ LLADA2MOE_START_DOCSTRING,
757
+ )
758
+ class LLaDA2MoeModel(LLaDA2MoePreTrainedModel):
759
+ """
760
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaDA2MoeDecoderLayer`]
761
+
762
+ Args:
763
+ config: LLaDA2MoeConfig
764
+ """
765
+
766
+ def __init__(self, config: LLaDA2MoeConfig):
767
+ super().__init__(config)
768
+ self.padding_idx = config.pad_token_id
769
+ self.vocab_size = config.vocab_size
770
+
771
+ self.word_embeddings = nn.Embedding(
772
+ config.vocab_size, config.hidden_size, self.padding_idx
773
+ )
774
+ self.layers = nn.ModuleList(
775
+ [
776
+ LLaDA2MoeDecoderLayer(config, layer_idx)
777
+ for layer_idx in range(config.num_hidden_layers)
778
+ ]
779
+ )
780
+ self._use_sdpa = config._attn_implementation == "sdpa"
781
+ self._use_flex_attention = config._attn_implementation == "flex_attention"
782
+ self.norm = LLaDA2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
783
+ self.rotary_emb = LLaDA2MoeRotaryEmbedding(config=config)
784
+ self.gradient_checkpointing = False
785
+ # Initialize weights and apply final processing
786
+ self.post_init()
787
+
788
+ def get_input_embeddings(self):
789
+ return self.word_embeddings
790
+
791
+ def set_input_embeddings(self, value):
792
+ self.word_embeddings = value
793
+
794
+ @add_start_docstrings_to_model_forward(LLADA2MOE_INPUTS_DOCSTRING)
795
+ def forward(
796
+ self,
797
+ input_ids: torch.LongTensor = None,
798
+ attention_mask: Optional[torch.Tensor] = None,
799
+ position_ids: Optional[torch.LongTensor] = None,
800
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
801
+ inputs_embeds: Optional[torch.FloatTensor] = None,
802
+ use_cache: Optional[bool] = None,
803
+ output_attentions: Optional[bool] = None,
804
+ output_hidden_states: Optional[bool] = None,
805
+ output_router_logits: Optional[bool] = None,
806
+ return_dict: Optional[bool] = None,
807
+ **kwargs,
808
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
809
+ output_attentions = (
810
+ output_attentions
811
+ if output_attentions is not None
812
+ else self.config.output_attentions
813
+ )
814
+ output_hidden_states = (
815
+ output_hidden_states
816
+ if output_hidden_states is not None
817
+ else self.config.output_hidden_states
818
+ )
819
+ output_router_logits = (
820
+ output_router_logits
821
+ if output_router_logits is not None
822
+ else self.config.output_router_logits
823
+ )
824
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
825
+
826
+ return_dict = (
827
+ return_dict if return_dict is not None else self.config.use_return_dict
828
+ )
829
+
830
+ # retrieve input_ids and inputs_embeds
831
+ if input_ids is not None and inputs_embeds is not None:
832
+ raise ValueError(
833
+ "You cannot specify both input_ids and inputs_embeds at the same time"
834
+ )
835
+ elif input_ids is not None:
836
+ batch_size, seq_length = input_ids.shape[:2]
837
+ elif inputs_embeds is not None:
838
+ batch_size, seq_length = inputs_embeds.shape[:2]
839
+ else:
840
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
841
+
842
+ if self.gradient_checkpointing and self.training:
843
+ if use_cache:
844
+ logger.warning_once(
845
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
846
+ )
847
+ use_cache = False
848
+
849
+ if use_cache and past_key_values is None:
850
+ past_key_values = DynamicCache()
851
+
852
+ if inputs_embeds is None:
853
+ inputs_embeds = self.word_embeddings(input_ids)
854
+
855
+ past_seen_tokens = (
856
+ past_key_values.get_seq_length() if past_key_values is not None else 0
857
+ )
858
+
859
+ if position_ids is None:
860
+ position_ids = torch.arange(
861
+ past_seen_tokens,
862
+ past_seen_tokens + inputs_embeds.shape[1],
863
+ device=inputs_embeds.device,
864
+ )
865
+ position_ids = position_ids.unsqueeze(0)
866
+ if attention_mask.size() == (batch_size, 1, seq_length, seq_length):
867
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
868
+ attention_mask,
869
+ (batch_size, seq_length),
870
+ inputs_embeds,
871
+ past_seen_tokens,
872
+ )
873
+ else:
874
+ raise ValueError(
875
+ f"LLaDA2.0 only support block attention mask with shape: {(batch_size, 1, seq_length, seq_length)}, the input attention with shape {attention_mask.size()=}!"
876
+ )
877
+ # embed positions
878
+ hidden_states = inputs_embeds
879
+
880
+ # create position embeddings to be shared across the decoder layers
881
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
882
+
883
+ # decoder layers
884
+ all_hidden_states = () if output_hidden_states else None
885
+ all_self_attns = () if output_attentions else None
886
+ all_router_logits = () if output_router_logits else None
887
+ next_decoder_cache = None
888
+
889
+ for decoder_layer in self.layers:
890
+ if output_hidden_states:
891
+ all_hidden_states += (hidden_states,)
892
+
893
+ if self.gradient_checkpointing and self.training:
894
+ layer_outputs = self._gradient_checkpointing_func(
895
+ decoder_layer.__call__,
896
+ hidden_states,
897
+ attention_mask,
898
+ position_ids,
899
+ past_key_values,
900
+ output_attentions,
901
+ output_router_logits,
902
+ use_cache,
903
+ position_embeddings,
904
+ )
905
+ else:
906
+ layer_outputs = decoder_layer(
907
+ hidden_states,
908
+ attention_mask=attention_mask,
909
+ position_ids=position_ids,
910
+ past_key_value=past_key_values,
911
+ output_attentions=output_attentions,
912
+ output_router_logits=output_router_logits,
913
+ use_cache=use_cache,
914
+ position_embeddings=position_embeddings,
915
+ )
916
+ hidden_states = layer_outputs[0]
917
+
918
+ if use_cache:
919
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
920
+
921
+ if output_attentions:
922
+ all_self_attns += (layer_outputs[1],)
923
+
924
+ if output_router_logits and layer_outputs[-1] is not None:
925
+ all_router_logits += (layer_outputs[-1],)
926
+
927
+ hidden_states = self.norm(hidden_states)
928
+
929
+ # add hidden states from the last decoder layer
930
+ if output_hidden_states:
931
+ all_hidden_states += (hidden_states,)
932
+
933
+ next_cache = None
934
+ if use_cache:
935
+ next_cache = next_decoder_cache
936
+ if not return_dict:
937
+ return tuple(
938
+ v
939
+ for v in [
940
+ hidden_states,
941
+ next_cache,
942
+ all_hidden_states,
943
+ all_self_attns,
944
+ all_router_logits,
945
+ ]
946
+ if v is not None
947
+ )
948
+ return MoeModelOutputWithPast(
949
+ last_hidden_state=hidden_states,
950
+ past_key_values=next_cache,
951
+ hidden_states=all_hidden_states,
952
+ attentions=all_self_attns,
953
+ router_logits=all_router_logits,
954
+ )
955
+
956
+
957
+ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
958
+ _tied_weights_keys = ["lm_head.weight"]
959
+
960
+ def __init__(self, config: LLaDA2MoeConfig):
961
+ super().__init__(config)
962
+ self.model = LLaDA2MoeModel(config)
963
+ self.vocab_size = config.vocab_size
964
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
965
+
966
+ # Initialize weights and apply final processing
967
+ self.post_init()
968
+
969
+ def get_input_embeddings(self):
970
+ return self.model.word_embeddings
971
+
972
+ def set_input_embeddings(self, value):
973
+ self.model.word_embeddings = value
974
+
975
+ def get_output_embeddings(self):
976
+ return self.lm_head
977
+
978
+ def set_output_embeddings(self, new_embeddings):
979
+ self.lm_head = new_embeddings
980
+
981
+ def set_decoder(self, decoder):
982
+ self.model = decoder
983
+
984
+ def get_decoder(self):
985
+ return self.model
986
+
987
+ @add_start_docstrings_to_model_forward(LLADA2MOE_INPUTS_DOCSTRING)
988
+ @replace_return_docstrings(
989
+ output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
990
+ )
991
+ def forward(
992
+ self,
993
+ input_ids: torch.LongTensor = None,
994
+ attention_mask: Optional[torch.Tensor] = None,
995
+ position_ids: Optional[torch.LongTensor] = None,
996
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
997
+ inputs_embeds: Optional[torch.FloatTensor] = None,
998
+ labels: Optional[torch.LongTensor] = None,
999
+ use_cache: Optional[bool] = None,
1000
+ output_attentions: Optional[bool] = None,
1001
+ output_hidden_states: Optional[bool] = None,
1002
+ output_router_logits: Optional[bool] = None,
1003
+ return_dict: Optional[bool] = None,
1004
+ **kwargs,
1005
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1006
+ r"""
1007
+ Args:
1008
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1009
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1010
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1011
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1012
+
1013
+ Returns:
1014
+
1015
+ Example:
1016
+
1017
+ ```python
1018
+ >>> from transformers import AutoTokenizer
1019
+
1020
+ >>> model = LLaDA2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1021
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1022
+
1023
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1024
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1025
+
1026
+ >>> # Generate
1027
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1028
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1029
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1030
+ ```"""
1031
+ output_attentions = (
1032
+ output_attentions
1033
+ if output_attentions is not None
1034
+ else self.config.output_attentions
1035
+ )
1036
+ output_hidden_states = (
1037
+ output_hidden_states
1038
+ if output_hidden_states is not None
1039
+ else self.config.output_hidden_states
1040
+ )
1041
+ output_router_logits = (
1042
+ output_router_logits
1043
+ if output_router_logits is not None
1044
+ else self.config.output_router_logits
1045
+ )
1046
+ return_dict = (
1047
+ return_dict if return_dict is not None else self.config.use_return_dict
1048
+ )
1049
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1050
+ outputs = self.model(
1051
+ input_ids=input_ids,
1052
+ attention_mask=attention_mask,
1053
+ position_ids=position_ids,
1054
+ past_key_values=past_key_values,
1055
+ inputs_embeds=inputs_embeds,
1056
+ use_cache=use_cache,
1057
+ output_attentions=output_attentions,
1058
+ output_hidden_states=output_hidden_states,
1059
+ output_router_logits=output_router_logits,
1060
+ return_dict=return_dict,
1061
+ **kwargs,
1062
+ )
1063
+
1064
+ loss = None
1065
+ aux_loss = None
1066
+ hidden_states = outputs[0]
1067
+
1068
+ logits = self.lm_head(hidden_states)
1069
+ logits = logits.float()
1070
+
1071
+ if labels is not None:
1072
+ # LLaDA2.0 will use same label position logits
1073
+ shift_logits = logits
1074
+ shift_labels = labels
1075
+ # Flatten the tokens
1076
+ loss_fct = CrossEntropyLoss()
1077
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1078
+ shift_labels = shift_labels.view(-1)
1079
+ # Enable model parallelism
1080
+ shift_labels = shift_labels.to(shift_logits.device)
1081
+ loss = loss_fct(shift_logits, shift_labels)
1082
+
1083
+ if not return_dict:
1084
+ output = (logits,) + outputs[1:]
1085
+ if output_router_logits:
1086
+ output = (aux_loss,) + output
1087
+ return (loss,) + output if loss is not None else output
1088
+
1089
+ return MoeCausalLMOutputWithPast(
1090
+ loss=loss,
1091
+ aux_loss=aux_loss,
1092
+ logits=logits,
1093
+ past_key_values=outputs.past_key_values,
1094
+ hidden_states=outputs.hidden_states,
1095
+ attentions=outputs.attentions,
1096
+ router_logits=outputs.router_logits,
1097
+ )
1098
+
1099
+ def prepare_inputs_for_generation(
1100
+ self,
1101
+ input_ids,
1102
+ past_key_values=None,
1103
+ attention_mask=None,
1104
+ inputs_embeds=None,
1105
+ token_type_ids=None,
1106
+ **kwargs,
1107
+ ):
1108
+ if past_key_values is not None:
1109
+ if isinstance(past_key_values, Cache):
1110
+ cache_length = past_key_values.get_seq_length()
1111
+ past_length = past_key_values.seen_tokens
1112
+ max_cache_length = (
1113
+ past_key_values.get_max_length()
1114
+ if hasattr(past_key_values, "get_max_length")
1115
+ else past_key_values.get_max_cache_shape()
1116
+ )
1117
+ else:
1118
+ cache_length = past_length = past_key_values[0][0].shape[2]
1119
+ max_cache_length = None
1120
+
1121
+ # Keep only the unprocessed tokens:
1122
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1123
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input)
1124
+ if (
1125
+ attention_mask is not None
1126
+ and attention_mask.shape[1] > input_ids.shape[1]
1127
+ ):
1128
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1129
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1130
+ # input_ids based on the past_length.
1131
+ elif past_length < input_ids.shape[1]:
1132
+ input_ids = input_ids[:, past_length:]
1133
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1134
+
1135
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1136
+ if (
1137
+ max_cache_length is not None
1138
+ and attention_mask is not None
1139
+ and cache_length + input_ids.shape[1] > max_cache_length
1140
+ ):
1141
+ attention_mask = attention_mask[:, -max_cache_length:]
1142
+
1143
+ position_ids = kwargs.get("position_ids", None)
1144
+ if attention_mask is not None and position_ids is None:
1145
+ # create position_ids on the fly for batch generation
1146
+ position_ids = attention_mask.long().cumsum(-1) - 1
1147
+ position_ids.masked_fill_(attention_mask == 0, 1)
1148
+ if past_key_values:
1149
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1150
+
1151
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1152
+ if inputs_embeds is not None and past_key_values is None:
1153
+ model_inputs = {"inputs_embeds": inputs_embeds}
1154
+ else:
1155
+ model_inputs = {"input_ids": input_ids}
1156
+
1157
+ model_inputs.update(
1158
+ {
1159
+ "position_ids": position_ids,
1160
+ "past_key_values": past_key_values,
1161
+ "use_cache": kwargs.get("use_cache"),
1162
+ "attention_mask": attention_mask,
1163
+ }
1164
+ )
1165
+ return model_inputs
1166
+
1167
+ @staticmethod
1168
+ def _reorder_cache(past_key_values, beam_idx):
1169
+ reordered_past = ()
1170
+ for layer_past in past_key_values:
1171
+ reordered_past += (
1172
+ tuple(
1173
+ past_state.index_select(0, beam_idx.to(past_state.device))
1174
+ for past_state in layer_past
1175
+ ),
1176
+ )
1177
+ return reordered_past
1178
+
1179
+ @staticmethod
1180
+ def _top_k_logits(logits, k):
1181
+ if k is None or k <= 0:
1182
+ return logits
1183
+ else:
1184
+ values, _ = torch.topk(logits, k)
1185
+ min_values = values[..., -1, None]
1186
+ return torch.where(
1187
+ logits < min_values, torch.full_like(logits, float("-inf")), logits
1188
+ )
1189
+
1190
+ @staticmethod
1191
+ def _top_p_logits(logits, p):
1192
+ if p is None or p >= 1.0:
1193
+ return logits
1194
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1195
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1196
+ sorted_mask = cumulative_probs > p
1197
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
1198
+ sorted_mask[..., 0] = False
1199
+ mask_indices = torch.scatter(
1200
+ torch.full_like(logits, False, dtype=torch.bool),
1201
+ -1,
1202
+ sorted_indices,
1203
+ sorted_mask,
1204
+ )
1205
+ return logits.masked_fill(mask_indices, float("-inf"))
1206
+
1207
+ def _sample_with_temperature_topk_topp(
1208
+ self, logits, temperature=1.0, top_k=0, top_p=1.0
1209
+ ):
1210
+ orig_shape = logits.shape[:-1]
1211
+ vocab_size = logits.shape[-1]
1212
+ logits = logits.reshape(-1, vocab_size)
1213
+ if temperature == 0.0:
1214
+ token = torch.argmax(logits, dim=-1, keepdim=True)
1215
+ probs = F.softmax(logits, dim=-1)
1216
+ token_prob = torch.gather(probs, -1, token)
1217
+ return token.view(*orig_shape), token_prob.view(*orig_shape)
1218
+
1219
+ if temperature > 0 and temperature != 1.0:
1220
+ logits = logits / temperature
1221
+ logits = self._top_k_logits(logits, top_k)
1222
+ logits = self._top_p_logits(logits, top_p)
1223
+ probs = F.softmax(logits, dim=-1)
1224
+ token = torch.multinomial(probs, num_samples=1)
1225
+ token_prob = torch.gather(probs, -1, token)
1226
+ return token.view(*orig_shape), token_prob.view(*orig_shape)
1227
+
1228
+ @staticmethod
1229
+ def _get_num_transfer_tokens(block_length, steps):
1230
+ if steps == 0:
1231
+ return torch.tensor([], dtype=torch.int64)
1232
+ base = block_length // steps
1233
+ remainder = block_length % steps
1234
+ num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64)
1235
+ num_transfer_tokens[:remainder] += 1
1236
+ return num_transfer_tokens
1237
+
1238
+ @torch.no_grad()
1239
+ def generate(
1240
+ self,
1241
+ inputs: Optional[torch.Tensor] = None,
1242
+ temperature: float = 0.0,
1243
+ block_length: int = 32,
1244
+ steps: int = 32,
1245
+ gen_length: int = 2048,
1246
+ top_p: Optional[float] = None,
1247
+ top_k: Optional[int] = None,
1248
+ eos_early_stop: bool = False,
1249
+ minimal_topk: int = 1,
1250
+ threshold: float = 0.95,
1251
+ editing_threshold: float = 0.9,
1252
+ max_post_steps: int = 16,
1253
+ eos_id: int = 156892,
1254
+ mask_id: int = 156895,
1255
+ num_to_transfer: int = 1,
1256
+ ):
1257
+ r"""
1258
+ Generates tokens using a block-wise, iterative refinement strategy.
1259
+ This method operates differently from standard autoregressive generation. It first creates a template of the
1260
+ full desired length, filled with a special `mask_id`. It then processes this template in segments (`blocks`)
1261
+ and iteratively "denoises" or "refines" the `mask_id` tokens into actual tokens over a series of `steps` for
1262
+ each block. A custom block-diagonal causal attention mask ensures that generation within a block can attend to
1263
+ all previous blocks but not future ones.
1264
+ <Tip warning={true}>
1265
+ This is a specialized generation method. The quality and speed of the output are highly dependent on the interplay
1266
+ between `block_length`, `steps`, and `threshold`. It aims to achieve faster generation through parallel
1267
+ decoding within blocks, which is a departure from the token-by-token generation of standard `.generate()` methods.
1268
+ </Tip>
1269
+ Parameters:
1270
+ inputs (`torch.Tensor`):
1271
+ The token sequence used as a prompt for the generation.
1272
+ temperature (`float`, *optional*, defaults to 0.0):
1273
+ The value used to module the next token probabilities. A value of 0.0 corresponds to greedy decoding.
1274
+ block_length (`int`, *optional*, defaults to 32):
1275
+ The size of each generation block. The model generates text in parallel within these blocks. This is a
1276
+ key parameter for controlling the granularity of the generation process.
1277
+ steps (`int`, *optional*, defaults to 32):
1278
+ The number of iterative refinement (or "denoising") steps to perform for each block. Within each block,
1279
+ the model will try to replace `mask_id` tokens with real tokens for this many iterations.
1280
+ gen_length (`int`, *optional*, defaults to 2048):
1281
+ The maximum number of tokens to generate, excluding the prompt.
1282
+ top_p (`float`, *optional*):
1283
+ If set to a float value between 0 and 1, only the most probable tokens with probabilities that add up to
1284
+ `top_p` or higher are kept for generation (nucleus sampling).
1285
+ top_k (`int`, *optional*):
1286
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
1287
+ eos_early_stop (`bool`, *optional*, defaults to `False`):
1288
+ If `True`, generation will stop as soon as a valid End-Of-Sequence token is generated and confirmed,
1289
+ even if `gen_length` has not been reached.
1290
+ minimal_topk (`int`, *optional*, defaults to 1):
1291
+ A parameter used to dynamically adjust the number of refinement `steps`. The effective number of steps
1292
+ is capped at `gen_length // minimal_topk`.
1293
+ threshold (`float`, *optional*, defaults to 0.95):
1294
+ The confidence probability threshold for accepting a sampled token. During each refinement step, a
1295
+ sampled token is only kept if its probability is above this threshold. If not enough tokens meet the
1296
+ threshold, the ones with the highest confidence are chosen.
1297
+ editing_threshold (`float`, *optional*, defaults to 0.5):
1298
+ The confidence threshold for **editing**. Existing tokens (non-masked) are replaced by newly
1299
+ sampled tokens if the model's confidence in the new token exceeds this threshold and the token has changed.
1300
+ max_post_steps (`int`, *optional*, defaults to 16):
1301
+ Number of global refinement iterations after all mask tokens are resolved.
1302
+ eos_id (`int`, *optional*, defaults to 156892):
1303
+ The token ID for the end-of-sequence token. Used for `eos_early_stop`.
1304
+ mask_id (`int`, *optional*, defaults to 156895):
1305
+ The token ID used as a placeholder for tokens that are yet to be generated. This is central to the
1306
+ iterative refinement algorithm.
1307
+ Return:
1308
+ `torch.Tensor`: A string containing the generated token IDs, starting
1309
+ after the prompt and stopping at the first `eos_id` or `gen_length`.
1310
+ """
1311
+
1312
+ steps = min(steps, gen_length // minimal_topk)
1313
+ input_ids = inputs.to(self.device)
1314
+
1315
+ prompt_length = input_ids.shape[1]
1316
+ num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
1317
+ total_length = num_blocks * block_length
1318
+
1319
+ block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=self.device))
1320
+ block_diffusion_attention_mask = (
1321
+ block_mask.repeat_interleave(block_length, dim=0)
1322
+ .repeat_interleave(block_length, dim=1)
1323
+ .unsqueeze(0)
1324
+ .unsqueeze(0)
1325
+ ).to(torch.bfloat16)
1326
+
1327
+ position_ids = torch.arange(total_length, device=self.device).unsqueeze(0)
1328
+ x = torch.full((1, total_length), mask_id, dtype=torch.long, device=self.device)
1329
+ x[:, :prompt_length] = input_ids.clone()
1330
+
1331
+ prompt_index_full = torch.zeros_like(x, dtype=torch.bool)
1332
+ prompt_index_full[:, :prompt_length] = True
1333
+
1334
+ prefill_blocks = prompt_length // block_length
1335
+
1336
+ for num_block in range(prefill_blocks, num_blocks):
1337
+ current_window_end = (num_block + 1) * block_length
1338
+ cur_x = x[:, :current_window_end]
1339
+ cur_attn_mask = block_diffusion_attention_mask[
1340
+ :, :, :current_window_end, :current_window_end
1341
+ ]
1342
+ cur_position_ids = position_ids[:, :current_window_end]
1343
+
1344
+ block_start_pos = num_block * block_length
1345
+
1346
+ post_steps = 0
1347
+ while True:
1348
+ old_block_tokens = cur_x[:, -block_length:].clone()
1349
+ active_block_mask = cur_x[:, -block_length:] == mask_id
1350
+ if torch.any(active_block_mask) == False:
1351
+ post_steps += 1
1352
+ if post_steps > max_post_steps:
1353
+ break
1354
+ prompt_mask_in_block = torch.zeros(
1355
+ block_length, dtype=torch.bool, device=self.device
1356
+ )
1357
+ if block_start_pos < prompt_length:
1358
+ prompt_end_in_block = min(
1359
+ prompt_length - block_start_pos, block_length
1360
+ )
1361
+ prompt_mask_in_block[:prompt_end_in_block] = True
1362
+
1363
+ outputs = self.forward(
1364
+ cur_x,
1365
+ attention_mask=cur_attn_mask,
1366
+ position_ids=cur_position_ids,
1367
+ output_attentions=True,
1368
+ )
1369
+ logits = outputs.logits
1370
+
1371
+ active_logits = logits[:, -block_length:, :]
1372
+ x0, x0_p = self._sample_with_temperature_topk_topp(
1373
+ active_logits, temperature=temperature, top_k=top_k, top_p=top_p
1374
+ )
1375
+ mask_transfer_index = torch.zeros_like(x0, dtype=torch.bool)
1376
+ if active_block_mask.sum() > 0:
1377
+ mask_confidence = torch.where(active_block_mask, x0_p, -torch.inf)
1378
+ high_conf_mask = (
1379
+ mask_confidence[0] > threshold
1380
+ ) & active_block_mask[0]
1381
+ num_high_confidence = high_conf_mask.sum().item()
1382
+
1383
+ if num_high_confidence >= num_to_transfer:
1384
+ mask_transfer_index[0] = high_conf_mask
1385
+ else:
1386
+ num_available = active_block_mask.sum().item()
1387
+ if num_available > 0:
1388
+ _, idx = torch.topk(
1389
+ mask_confidence[0],
1390
+ k=min(num_to_transfer, num_available),
1391
+ )
1392
+ mask_transfer_index[0, idx] = True
1393
+
1394
+ editing_transfer_index = torch.zeros_like(x0, dtype=torch.bool)
1395
+ non_mask_positions = ~active_block_mask
1396
+ non_prompt_positions = ~prompt_mask_in_block
1397
+ editable_positions = non_mask_positions & non_prompt_positions[None, :]
1398
+ editing_confidence = torch.where(editable_positions, x0_p, -torch.inf)
1399
+ high_conf_editing = (
1400
+ editing_confidence[0] > editing_threshold
1401
+ ) & editable_positions[0]
1402
+
1403
+ token_changed = x0[0] != old_block_tokens[0]
1404
+ editing_transfer_index[0] = high_conf_editing & token_changed
1405
+ final_transfer_index = mask_transfer_index | editing_transfer_index
1406
+
1407
+ if final_transfer_index.any():
1408
+ cur_x[:, -block_length:][final_transfer_index] = x0[
1409
+ final_transfer_index
1410
+ ]
1411
+
1412
+ if active_block_mask.sum() == 0 and not editing_transfer_index.any():
1413
+ break
1414
+
1415
+ x[:, :current_window_end] = cur_x
1416
+ if eos_early_stop:
1417
+ generated_part = x[0, prompt_length:current_window_end]
1418
+ if (generated_part == mask_id).sum() == 0:
1419
+ eos_positions = (generated_part == eos_id).nonzero(as_tuple=True)[0]
1420
+ if len(eos_positions) > 0:
1421
+ break
1422
+
1423
+ generated_answer = x[:, : prompt_length + gen_length]
1424
+ mask_positions = (generated_answer[0][input_ids.shape[1] :] == eos_id).nonzero(
1425
+ as_tuple=True
1426
+ )[0]
1427
+ if len(mask_positions) > 0:
1428
+ first_mask_position = mask_positions[0].item()
1429
+ else:
1430
+ first_mask_position = gen_length
1431
+
1432
+ return generated_answer[
1433
+ :, input_ids.shape[1] : input_ids.shape[1] + first_mask_position + 1
1434
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|startoftext|>",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "<|endoftext|>",
5
+ "gmask_token": "[gMASK]",
6
+ "pad_token": "<|endoftext|>",
7
+ "mask_token": "<|mask|>"
8
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "bos_token": "<|startoftext|>",
5
+ "chat_template": "{% set thinking_option = 'off' %}\n{{- '<role>SYSTEM</role>' }}\n{%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n' }}\n{%- endif %}\n{%- if tools %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call>\\n\" }}\n{%- endif %}\n{{- 'detailed thinking ' + thinking_option + '<|role_end|>' }}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if message.role == \"user\" %}\n {{- '<role>HUMAN</role>' + message.content + '<|role_end|>' }}\n {%- elif message.role == \"system\" and not loop.first %}\n {{- '<role>SYSTEM</role>' + message.content + '<|role_end|>' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if reasoning_content %}\n {{- '<role>ASSISTANT</role>' + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<role>ASSISTANT</role>' + content }}\n {%- endif %}\n {%- else %}\n {{- '<role>ASSISTANT</role>' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|role_end|>' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<role>OBSERVATION</role>' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|role_end|>' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<role>ASSISTANT</role>' }}\n{%- endif %}",
6
+ "clean_up_tokenization_spaces": false,
7
+ "cls_token": "[CLS]",
8
+ "eos_token": "<|endoftext|>",
9
+ "mask_token": "<|mask|>",
10
+ "fast_tokenizer": true,
11
+ "gmask_token": "[gMASK]",
12
+ "merges_file": null,
13
+ "model_max_length": 32768,
14
+ "pad_token": "<|endoftext|>",
15
+ "tokenizer_class": "PreTrainedTokenizerFast",
16
+ "trust_remote_code": true,
17
+ "vocab_file": null
18
+ }