CharyZeng commited on
Commit
d9556ce
·
verified ·
1 Parent(s): bf30e76

Add 4-layer reduced variant: keep layers 0,1,2,42 renumbered to 0..3 + MTP

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pdf filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
DeepSeek_V4.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4cbe4fcbd2888b25b2890a98cc6ef4ce0489df7c93e140b6f853c451d3f5c52
3
+ size 4479907
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ ---
5
+ # DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence
6
+
7
+ <!-- markdownlint-disable first-line-h1 -->
8
+ <!-- markdownlint-disable html -->
9
+ <!-- markdownlint-disable no-duplicate-header -->
10
+
11
+ <div align="center">
12
+ <img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek-V4" />
13
+ </div>
14
+ <hr>
15
+ <div align="center" style="line-height: 1;">
16
+ <a href="https://www.deepseek.com/" target="_blank" style="margin: 2px;">
17
+ <img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" style="display: inline-block; vertical-align: middle;"/>
18
+ </a>
19
+ <a href="https://chat.deepseek.com/" target="_blank" style="margin: 2px;">
20
+ <img alt="Chat" src="https://img.shields.io/badge/🤖%20Chat-DeepSeek%20V4-536af5?color=536af5&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
21
+ </a>
22
+ </div>
23
+ <div align="center" style="line-height: 1;">
24
+ <a href="https://huggingface.co/deepseek-ai" target="_blank" style="margin: 2px;">
25
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
26
+ </a>
27
+ <a href="https://twitter.com/deepseek_ai" target="_blank" style="margin: 2px;">
28
+ <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
29
+ </a>
30
+ </div>
31
+ <div align="center" style="line-height: 1;">
32
+ <a href="LICENSE" style="margin: 2px;">
33
+ <img alt="License" src="https://img.shields.io/badge/License-MIT-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
34
+ </a>
35
+ </div>
36
+
37
+ <p align="center">
38
+ <a href="https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/DeepSeek_V4.pdf"><b>Technical Report</b>👁️</a>
39
+ </p>
40
+
41
+ ## Introduction
42
+
43
+ We present a preview version of **DeepSeek-V4** series, including two strong Mixture-of-Experts (MoE) language models — **DeepSeek-V4-Pro** with 1.6T parameters (49B activated) and **DeepSeek-V4-Flash** with 284B parameters (13B activated) — both supporting a context length of **one million tokens**.
44
+
45
+ DeepSeek-V4 series incorporate several key upgrades in architecture and optimization:
46
+
47
+ 1. **Hybrid Attention Architecture:** We design a hybrid attention mechanism combining Compressed Sparse Attention (CSA) and Heavily Compressed Attention (HCA) to dramatically improve long-context efficiency. In the 1M-token context setting, DeepSeek-V4-Pro requires only **27% of single-token inference FLOPs** and **10% of KV cache** compared with DeepSeek-V3.2.
48
+ 2. **Manifold-Constrained Hyper-Connections (mHC):** We incorporate mHC to strengthen conventional residual connections, enhancing stability of signal propagation across layers while preserving model expressivity.
49
+ 3. **Muon Optimizer:** We employ the Muon optimizer for faster convergence and greater training stability.
50
+
51
+ We pre-train both models on more than **32T** diverse and high-quality tokens, followed by a comprehensive post-training pipeline. The post-training features a two-stage paradigm: independent cultivation of domain-specific experts (through SFT and RL with GRPO), followed by unified model consolidation via on-policy distillation, integrating distinct proficiencies across diverse domains into a single model.
52
+
53
+ **DeepSeek-V4-Pro-Max**, the maximum reasoning effort mode of DeepSeek-V4-Pro, significantly advances the knowledge capabilities of open-source models, firmly establishing itself as the best open-source model available today. It achieves top-tier performance in coding benchmarks and significantly bridges the gap with leading closed-source models on reasoning and agentic tasks. Meanwhile, **DeepSeek-V4-Flash-Max** achieves comparable reasoning performance to the Pro version when given a larger thinking budget, though its smaller parameter scale naturally places it slightly behind on pure knowledge tasks and the most complex agentic workflows.
54
+
55
+ <div align="center">
56
+ <img src="assets/dsv4_performance.png" >
57
+ </div>
58
+
59
+ ## Model Downloads
60
+
61
+ <div align="center">
62
+
63
+ | **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Precision** | **Download** |
64
+ | :---: | :---: | :---: | :---: | :---: | :---: |
65
+ | DeepSeek-V4-Flash-Base | 284B | 13B | 1M | FP8 Mixed | [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash-Base) \| [ModelScope](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash-Base) |
66
+ | DeepSeek-V4-Flash | 284B | 13B | 1M | FP4 + FP8 Mixed* | [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) \| [ModelScope](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Flash) |
67
+ | DeepSeek-V4-Pro-Base | 1.6T | 49B | 1M | FP8 Mixed | [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro-Base) \| [ModelScope](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro-Base) |
68
+ | DeepSeek-V4-Pro | 1.6T | 49B | 1M | FP4 + FP8 Mixed* | [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro) \| [ModelScope](https://modelscope.cn/models/deepseek-ai/DeepSeek-V4-Pro) |
69
+
70
+ </div>
71
+
72
+ *\*FP4 + FP8 Mixed: MoE expert parameters use FP4 precision; most other parameters use FP8.*
73
+
74
+ ## Evaluation Results
75
+
76
+ ### Base Model
77
+
78
+ <div align="center">
79
+
80
+ | Benchmark (Metric) | # Shots | DeepSeek-V3.2-Base | DeepSeek-V4-Flash-Base | DeepSeek-V4-Pro-Base |
81
+ | :--- | :---: | :---: | :---: | :---: |
82
+ | Architecture | - | MoE | MoE | MoE |
83
+ | # Activated Params | - | 37B | 13B | 49B |
84
+ | # Total Params | - | 671B | 284B | 1.6T |
85
+ | **World Knowledge** | | | | |
86
+ | AGIEval (EM) | 0-shot | 80.1 | 82.6 | **83.1** |
87
+ | MMLU (EM) | 5-shot | 87.8 | 88.7 | **90.1** |
88
+ | MMLU-Redux (EM) | 5-shot | 87.5 | 89.4 | **90.8** |
89
+ | MMLU-Pro (EM) | 5-shot | 65.5 | 68.3 | **73.5** |
90
+ | MMMLU (EM) | 5-shot | 87.9 | 88.8 | **90.3** |
91
+ | C-Eval (EM) | 5-shot | 90.4 | 92.1 | **93.1** |
92
+ | CMMLU (EM) | 5-shot | 88.9 | 90.4 | **90.8** |
93
+ | MultiLoKo (EM) | 5-shot | 38.7 | 42.2 | **51.1** |
94
+ | Simple-QA verified (EM) | 25-shot | 28.3 | 30.1 | **55.2** |
95
+ | SuperGPQA (EM) | 5-shot | 45.0 | 46.5 | **53.9** |
96
+ | FACTS Parametric (EM) | 25-shot | 27.1 | 33.9 | **62.6** |
97
+ | TriviaQA (EM) | 5-shot | 83.3 | 82.8 | **85.6** |
98
+ | **Language & Reasoning** | | | | |
99
+ | BBH (EM) | 3-shot | **87.6** | 86.9 | 87.5 |
100
+ | DROP (F1) | 1-shot | 88.2 | 88.6 | **88.7** |
101
+ | HellaSwag (EM) | 0-shot | 86.4 | 85.7 | **88.0** |
102
+ | WinoGrande (EM) | 0-shot | 78.9 | 79.5 | **81.5** |
103
+ | CLUEWSC (EM) | 5-shot | 83.5 | 82.2 | **85.2** |
104
+ | **Code & Math** | | | | |
105
+ | BigCodeBench (Pass@1) | 3-shot | **63.9** | 56.8 | 59.2 |
106
+ | HumanEval (Pass@1) | 0-shot | 62.8 | 69.5 | **76.8** |
107
+ | GSM8K (EM) | 8-shot | 91.1 | 90.8 | **92.6** |
108
+ | MATH (EM) | 4-shot | 60.5 | 57.4 | **64.5** |
109
+ | MGSM (EM) | 8-shot | 81.3 | **85.7** | 84.4 |
110
+ | CMath (EM) | 3-shot | 92.6 | **93.6** | 90.9 |
111
+ | **Long Context** | | | | |
112
+ | LongBench-V2 (EM) | 1-shot | 40.2 | 44.7 | **51.5** |
113
+
114
+ </div>
115
+
116
+ ### Instruct Model
117
+
118
+ DeepSeek-V4-Pro and DeepSeek-V4-Flash both support three reasoning effort modes:
119
+
120
+ | Reasoning Mode | Characteristics | Typical Use Cases | Response Format |
121
+ | :--- | :--- | :--- | :--- |
122
+ | Non-think | Fast, intuitive responses | Routine daily tasks, low-risk decisions | `</think>` summary |
123
+ | Think High | Conscious logical analysis, slower but more accurate | Complex problem-solving, planning | `<think>` thinking `</think>` summary |
124
+ | Think Max | Push reasoning to its fullest extent | Exploring the boundary of model reasoning capability | Special system prompt + `<think>` thinking `</think>` summary |
125
+
126
+ #### DeepSeek-V4-Pro-Max vs Frontier Models
127
+
128
+ <div align="center">
129
+
130
+ | Benchmark (Metric) | Opus-4.6 Max | GPT-5.4 xHigh | Gemini-3.1-Pro High | K2.6 Thinking | GLM-5.1 Thinking | DS-V4-Pro Max |
131
+ | :--- | :---: | :---: | :---: | :---: | :---: | :---: |
132
+ | **Knowledge & Reasoning** | | | | | | |
133
+ | MMLU-Pro (EM) | 89.1 | 87.5 | **91.0** | 87.1 | 86.0 | 87.5 |
134
+ | SimpleQA-Verified (Pass@1) | 46.2 | 45.3 | **75.6** | 36.9 | 38.1 | 57.9 |
135
+ | Chinese-SimpleQA (Pass@1) | 76.4 | 76.8 | **85.9** | 75.9 | 75.0 | 84.4 |
136
+ | GPQA Diamond (Pass@1) | 91.3 | 93.0 | **94.3** | 90.5 | 86.2 | 90.1 |
137
+ | HLE (Pass@1) | 40.0 | 39.8 | **44.4** | 36.4 | 34.7 | 37.7 |
138
+ | LiveCodeBench (Pass@1) | 88.8 | - | 91.7 | 89.6 | - | **93.5** |
139
+ | Codeforces (Rating) | - | 3168 | 3052 | - | - | **3206** |
140
+ | HMMT 2026 Feb (Pass@1) | 96.2 | **97.7** | 94.7 | 92.7 | 89.4 | 95.2 |
141
+ | IMOAnswerBench (Pass@1) | 75.3 | **91.4** | 81.0 | 86.0 | 83.8 | 89.8 |
142
+ | Apex (Pass@1) | 34.5 | 54.1 | **60.9** | 24.0 | 11.5 | 38.3 |
143
+ | Apex Shortlist (Pass@1) | 85.9 | 78.1 | 89.1 | 75.5 | 72.4 | **90.2** |
144
+ | **Long Context** | | | | | | |
145
+ | MRCR 1M (MMR) | **92.9** | - | 76.3 | - | - | 83.5 |
146
+ | CorpusQA 1M (ACC) | **71.7** | - | 53.8 | - | - | 62.0 |
147
+ | **Agentic** | | | | | | |
148
+ | Terminal Bench 2.0 (Acc) | 65.4 | **75.1** | 68.5 | 66.7 | 63.5 | 67.9 |
149
+ | SWE Verified (Resolved) | **80.8** | - | 80.6 | 80.2 | - | 80.6 |
150
+ | SWE Pro (Resolved) | 57.3 | 57.7 | 54.2 | **58.6** | 58.4 | 55.4 |
151
+ | SWE Multilingual (Resolved) | **77.5** | - | - | 76.7 | 73.3 | 76.2 |
152
+ | BrowseComp (Pass@1) | 83.7 | 82.7 | **85.9** | 83.2 | 79.3 | 83.4 |
153
+ | HLE w/ tools (Pass@1) | 53.1 | 52.0 | 51.6 | **54.0** | 50.4 | 48.2 |
154
+ | GDPval-AA (Elo) | 1619 | **1674** | 1314 | 1482 | 1535 | 1554 |
155
+ | MCPAtlas Public (Pass@1) | **73.8** | 67.2 | 69.2 | 66.6 | 71.8 | 73.6 |
156
+ | Toolathlon (Pass@1) | 47.2 | **54.6** | 48.8 | 50.0 | 40.7 | 51.8 |
157
+
158
+ </div>
159
+
160
+ #### Comparison across Modes
161
+
162
+ <div align="center">
163
+
164
+ | Benchmark (Metric) | V4-Flash Non-Think | V4-Flash High | V4-Flash Max | V4-Pro Non-Think | V4-Pro High | V4-Pro Max |
165
+ | :--- | :---: | :---: | :---: | :---: | :---: | :---: |
166
+ | **Knowledge & Reasoning** | | | | | | |
167
+ | MMLU-Pro (EM) | 83.0 | 86.4 | 86.2 | 82.9 | 87.1 | **87.5** |
168
+ | SimpleQA-Verified (Pass@1) | 23.1 | 28.9 | 34.1 | 45.0 | 46.2 | **57.9** |
169
+ | Chinese-SimpleQA (Pass@1) | 71.5 | 73.2 | 78.9 | 75.8 | 77.7 | **84.4** |
170
+ | GPQA Diamond (Pass@1) | 71.2 | 87.4 | 88.1 | 72.9 | 89.1 | **90.1** |
171
+ | HLE (Pass@1) | 8.1 | 29.4 | 34.8 | 7.7 | 34.5 | **37.7** |
172
+ | LiveCodeBench (Pass@1) | 55.2 | 88.4 | 91.6 | 56.8 | 89.8 | **93.5** |
173
+ | Codeforces (Rating) | - | 2816 | 3052 | - | 2919 | **3206** |
174
+ | HMMT 2026 Feb (Pass@1) | 40.8 | 91.9 | 94.8 | 31.7 | 94.0 | **95.2** |
175
+ | IMOAnswerBench (Pass@1) | 41.9 | 85.1 | 88.4 | 35.3 | 88.0 | **89.8** |
176
+ | Apex (Pass@1) | 1.0 | 19.1 | 33.0 | 0.4 | 27.4 | **38.3** |
177
+ | Apex Shortlist (Pass@1) | 9.3 | 72.1 | 85.7 | 9.2 | 85.5 | **90.2** |
178
+ | **Long Context** | | | | | | |
179
+ | MRCR 1M (MMR) | 37.5 | 76.9 | 78.7 | 44.7 | 83.3 | **83.5** |
180
+ | CorpusQA 1M (ACC) | 15.5 | 59.3 | 60.5 | 35.6 | 56.5 | **62.0** |
181
+ | **Agentic** | | | | | | |
182
+ | Terminal Bench 2.0 (Acc) | 49.1 | 56.6 | 56.9 | 59.1 | 63.3 | **67.9** |
183
+ | SWE Verified (Resolved) | 73.7 | 78.6 | 79.0 | 73.6 | 79.4 | **80.6** |
184
+ | SWE Pro (Resolved) | 49.1 | 52.3 | 52.6 | 52.1 | 54.4 | **55.4** |
185
+ | SWE Multilingual (Resolved) | 69.7 | 70.2 | 73.3 | 69.8 | 74.1 | **76.2** |
186
+ | BrowseComp (Pass@1) | - | 53.5 | 73.2 | - | 80.4 | **83.4** |
187
+ | HLE w/ tools (Pass@1) | - | 40.3 | 45.1 | - | 44.7 | **48.2** |
188
+ | MCPAtlas (Pass@1) | 64.0 | 67.4 | 69.0 | 69.4 | **74.2** | 73.6 |
189
+ | GDPval-AA (Elo) | - | - | 1395 | - | - | **1554** |
190
+ | Toolathlon (Pass@1) | 40.7 | 43.5 | 47.8 | 46.3 | 49.0 | **51.8** |
191
+
192
+ </div>
193
+
194
+ ## Chat Template
195
+
196
+ This release does not include a Jinja-format chat template. Instead, we provide a dedicated `encoding` folder with Python scripts and test cases demonstrating how to encode messages in OpenAI-compatible format into input strings for the model, and how to parse the model's text output. Please refer to the [`encoding`](encoding/README.md) folder for full documentation.
197
+
198
+ A brief example:
199
+
200
+ ```python
201
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
202
+
203
+ messages = [
204
+ {"role": "user", "content": "hello"},
205
+ {"role": "assistant", "content": "Hello! I am DeepSeek.", "reasoning_content": "thinking..."},
206
+ {"role": "user", "content": "1+1=?"}
207
+ ]
208
+
209
+ # messages -> string
210
+ prompt = encode_messages(messages, thinking_mode="thinking")
211
+
212
+ # string -> tokens
213
+ import transformers
214
+ tokenizer = transformers.AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V4-Pro")
215
+ tokens = tokenizer.encode(prompt)
216
+ ```
217
+
218
+ ## How to Run Locally
219
+
220
+ Please refer to the [inference](inference/README.md) folder for detailed instructions on running DeepSeek-V4 locally, including model weight conversion and interactive chat demos.
221
+
222
+ For local deployment, we recommend setting the sampling parameters to `temperature = 1.0, top_p = 1.0`. For the Think Max reasoning mode, we recommend setting the context window to at least **384K** tokens.
223
+
224
+ ## License
225
+
226
+ This repository and the model weights are licensed under the [MIT License](LICENSE).
227
+
228
+ ## Citation
229
+
230
+ ```
231
+ @misc{deepseekai2026deepseekv4,
232
+ title={DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence},
233
+ author={DeepSeek-AI},
234
+ year={2026},
235
+ }
236
+ ```
237
+
238
+ ## Contact
239
+
240
+ If you have any questions, please raise an issue or contact us at [service@deepseek.com](service@deepseek.com).
assets/dsv4_performance.png ADDED

Git LFS Details

  • SHA256: 8fd472981a4c8d40c1845c51c5e8961fc4ef3ac22e7ec77801f534c239c1b30f
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
config.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV4ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 1,
9
+ "hc_eps": 1e-06,
10
+ "hc_mult": 4,
11
+ "hc_sinkhorn_iters": 20,
12
+ "head_dim": 512,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 4096,
15
+ "index_head_dim": 128,
16
+ "index_n_heads": 64,
17
+ "index_topk": 512,
18
+ "initializer_range": 0.02,
19
+ "max_position_embeddings": 1048576,
20
+ "model_type": "deepseek_v4",
21
+ "moe_intermediate_size": 2048,
22
+ "n_routed_experts": 256,
23
+ "n_shared_experts": 1,
24
+ "norm_topk_prob": true,
25
+ "num_attention_heads": 64,
26
+ "num_experts_per_tok": 6,
27
+ "num_hidden_layers": 4,
28
+ "num_hash_layers": 3,
29
+ "num_key_value_heads": 1,
30
+ "num_nextn_predict_layers": 1,
31
+ "o_groups": 8,
32
+ "o_lora_rank": 1024,
33
+ "q_lora_rank": 1024,
34
+ "qk_rope_head_dim": 64,
35
+ "quantization_config": {
36
+ "activation_scheme": "dynamic",
37
+ "fmt": "e4m3",
38
+ "quant_method": "fp8",
39
+ "scale_fmt": "ue8m0",
40
+ "weight_block_size": [
41
+ 128,
42
+ 128
43
+ ]
44
+ },
45
+ "rms_norm_eps": 1e-06,
46
+ "rope_scaling": {
47
+ "beta_fast": 32,
48
+ "beta_slow": 1,
49
+ "factor": 16,
50
+ "original_max_position_embeddings": 65536,
51
+ "type": "yarn"
52
+ },
53
+ "rope_theta": 10000,
54
+ "routed_scaling_factor": 1.5,
55
+ "scoring_func": "sqrtsoftplus",
56
+ "sliding_window": 128,
57
+ "swiglu_limit": 10.0,
58
+ "tie_word_embeddings": false,
59
+ "topk_method": "noaux_tc",
60
+ "torch_dtype": "bfloat16",
61
+ "transformers_version": "4.57.1",
62
+ "use_cache": true,
63
+ "vocab_size": 129280,
64
+ "compress_rope_theta": 160000,
65
+ "compress_ratios": [
66
+ 0,
67
+ 0,
68
+ 4,
69
+ 128,
70
+ 4,
71
+ 128,
72
+ 4,
73
+ 128,
74
+ 4,
75
+ 128,
76
+ 4,
77
+ 128,
78
+ 4,
79
+ 128,
80
+ 4,
81
+ 128,
82
+ 4,
83
+ 128,
84
+ 4,
85
+ 128,
86
+ 4,
87
+ 128,
88
+ 4,
89
+ 128,
90
+ 4,
91
+ 128,
92
+ 4,
93
+ 128,
94
+ 4,
95
+ 128,
96
+ 4,
97
+ 128,
98
+ 4,
99
+ 128,
100
+ 4,
101
+ 128,
102
+ 4,
103
+ 128,
104
+ 4,
105
+ 128,
106
+ 4,
107
+ 128,
108
+ 4,
109
+ 0
110
+ ]
111
+ }
encoding/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek-V4 Encoding
2
+
3
+ This document describes the prompt encoding format used by DeepSeek-V4 series models. The encoding handles multi-turn conversations, tool calling, extended thinking (reasoning), and quick instruction tasks.
4
+
5
+ A self-contained reference implementation is provided in `encoding_dsv4.py`.
6
+
7
+ ## Quick Start
8
+
9
+ ```python
10
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
11
+
12
+ # Encode a conversation
13
+ messages = [
14
+ {"role": "system", "content": "You are a helpful assistant."},
15
+ {"role": "user", "content": "What is 2+2?"},
16
+ ]
17
+ prompt = encode_messages(messages, thinking_mode="thinking")
18
+ # => "<|begin▁of▁sentence|>You are a helpful assistant.<|User|>What is 2+2?<|Assistant|><think>"
19
+
20
+ # Parse model output back to structured message
21
+ completion = "Simple arithmetic.</think>2 + 2 = 4.<|end▁of▁sentence|>"
22
+ parsed = parse_message_from_completion_text(completion, thinking_mode="thinking")
23
+ # => {"role": "assistant", "reasoning_content": "Simple arithmetic.", "content": "2 + 2 = 4.", "tool_calls": []}
24
+ ```
25
+
26
+ > **Note:** The `parse_message_from_completion_text` function is designed to handle well-formatted model output only. It does not attempt to correct or recover from malformed output that the model might occasionally generate. For production use, additional error handling is recommended.
27
+
28
+ ## Message Format
29
+
30
+ ### Special Tokens
31
+
32
+ | Token | Purpose |
33
+ |-------|---------|
34
+ | `<|begin▁of▁sentence|>` | Beginning of sequence (BOS) |
35
+ | `<|end▁of▁sentence|>` | End of assistant turn (EOS) |
36
+ | `<|User|>` | User turn prefix |
37
+ | `<|Assistant|>` | Assistant turn prefix |
38
+ | `<|latest_reminder|>` | Latest reminder (date, locale, etc.) |
39
+ | `<think>` / `</think>` | Reasoning block delimiters |
40
+ | `|DSML|` | DSML markup token |
41
+
42
+ ### Roles
43
+
44
+ The encoding supports the following message roles: `system`, `user`, `assistant`, `tool`, `latest_reminder`, and `developer`.
45
+
46
+ > **Note on the `developer` role:** The `developer` role is used exclusively in the internal search agent pipeline. It is not needed for general-purpose chat or tool-calling tasks, and the official API does not accept messages with this role.
47
+
48
+ ### Basic Chat
49
+
50
+ A simple multi-turn conversation is encoded as:
51
+
52
+ ```
53
+ <|begin▁of▁sentence|>{system_prompt}
54
+ <|User|>{user_message}<|Assistant|></think>{response}<|end▁of▁sentence|>
55
+ <|User|>{user_message_2}<|Assistant|></think>{response_2}<|end▁of▁sentence|>
56
+ ```
57
+
58
+ - The BOS token is prepended at the very beginning of the conversation.
59
+ - In **chat mode** (`thinking_mode="chat"`), `</think>` is placed right after `<|Assistant|>` to immediately close the thinking block, so the model generates content directly.
60
+
61
+ ### Interleaved Thinking Mode
62
+
63
+ In **thinking mode** (`thinking_mode="thinking"`), the model produces explicit reasoning inside `<think>...</think>` blocks before responding.
64
+
65
+ ```
66
+ <|begin▁of▁sentence|>{system_prompt}
67
+ <|User|>{message}<|Assistant|><think>{reasoning}</think>{response}<|end▁of▁sentence|>
68
+ ```
69
+
70
+ The `drop_thinking` parameter (default `True`) controls whether reasoning from earlier turns is preserved:
71
+
72
+ - **Without tools**: `drop_thinking` takes effect. Reasoning content from assistant turns **before** the last user message is stripped. Only the final assistant turn retains its `<think>...</think>` block.
73
+ - **With tools** (on system or developer message): `drop_thinking` is automatically disabled. All turns retain their reasoning, because tool-calling conversations require full context for the model to track multi-step reasoning across tool calls.
74
+
75
+ ### Tool Calling (DSML Format)
76
+
77
+ Tools are defined on the `system` or `developer` message via the `tools` field (OpenAI-compatible format). When tools are present, the following schema block is injected into the system/user prompt:
78
+
79
+ ```
80
+ ## Tools
81
+
82
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
83
+
84
+ <|DSML|tool_calls>
85
+ <|DSML|invoke name="$TOOL_NAME">
86
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
87
+ ...
88
+ </|DSML|invoke>
89
+ <|DSML|invoke name="$TOOL_NAME2">
90
+ ...
91
+ </|DSML|invoke>
92
+ </|DSML|tool_calls>
93
+
94
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
95
+
96
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
97
+
98
+ Otherwise, output directly after </think> with tool calls or final response.
99
+
100
+ ### Available Tool Schemas
101
+
102
+ {tool_definitions_json}
103
+
104
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
105
+ ```
106
+
107
+ An actual tool call in the assistant turn looks like:
108
+
109
+ ```xml
110
+ <|DSML|tool_calls>
111
+ <|DSML|invoke name="function_name">
112
+ <|DSML|parameter name="param" string="true">string_value</|DSML|parameter>
113
+ <|DSML|parameter name="count" string="false">5</|DSML|parameter>
114
+ </|DSML|invoke>
115
+ </|DSML|tool_calls><|end▁of▁sentence|>
116
+ ```
117
+
118
+ - `string="true"`: the parameter value is a raw string.
119
+ - `string="false"`: the parameter value is JSON (number, boolean, array, object).
120
+
121
+ Tool execution results are wrapped in `<tool_result>` tags within user messages:
122
+
123
+ ```
124
+ <|User|><tool_result>{result_json}</tool_result><|Assistant|><think>...
125
+ ```
126
+
127
+ When multiple tool results are present, they are sorted by the order of the corresponding `tool_calls` in the preceding assistant message.
128
+
129
+ ### Reasoning Effort
130
+
131
+ When `reasoning_effort="max"` is set, a special prefix is prepended at the very beginning of the prompt (before the system message) to instruct the model to maximize its reasoning depth:
132
+
133
+ ```
134
+ Reasoning Effort: Absolute maximum with no shortcuts permitted.
135
+ You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.
136
+ Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.
137
+ ```
138
+
139
+ ### Quick Instruction Special Tokens
140
+
141
+ Quick instruction tokens are used for auxiliary classification and generation tasks. They are appended to messages via the `"task"` field to trigger specialized model behavior for a single-token or short-form output.
142
+
143
+ | Special Token | Description | Format |
144
+ |:---|:---|:---|
145
+ | `<|action|>` | Determines whether the user prompt requires a web search or can be answered directly. | `...<|User|>{prompt}<|Assistant|><think><|action|>` |
146
+ | `<|title|>` | Generates a concise conversation title after the first assistant response. | `...<|Assistant|>{response}<|end▁of▁sentence|><|title|>` |
147
+ | `<|query|>` | Generates search queries for the user prompt. | `...<|User|>{prompt}<|query|>` |
148
+ | `<|authority|>` | Classifies the user prompt's demand for source authoritativeness. | `...<|User|>{prompt}<|authority|>` |
149
+ | `<|domain|>` | Identifies the domain of the user prompt. | `...<|User|>{prompt}<|domain|>` |
150
+ | `<|extracted_url|>` `<|read_url|>` | Determines whether each URL in the user prompt should be fetched and read. | `...<|User|>{prompt}<|extracted_url|>{url}<|read_url|>` |
151
+
152
+ Usage in message format:
153
+
154
+ - **`action`** on a user message: the `<|action|>` token is placed after the assistant prefix and thinking token, triggering a routing decision (e.g., "Search" or "Answer").
155
+ - **Other tasks** (`query`, `authority`, `domain`, `read_url`) on a user message: the task token is appended directly after the user content.
156
+ - **`title`** on an assistant message: the `<|title|>` token is appended after the assistant's EOS. The next assistant message provides the generated title.
encoding/encoding_dsv4.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek-V4 Encoding
3
+
4
+ A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
5
+ with tool calling, thinking mode, and quick instruction task support.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Union, Optional, Tuple
9
+ import copy
10
+ import json
11
+ import re
12
+
13
+ # ============================================================
14
+ # Special Tokens
15
+ # ============================================================
16
+
17
+ bos_token: str = "<|begin▁of▁sentence|>"
18
+ eos_token: str = "<|end▁of▁sentence|>"
19
+ thinking_start_token: str = "<think>"
20
+ thinking_end_token: str = "</think>"
21
+ dsml_token: str = "|DSML|"
22
+
23
+ USER_SP_TOKEN = "<|User|>"
24
+ ASSISTANT_SP_TOKEN = "<|Assistant|>"
25
+ LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
26
+
27
+ # Task special tokens for internal classification tasks
28
+ DS_TASK_SP_TOKENS = {
29
+ "action": "<|action|>",
30
+ "query": "<|query|>",
31
+ "authority": "<|authority|>",
32
+ "domain": "<|domain|>",
33
+ "title": "<|title|>",
34
+ "read_url": "<|read_url|>",
35
+ }
36
+ VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
37
+
38
+ # ============================================================
39
+ # Templates
40
+ # ============================================================
41
+
42
+ system_msg_template: str = "{content}"
43
+ user_msg_template: str = "{content}"
44
+ latest_reminder_msg_template: str = "{content}"
45
+ assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
46
+ assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
47
+ thinking_template: str = "{reasoning_content}"
48
+
49
+ response_format_template: str = (
50
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
51
+ )
52
+ tool_call_template: str = (
53
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
54
+ )
55
+ tool_calls_template = (
56
+ "<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
57
+ )
58
+ tool_calls_block_name: str = "tool_calls"
59
+
60
+ tool_output_template: str = (
61
+ "<tool_result>{content}</tool_result>"
62
+ )
63
+
64
+ REASONING_EFFORT_MAX = (
65
+ "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
66
+ "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
67
+ "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
68
+ )
69
+
70
+ TOOLS_TEMPLATE = """## Tools
71
+
72
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
73
+
74
+ <{dsml_token}tool_calls>
75
+ <{dsml_token}invoke name="$TOOL_NAME">
76
+ <{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
77
+ ...
78
+ </{dsml_token}invoke>
79
+ <{dsml_token}invoke name="$TOOL_NAME2">
80
+ ...
81
+ </{dsml_token}invoke>
82
+ </{dsml_token}tool_calls>
83
+
84
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
85
+
86
+ If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
87
+
88
+ Otherwise, output directly after {thinking_end_token} with tool calls or final response.
89
+
90
+ ### Available Tool Schemas
91
+
92
+ {tool_schemas}
93
+
94
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
95
+ """
96
+
97
+ # ============================================================
98
+ # Utility Functions
99
+ # ============================================================
100
+
101
+ def to_json(value: Any) -> str:
102
+ """Serialize a value to JSON string."""
103
+ try:
104
+ return json.dumps(value, ensure_ascii=False)
105
+ except:
106
+ return json.dumps(value, ensure_ascii=True)
107
+
108
+
109
+ def tools_from_openai_format(tools):
110
+ """Extract function definitions from OpenAI-format tool list."""
111
+ return [tool["function"] for tool in tools]
112
+
113
+
114
+ def tool_calls_from_openai_format(tool_calls):
115
+ """Convert OpenAI-format tool calls to internal format."""
116
+ return [
117
+ {
118
+ "name": tool_call["function"]["name"],
119
+ "arguments": tool_call["function"]["arguments"],
120
+ }
121
+ for tool_call in tool_calls
122
+ ]
123
+
124
+
125
+ def tool_calls_to_openai_format(tool_calls):
126
+ """Convert internal tool calls to OpenAI format."""
127
+ return [
128
+ {
129
+ "type": "function",
130
+ "function": {
131
+ "name": tool_call["name"],
132
+ "arguments": tool_call["arguments"],
133
+ }
134
+ }
135
+ for tool_call in tool_calls
136
+ ]
137
+
138
+
139
+ def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
140
+ """
141
+ Encode tool call arguments into DSML parameter format.
142
+
143
+ Args:
144
+ tool_call: Dict with "name" and "arguments" (JSON string) keys.
145
+
146
+ Returns:
147
+ DSML-formatted parameter string.
148
+ """
149
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
150
+ P_dsml_strs = []
151
+
152
+ try:
153
+ arguments = json.loads(tool_call["arguments"])
154
+ except Exception as err:
155
+ arguments = {"arguments": tool_call["arguments"]}
156
+
157
+ for k, v in arguments.items():
158
+ p_dsml_str = p_dsml_template.format(
159
+ dsml_token=dsml_token,
160
+ key=k,
161
+ is_str="true" if isinstance(v, str) else "false",
162
+ value=v if isinstance(v, str) else to_json(v),
163
+ )
164
+ P_dsml_strs.append(p_dsml_str)
165
+
166
+ return "\n".join(P_dsml_strs)
167
+
168
+
169
+ def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
170
+ """
171
+ Decode DSML parameters back to a tool call dict.
172
+
173
+ Args:
174
+ tool_name: Name of the tool.
175
+ tool_args: Dict mapping param_name -> (value, is_string_flag).
176
+
177
+ Returns:
178
+ Dict with "name" and "arguments" (JSON string) keys.
179
+ """
180
+ def _decode_value(key: str, value: str, string: str):
181
+ if string == "true":
182
+ value = to_json(value)
183
+ return f"{to_json(key)}: {value}"
184
+
185
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
186
+ return dict(name=tool_name, arguments=tool_args_json)
187
+
188
+
189
+ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
190
+ """
191
+ Render tool schemas into the system prompt format.
192
+
193
+ Args:
194
+ tools: List of tool schema dicts (each with name, description, parameters).
195
+
196
+ Returns:
197
+ Formatted tools section string.
198
+ """
199
+ tools_json = [to_json(t) for t in tools]
200
+
201
+ return TOOLS_TEMPLATE.format(
202
+ tool_schemas="\n".join(tools_json),
203
+ dsml_token=dsml_token,
204
+ thinking_start_token=thinking_start_token,
205
+ thinking_end_token=thinking_end_token,
206
+ )
207
+
208
+
209
+ def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
210
+ """Find the index of the last user/developer message."""
211
+ last_user_index = -1
212
+ for idx in range(len(messages) - 1, -1, -1):
213
+ if messages[idx].get("role") in ["user", "developer"]:
214
+ last_user_index = idx
215
+ break
216
+ return last_user_index
217
+
218
+
219
+ # ============================================================
220
+ # Message Rendering
221
+ # ============================================================
222
+
223
+ def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
224
+ """
225
+ Render a single message at the given index into its encoded string form.
226
+
227
+ This is the core function that converts each message in the conversation
228
+ into the DeepSeek-V4 format.
229
+
230
+ Args:
231
+ index: Index of the message to render.
232
+ messages: Full list of messages in the conversation.
233
+ thinking_mode: Either "chat" or "thinking".
234
+ drop_thinking: Whether to drop reasoning content from earlier turns.
235
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
236
+
237
+ Returns:
238
+ Encoded string for this message.
239
+ """
240
+ assert 0 <= index < len(messages)
241
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
242
+
243
+ prompt = ""
244
+ msg = messages[index]
245
+ last_user_idx = find_last_user_index(messages)
246
+
247
+ role = msg.get("role")
248
+ content = msg.get("content")
249
+ tools = msg.get("tools")
250
+ response_format = msg.get("response_format")
251
+ tool_calls = msg.get("tool_calls")
252
+ reasoning_content = msg.get("reasoning_content")
253
+ wo_eos = msg.get("wo_eos", False)
254
+
255
+ if tools:
256
+ tools = tools_from_openai_format(tools)
257
+ if tool_calls:
258
+ tool_calls = tool_calls_from_openai_format(tool_calls)
259
+
260
+ # Reasoning effort prefix (only at index 0 in thinking mode with max effort)
261
+ assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
262
+ if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
263
+ prompt += REASONING_EFFORT_MAX
264
+
265
+ if role == "system":
266
+ prompt += system_msg_template.format(content=content or "")
267
+ if tools:
268
+ prompt += "\n\n" + render_tools(tools)
269
+ if response_format:
270
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
271
+
272
+ elif role == "developer":
273
+ assert content, f"Invalid message for role `{role}`: {msg}"
274
+
275
+ content_developer = USER_SP_TOKEN
276
+ content_developer += content
277
+
278
+ if tools:
279
+ content_developer += "\n\n" + render_tools(tools)
280
+ if response_format:
281
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
282
+
283
+ prompt += user_msg_template.format(content=content_developer)
284
+
285
+ elif role == "user":
286
+ prompt += USER_SP_TOKEN
287
+
288
+ # Handle content blocks (tool results mixed with text)
289
+ content_blocks = msg.get("content_blocks")
290
+ if content_blocks:
291
+ parts = []
292
+ for block in content_blocks:
293
+ block_type = block.get("type")
294
+ if block_type == "text":
295
+ parts.append(block.get("text", ""))
296
+ elif block_type == "tool_result":
297
+ tool_content = block.get("content", "")
298
+ if isinstance(tool_content, list):
299
+ text_parts = []
300
+ for b in tool_content:
301
+ if b.get("type") == "text":
302
+ text_parts.append(b.get("text", ""))
303
+ else:
304
+ text_parts.append(f"[Unsupported {b.get('type')}]")
305
+ tool_content = "\n\n".join(text_parts)
306
+ parts.append(tool_output_template.format(content=tool_content))
307
+ else:
308
+ parts.append(f"[Unsupported {block_type}]")
309
+ prompt += "\n\n".join(parts)
310
+ else:
311
+ prompt += content or ""
312
+
313
+ elif role == "latest_reminder":
314
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
315
+
316
+ elif role == "tool":
317
+ raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
318
+
319
+ elif role == "assistant":
320
+ thinking_part = ""
321
+ tc_content = ""
322
+
323
+ if tool_calls:
324
+ tc_list = [
325
+ tool_call_template.format(
326
+ dsml_token=dsml_token,
327
+ name=tc.get("name"),
328
+ arguments=encode_arguments_to_dsml(tc)
329
+ )
330
+ for tc in tool_calls
331
+ ]
332
+ tc_content += '\n\n' + tool_calls_template.format(
333
+ dsml_token=dsml_token,
334
+ tool_calls="\n".join(tc_list),
335
+ tc_block_name=tool_calls_block_name,
336
+ )
337
+
338
+ summary_content = content or ""
339
+ rc = reasoning_content or ""
340
+
341
+ # Check if previous message has a task - if so, this is a task output (no thinking)
342
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
343
+
344
+ if thinking_mode == "thinking" and not prev_has_task:
345
+ if not drop_thinking or index > last_user_idx:
346
+ thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
347
+ else:
348
+ thinking_part = ""
349
+
350
+ if wo_eos:
351
+ prompt += assistant_msg_wo_eos_template.format(
352
+ reasoning=thinking_part,
353
+ content=summary_content,
354
+ tool_calls=tc_content,
355
+ )
356
+ else:
357
+ prompt += assistant_msg_template.format(
358
+ reasoning=thinking_part,
359
+ content=summary_content,
360
+ tool_calls=tc_content,
361
+ )
362
+ else:
363
+ raise NotImplementedError(f"Unknown role: {role}")
364
+
365
+ # Append transition tokens based on what follows
366
+ if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
367
+ return prompt
368
+
369
+ task = messages[index].get("task")
370
+ if task is not None:
371
+ # Task special token for internal classification tasks
372
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
373
+ task_sp_token = DS_TASK_SP_TOKENS[task]
374
+
375
+ if task != "action":
376
+ # Non-action tasks: append task sp token directly after the message
377
+ prompt += task_sp_token
378
+ else:
379
+ # Action task: append Assistant + thinking token + action sp token
380
+ prompt += ASSISTANT_SP_TOKEN
381
+ prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
382
+ prompt += task_sp_token
383
+
384
+ elif messages[index].get("role") in ["user", "developer"]:
385
+ # Normal generation: append Assistant + thinking token
386
+ prompt += ASSISTANT_SP_TOKEN
387
+ if not drop_thinking and thinking_mode == "thinking":
388
+ prompt += thinking_start_token
389
+ elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
390
+ prompt += thinking_start_token
391
+ else:
392
+ prompt += thinking_end_token
393
+
394
+ return prompt
395
+
396
+
397
+ # ============================================================
398
+ # Preprocessing
399
+ # ============================================================
400
+
401
+ def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
402
+ """
403
+ Merge tool messages into the preceding user message using content_blocks format.
404
+
405
+ DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
406
+ are encoded as <tool_result> blocks within user messages.
407
+
408
+ This function converts a standard OpenAI-format conversation (with separate
409
+ "tool" role messages) into V4 format where tool results are merged into
410
+ user messages.
411
+
412
+ Args:
413
+ messages: List of message dicts in OpenAI format.
414
+
415
+ Returns:
416
+ Processed message list with tool messages merged into user messages.
417
+ """
418
+ merged: List[Dict[str, Any]] = []
419
+
420
+ for msg in messages:
421
+ msg = copy.deepcopy(msg)
422
+ role = msg.get("role")
423
+
424
+ if role == "tool":
425
+ # Convert tool message to a user message with tool_result block
426
+ tool_block = {
427
+ "type": "tool_result",
428
+ "tool_use_id": msg.get("tool_call_id", ""),
429
+ "content": msg.get("content", ""),
430
+ }
431
+ # Merge into previous message if it's already a user (merged tool)
432
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
433
+ merged[-1]["content_blocks"].append(tool_block)
434
+ else:
435
+ merged.append({
436
+ "role": "user",
437
+ "content_blocks": [tool_block],
438
+ })
439
+ elif role == "user":
440
+ text_block = {"type": "text", "text": msg.get("content", "")}
441
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
442
+ merged[-1]["content_blocks"].append(text_block)
443
+ else:
444
+ new_msg = {
445
+ "role": "user",
446
+ "content": msg.get("content", ""),
447
+ "content_blocks": [text_block],
448
+ }
449
+ # Preserve extra fields (task, wo_eos, mask, etc.)
450
+ for key in ("task", "wo_eos", "mask"):
451
+ if key in msg:
452
+ new_msg[key] = msg[key]
453
+ merged.append(new_msg)
454
+ else:
455
+ merged.append(msg)
456
+
457
+ return merged
458
+
459
+
460
+ def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
461
+ """
462
+ Sort tool_result blocks within user messages by the order of tool_calls
463
+ in the preceding assistant message.
464
+
465
+ Args:
466
+ messages: Preprocessed message list (after merge_tool_messages).
467
+
468
+ Returns:
469
+ Message list with sorted tool result blocks.
470
+ """
471
+ last_tool_call_order: Dict[str, int] = {}
472
+
473
+ for msg in messages:
474
+ role = msg.get("role")
475
+ if role == "assistant" and msg.get("tool_calls"):
476
+ last_tool_call_order = {}
477
+ for idx, tc in enumerate(msg["tool_calls"]):
478
+ tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
479
+ if tc_id:
480
+ last_tool_call_order[tc_id] = idx
481
+
482
+ elif role == "user" and msg.get("content_blocks"):
483
+ tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
484
+ if len(tool_blocks) > 1 and last_tool_call_order:
485
+ sorted_blocks = sorted(
486
+ tool_blocks,
487
+ key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
488
+ )
489
+ sorted_idx = 0
490
+ new_blocks = []
491
+ for block in msg["content_blocks"]:
492
+ if block.get("type") == "tool_result":
493
+ new_blocks.append(sorted_blocks[sorted_idx])
494
+ sorted_idx += 1
495
+ else:
496
+ new_blocks.append(block)
497
+ msg["content_blocks"] = new_blocks
498
+
499
+ return messages
500
+
501
+
502
+ # ============================================================
503
+ # Main Encoding Function
504
+ # ============================================================
505
+
506
+ def encode_messages(
507
+ messages: List[Dict[str, Any]],
508
+ thinking_mode: str,
509
+ context: Optional[List[Dict[str, Any]]] = None,
510
+ drop_thinking: bool = True,
511
+ add_default_bos_token: bool = True,
512
+ reasoning_effort: Optional[str] = None,
513
+ ) -> str:
514
+ """
515
+ Encode a list of messages into the DeepSeek-V4 prompt format.
516
+
517
+ This is the main entry point for encoding conversations. It handles:
518
+ - BOS token insertion
519
+ - Thinking mode with optional reasoning content dropping
520
+ - Tool message merging into user messages
521
+ - Multi-turn conversation context
522
+
523
+ Args:
524
+ messages: List of message dicts to encode.
525
+ thinking_mode: Either "chat" or "thinking".
526
+ context: Optional preceding context messages (already encoded prefix).
527
+ drop_thinking: If True, drop reasoning_content from earlier assistant turns
528
+ (only keep reasoning for messages after the last user message).
529
+ add_default_bos_token: Whether to prepend BOS token at conversation start.
530
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
531
+
532
+ Returns:
533
+ The encoded prompt string.
534
+ """
535
+ context = context if context else []
536
+
537
+ # Preprocess: merge tool messages and sort tool results
538
+ messages = merge_tool_messages(messages)
539
+ messages = sort_tool_results_by_call_order(context + messages)[len(context):]
540
+ if context:
541
+ context = merge_tool_messages(context)
542
+ context = sort_tool_results_by_call_order(context)
543
+
544
+ full_messages = context + messages
545
+
546
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
547
+
548
+ # Resolve drop_thinking: if any message has tools defined, don't drop thinking
549
+ effective_drop_thinking = drop_thinking
550
+ if any(m.get("tools") for m in full_messages):
551
+ effective_drop_thinking = False
552
+
553
+ if thinking_mode == "thinking" and effective_drop_thinking:
554
+ full_messages = _drop_thinking_messages(full_messages)
555
+ # After dropping, recalculate how many messages to render
556
+ # (context may have shrunk too)
557
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
558
+ context_len = len(full_messages) - num_to_render
559
+ else:
560
+ num_to_render = len(messages)
561
+ context_len = len(context)
562
+
563
+ for idx in range(num_to_render):
564
+ prompt += render_message(
565
+ idx + context_len,
566
+ full_messages,
567
+ thinking_mode=thinking_mode,
568
+ drop_thinking=effective_drop_thinking,
569
+ reasoning_effort=reasoning_effort,
570
+ )
571
+
572
+ return prompt
573
+
574
+
575
+ def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
576
+ """
577
+ Drop reasoning_content and non-essential messages before the last user message.
578
+
579
+ Behavior:
580
+ - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
581
+ - Messages at or after the last user index are always kept.
582
+ - Assistant messages before the last user get reasoning_content removed.
583
+ - Developer messages before the last user are dropped entirely.
584
+ """
585
+ last_user_idx = find_last_user_index(messages)
586
+ result = []
587
+ keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
588
+
589
+ for idx, msg in enumerate(messages):
590
+ role = msg.get("role")
591
+ if role in keep_roles or idx >= last_user_idx:
592
+ result.append(msg)
593
+ elif role == "assistant":
594
+ msg = copy.copy(msg)
595
+ msg.pop("reasoning_content", None)
596
+ result.append(msg)
597
+ # developer and other roles before last_user_idx are dropped
598
+
599
+ return result
600
+
601
+
602
+ # ============================================================
603
+ # Parsing (Decoding model output)
604
+ # ============================================================
605
+
606
+ def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
607
+ """
608
+ Read text from index until one of the stop strings is found.
609
+
610
+ Returns:
611
+ Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
612
+ """
613
+ min_pos = len(text)
614
+ matched_stop = None
615
+
616
+ for s in stop:
617
+ pos = text.find(s, index)
618
+ if pos != -1 and pos < min_pos:
619
+ min_pos = pos
620
+ matched_stop = s
621
+
622
+ if matched_stop:
623
+ content = text[index:min_pos]
624
+ return min_pos + len(matched_stop), content, matched_stop
625
+ else:
626
+ content = text[index:]
627
+ return len(text), content, None
628
+
629
+
630
+ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
631
+ """
632
+ Parse DSML tool calls from text starting at the given index.
633
+
634
+ Args:
635
+ index: Starting position in text.
636
+ text: The full text to parse.
637
+
638
+ Returns:
639
+ Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
640
+ Each tool call dict has "name" and "arguments" keys.
641
+ """
642
+ tool_calls: List[Dict[str, Any]] = []
643
+ stop_token = None
644
+ tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
645
+
646
+ while index < len(text):
647
+ index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
648
+ if _ != ">\n":
649
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
650
+
651
+ if stop_token == tool_calls_end_token:
652
+ break
653
+
654
+ if stop_token is None:
655
+ raise ValueError("Missing special token in tool calls")
656
+
657
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
658
+
659
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
660
+ if len(p_tool_name) != 1:
661
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
662
+ tool_name = p_tool_name[0]
663
+
664
+ tool_args: Dict[str, Tuple[str, str]] = {}
665
+ while stop_token == f"<{dsml_token}parameter":
666
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
667
+
668
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
669
+ if len(param_kv) != 1:
670
+ raise ValueError(f"Parameter format error: '{param_content}'")
671
+ param_name, string, param_value = param_kv[0]
672
+
673
+ if param_name in tool_args:
674
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
675
+ tool_args[param_name] = (param_value, string)
676
+
677
+ index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
678
+ if content != ">\n":
679
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
680
+
681
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
682
+ tool_calls.append(tool_call)
683
+
684
+ return index, stop_token, tool_calls
685
+
686
+
687
+ def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
688
+ """
689
+ Parse a model completion text into a structured assistant message.
690
+
691
+ This function takes the raw text output from the model (a single assistant turn)
692
+ and extracts:
693
+ - reasoning_content (thinking block)
694
+ - content (summary/response)
695
+ - tool_calls (if any)
696
+
697
+ NOTE: This function is designed to parse only correctly formatted strings and
698
+ will raise ValueError for malformed output.
699
+
700
+ Args:
701
+ text: The raw completion text (including EOS token).
702
+ thinking_mode: Either "chat" or "thinking".
703
+
704
+ Returns:
705
+ Dict with keys: "role", "content", "reasoning_content", "tool_calls".
706
+ tool_calls are in OpenAI format.
707
+ """
708
+ summary_content, reasoning_content, tool_calls = "", "", []
709
+ index, stop_token = 0, None
710
+ tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
711
+
712
+ is_thinking = thinking_mode == "thinking"
713
+ is_tool_calling = False
714
+
715
+ if is_thinking:
716
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
717
+ reasoning_content = content_delta
718
+ assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
719
+
720
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
721
+ summary_content = content_delta
722
+ if stop_token == tool_calls_start_token:
723
+ is_tool_calling = True
724
+ else:
725
+ assert stop_token == eos_token, "Invalid format: missing EOS token"
726
+
727
+ if is_tool_calling:
728
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
729
+
730
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
731
+ assert not tool_ends_text, "Unexpected content after tool calls"
732
+
733
+ assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
734
+
735
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
736
+ assert sp_token not in summary_content and sp_token not in reasoning_content, \
737
+ f"Unexpected special token '{sp_token}' in content"
738
+
739
+ return {
740
+ "role": "assistant",
741
+ "content": summary_content,
742
+ "reasoning_content": reasoning_content,
743
+ "tool_calls": tool_calls_to_openai_format(tool_calls)
744
+ }
encoding/test_encoding_dsv4.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test suite for DeepSeek-V4 Encoding.
3
+
4
+ Run: python test_encoding_dsv4.py
5
+ """
6
+
7
+ import json
8
+ import os
9
+
10
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
11
+
12
+ TESTS_DIR = os.path.join(os.path.dirname(__file__), "tests")
13
+
14
+
15
+ def test_case_1():
16
+ """Thinking mode with tool calls (multi-turn, tool results merged into user)."""
17
+ with open(os.path.join(TESTS_DIR, "test_input_1.json")) as f:
18
+ td = json.load(f)
19
+ messages = td["messages"]
20
+ messages[0]["tools"] = td["tools"]
21
+ gold = open(os.path.join(TESTS_DIR, "test_output_1.txt")).read()
22
+ prompt = encode_messages(messages, thinking_mode="thinking")
23
+ assert prompt == gold
24
+
25
+ # Parse: assistant turn with tool call
26
+ marker = "<|Assistant|><think>"
27
+ first_start = prompt.find(marker) + len(marker)
28
+ first_end = prompt.find("<|User|>", first_start)
29
+ parsed_tc = parse_message_from_completion_text(prompt[first_start:first_end], thinking_mode="thinking")
30
+ assert parsed_tc["reasoning_content"] == "The user wants to know the weather in Beijing. I should use the get_weather tool."
31
+ assert parsed_tc["content"] == ""
32
+ assert len(parsed_tc["tool_calls"]) == 1
33
+ assert parsed_tc["tool_calls"][0]["function"]["name"] == "get_weather"
34
+ assert json.loads(parsed_tc["tool_calls"][0]["function"]["arguments"]) == {"location": "Beijing", "unit": "celsius"}
35
+
36
+ # Parse: final assistant turn with content
37
+ last_start = prompt.rfind(marker) + len(marker)
38
+ parsed_final = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
39
+ assert parsed_final["reasoning_content"] == "Got the weather data. Let me format a nice response."
40
+ assert "22°C" in parsed_final["content"]
41
+ assert parsed_final["tool_calls"] == []
42
+
43
+ print(" [PASS] case 1: thinking with tools (encode + parse)")
44
+
45
+
46
+ def test_case_2():
47
+ """Thinking mode without tools (drop_thinking removes earlier reasoning)."""
48
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_2.json")))
49
+ gold = open(os.path.join(TESTS_DIR, "test_output_2.txt")).read()
50
+ prompt = encode_messages(messages, thinking_mode="thinking")
51
+ assert prompt == gold
52
+
53
+ # Parse: last assistant turn
54
+ marker = "<|Assistant|><think>"
55
+ last_start = prompt.rfind(marker) + len(marker)
56
+ parsed = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
57
+ assert parsed["reasoning_content"] == "The user asks about the capital of France. It is Paris."
58
+ assert parsed["content"] == "The capital of France is Paris."
59
+ assert parsed["tool_calls"] == []
60
+
61
+ # Verify drop_thinking: first assistant's reasoning should be absent
62
+ assert "The user said hello" not in prompt
63
+
64
+ print(" [PASS] case 2: thinking without tools (encode + parse)")
65
+
66
+
67
+ def test_case_3():
68
+ """Interleaved thinking + search (developer with tools, latest_reminder)."""
69
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_3.json")))
70
+ gold = open(os.path.join(TESTS_DIR, "test_output_3.txt")).read()
71
+ assert encode_messages(messages, thinking_mode="thinking") == gold
72
+ print(" [PASS] case 3: interleaved thinking + search")
73
+
74
+
75
+ def test_case_4():
76
+ """Quick instruction task with latest_reminder (chat mode, action task)."""
77
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_4.json")))
78
+ gold = open(os.path.join(TESTS_DIR, "test_output_4.txt")).read()
79
+ assert encode_messages(messages, thinking_mode="chat") == gold
80
+ print(" [PASS] case 4: quick instruction task")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ print("Running DeepSeek-V4 Encoding Tests...\n")
85
+ test_case_1()
86
+ test_case_2()
87
+ test_case_3()
88
+ test_case_4()
89
+ print("\nAll 4 tests passed!")
encoding/tests/test_input_1.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tools": [
3
+ {
4
+ "type": "function",
5
+ "function": {
6
+ "name": "get_weather",
7
+ "description": "Get the weather for a specific location",
8
+ "parameters": {
9
+ "type": "object",
10
+ "properties": {
11
+ "location": {
12
+ "type": "string",
13
+ "description": "The city name"
14
+ },
15
+ "unit": {
16
+ "type": "string",
17
+ "enum": ["celsius", "fahrenheit"],
18
+ "description": "Temperature unit"
19
+ }
20
+ },
21
+ "required": ["location"]
22
+ }
23
+ }
24
+ },
25
+ {
26
+ "type": "function",
27
+ "function": {
28
+ "name": "search",
29
+ "description": "Search the web for information",
30
+ "parameters": {
31
+ "type": "object",
32
+ "properties": {
33
+ "query": {
34
+ "type": "string",
35
+ "description": "Search query"
36
+ },
37
+ "num_results": {
38
+ "type": "integer",
39
+ "description": "Number of results to return"
40
+ }
41
+ },
42
+ "required": ["query"]
43
+ }
44
+ }
45
+ }
46
+ ],
47
+ "messages": [
48
+ {
49
+ "role": "system",
50
+ "content": "You are a helpful assistant."
51
+ },
52
+ {
53
+ "role": "user",
54
+ "content": "What's the weather in Beijing?"
55
+ },
56
+ {
57
+ "role": "assistant",
58
+ "reasoning_content": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
59
+ "tool_calls": [
60
+ {
61
+ "id": "call_001",
62
+ "type": "function",
63
+ "function": {
64
+ "name": "get_weather",
65
+ "arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
66
+ }
67
+ }
68
+ ]
69
+ },
70
+ {
71
+ "role": "tool",
72
+ "tool_call_id": "call_001",
73
+ "content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
74
+ },
75
+ {
76
+ "role": "assistant",
77
+ "reasoning_content": "Got the weather data. Let me format a nice response.",
78
+ "content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
79
+ }
80
+ ]
81
+ }
encoding/tests/test_input_2.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "You are a helpful assistant."
5
+ },
6
+ {
7
+ "role": "user",
8
+ "content": "Hello"
9
+ },
10
+ {
11
+ "role": "assistant",
12
+ "reasoning_content": "The user said hello, I should greet back.",
13
+ "content": "Hi there! How can I help you?"
14
+ },
15
+ {
16
+ "role": "user",
17
+ "content": "What is the capital of France?"
18
+ },
19
+ {
20
+ "role": "assistant",
21
+ "reasoning_content": "The user asks about the capital of France. It is Paris.",
22
+ "content": "The capital of France is Paris."
23
+ }
24
+ ]
encoding/tests/test_input_3.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "该助手为DeepSeek,由深度求索公司创造。"
5
+ },
6
+ {
7
+ "role": "latest_reminder",
8
+ "content": "2026-02-21,星期六,广州,App,中文"
9
+ },
10
+ {
11
+ "role": "developer",
12
+ "content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
13
+ "tools": [
14
+ {
15
+ "type": "function",
16
+ "function": {
17
+ "name": "search",
18
+ "description": "Web search. Split multiple queries with '||'.",
19
+ "parameters": {
20
+ "type": "object",
21
+ "properties": {
22
+ "queries": {
23
+ "type": "string",
24
+ "description": "query1||query2"
25
+ }
26
+ },
27
+ "required": [
28
+ "queries"
29
+ ],
30
+ "additionalProperties": false,
31
+ "$schema": "http://json-schema.org/draft-07/schema#"
32
+ }
33
+ }
34
+ },
35
+ {
36
+ "type": "function",
37
+ "function": {
38
+ "name": "open",
39
+ "description": "Batch open IDs (format 【{id}†...】) or URLs.",
40
+ "parameters": {
41
+ "type": "object",
42
+ "properties": {
43
+ "open_list": {
44
+ "type": "array",
45
+ "items": {
46
+ "type": "object",
47
+ "properties": {
48
+ "id": {
49
+ "description": "ID or URL",
50
+ "anyOf": [
51
+ {
52
+ "type": "integer"
53
+ },
54
+ {
55
+ "type": "string"
56
+ }
57
+ ],
58
+ "default": -1
59
+ },
60
+ "cursor": {
61
+ "type": "integer",
62
+ "description": "",
63
+ "default": -1
64
+ },
65
+ "loc": {
66
+ "type": "integer",
67
+ "description": "Start line",
68
+ "default": -1
69
+ },
70
+ "num_lines": {
71
+ "type": "integer",
72
+ "description": "",
73
+ "default": -1
74
+ },
75
+ "view_source": {
76
+ "type": "boolean",
77
+ "description": "",
78
+ "default": false
79
+ }
80
+ },
81
+ "additionalProperties": false
82
+ },
83
+ "description": ""
84
+ }
85
+ },
86
+ "required": [
87
+ "open_list"
88
+ ],
89
+ "additionalProperties": false,
90
+ "$schema": "http://json-schema.org/draft-07/schema#"
91
+ }
92
+ }
93
+ },
94
+ {
95
+ "type": "function",
96
+ "function": {
97
+ "name": "find",
98
+ "description": "Find exact text pattern in pages.",
99
+ "parameters": {
100
+ "type": "object",
101
+ "properties": {
102
+ "find_list": {
103
+ "type": "array",
104
+ "items": {
105
+ "type": "object",
106
+ "properties": {
107
+ "pattern": {
108
+ "type": "string",
109
+ "description": ""
110
+ },
111
+ "cursor": {
112
+ "type": "integer",
113
+ "description": "",
114
+ "default": -1
115
+ }
116
+ },
117
+ "required": [
118
+ "pattern"
119
+ ],
120
+ "additionalProperties": false
121
+ },
122
+ "description": ""
123
+ }
124
+ },
125
+ "required": [
126
+ "find_list"
127
+ ],
128
+ "additionalProperties": false,
129
+ "$schema": "http://json-schema.org/draft-07/schema#"
130
+ }
131
+ }
132
+ }
133
+ ]
134
+ },
135
+ {
136
+ "role": "assistant",
137
+ "content": "",
138
+ "reasoning_content": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
139
+ "tool_calls": [
140
+ {
141
+ "type": "function",
142
+ "function": {
143
+ "name": "search",
144
+ "arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
145
+ }
146
+ }
147
+ ]
148
+ },
149
+ {
150
+ "role": "tool",
151
+ "content": "[0]"
152
+ },
153
+ {
154
+ "role": "assistant",
155
+ "content": "请及时就医。",
156
+ "reasoning_content": "现在开始组织回答。",
157
+ "tool_calls": []
158
+ }
159
+ ]
encoding/tests/test_input_4.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
5
+ },
6
+ {
7
+ "role": "latest_reminder",
8
+ "content": "2024-11-15,上海市,App,中文"
9
+ },
10
+ {
11
+ "role": "user",
12
+ "content": "热海大滚锅是世界著名温泉吗"
13
+ },
14
+ {
15
+ "role": "assistant",
16
+ "content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
17
+ "mask": 1
18
+ },
19
+ {
20
+ "role": "user",
21
+ "content": "世界著名温泉有哪些",
22
+ "task": "action"
23
+ },
24
+ {
25
+ "role": "assistant",
26
+ "content": "Search"
27
+ }
28
+ ]
encoding/tests/test_output_1.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>You are a helpful assistant.
2
+
3
+ ## Tools
4
+
5
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
6
+
7
+ <|DSML|tool_calls>
8
+ <|DSML|invoke name="$TOOL_NAME">
9
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
10
+ ...
11
+ </|DSML|invoke>
12
+ <|DSML|invoke name="$TOOL_NAME2">
13
+ ...
14
+ </|DSML|invoke>
15
+ </|DSML|tool_calls>
16
+
17
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
18
+
19
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
20
+
21
+ Otherwise, output directly after </think> with tool calls or final response.
22
+
23
+ ### Available Tool Schemas
24
+
25
+ {"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
26
+ {"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
27
+
28
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
29
+ <|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
30
+
31
+ <|DSML|tool_calls>
32
+ <|DSML|invoke name="get_weather">
33
+ <|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
34
+ <|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
35
+ </|DSML|invoke>
36
+ </|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
encoding/tests/test_output_2.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ <|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
encoding/tests/test_output_3.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
2
+
3
+ CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
4
+
5
+ ## Tools
6
+
7
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
8
+
9
+ <|DSML|tool_calls>
10
+ <|DSML|invoke name="$TOOL_NAME">
11
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
12
+ ...
13
+ </|DSML|invoke>
14
+ <|DSML|invoke name="$TOOL_NAME2">
15
+ ...
16
+ </|DSML|invoke>
17
+ </|DSML|tool_calls>
18
+
19
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
20
+
21
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
22
+
23
+ Otherwise, output directly after </think> with tool calls or final response.
24
+
25
+ ### Available Tool Schemas
26
+
27
+ {"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
28
+ {"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
29
+ {"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
30
+
31
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
32
+ <|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
33
+
34
+ <|DSML|tool_calls>
35
+ <|DSML|invoke name="search">
36
+ <|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
37
+ </|DSML|invoke>
38
+ </|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
encoding/tests/test_output_4.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
2
+ 今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
3
+
4
+ 为了让你快速了解,我整理了一个简要的表格:
5
+
6
+ | 维度 | 热海大滚锅的地位与特点 |
7
+ | :--- | :--- |
8
+ | **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
9
+ | **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
10
+ | **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
11
+ | **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
12
+
13
+ ### 💡 游玩攻略与温馨提示
14
+
15
+ 如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
16
+
17
+ - **门票与开放时间**:
18
+ - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
19
+ - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
20
+
21
+ - **特色体验**:
22
+ - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
23
+ - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
24
+
25
+ - **注意事项**:
26
+ - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
27
+ - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
28
+
29
+ 希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "do_sample": true,
6
+ "temperature": 1.0,
7
+ "top_p": 1.0,
8
+ "transformers_version": "4.46.3"
9
+ }
inference/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code for DeepSeek models
2
+
3
+ First convert huggingface model weight files to the format of this project.
4
+ ```bash
5
+ export EXPERTS=256
6
+ export MP=4
7
+ export CONFIG=config.json
8
+ python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
9
+ ```
10
+
11
+ Then chat with DeepSeek model at will!
12
+ ```bash
13
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
14
+ ```
15
+
16
+ Or batch inference from file.
17
+ ```bash
18
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
19
+ ```
20
+
21
+ Or multi nodes inference.
22
+ ```bash
23
+ torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
24
+ ```
25
+
26
+ If you want to use fp8, just remove `"expert_dtype": "fp4"` in `config.json` and specify `--expert-dtype fp8` in `convert.py`.
inference/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 129280,
3
+ "dim": 4096,
4
+ "moe_inter_dim": 2048,
5
+ "n_layers": 43,
6
+ "n_hash_layers": 3,
7
+ "n_heads": 64,
8
+ "n_routed_experts": 256,
9
+ "n_shared_experts": 1,
10
+ "n_activated_experts": 6,
11
+ "score_func": "sqrtsoftplus",
12
+ "route_scale": 1.5,
13
+ "swiglu_limit": 10.0,
14
+ "q_lora_rank": 1024,
15
+ "head_dim": 512,
16
+ "rope_head_dim": 64,
17
+ "o_groups": 8,
18
+ "o_lora_rank": 1024,
19
+ "window_size": 128,
20
+ "original_seq_len": 65536,
21
+ "rope_theta": 10000,
22
+ "rope_factor": 16,
23
+ "beta_fast": 32,
24
+ "beta_slow": 1,
25
+ "index_n_heads": 64,
26
+ "index_head_dim": 128,
27
+ "index_topk": 512,
28
+ "hc_mult": 4,
29
+ "hc_sinkhorn_iters": 20,
30
+ "dtype": "fp8",
31
+ "scale_fmt": "ue8m0",
32
+ "expert_dtype": "fp4",
33
+ "compress_rope_theta": 160000,
34
+ "compress_ratios": [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
35
+ }
inference/convert.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from argparse import ArgumentParser
4
+ from glob import glob
5
+ from tqdm import tqdm, trange
6
+
7
+ import torch
8
+ from safetensors.torch import safe_open, save_file
9
+
10
+
11
+ FP4_TABLE = torch.tensor([
12
+ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
13
+ 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
14
+ ], dtype=torch.float32)
15
+
16
+
17
+ def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Casts a tensor from e2m1fn to e4m3fn losslessly.
20
+ """
21
+ assert x.dtype == torch.int8
22
+ assert x.ndim == 2
23
+ out_dim, in_dim = x.size()
24
+ in_dim *= 2
25
+ fp8_block_size = 128
26
+ fp4_block_size = 32
27
+ assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
28
+ assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
29
+
30
+ x = x.view(torch.uint8)
31
+ low = x & 0x0F
32
+ high = (x >> 4) & 0x0F
33
+ x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
34
+
35
+ # max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
36
+ # 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
37
+ MAX_OFFSET_BITS = 6
38
+
39
+ bOut = out_dim // fp8_block_size
40
+ bIn = in_dim // fp8_block_size
41
+ # bOut, bIn, 128, 128
42
+ x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
43
+ # bOut, bIn, 128*4
44
+ scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
45
+ ## bOut, bIn, 1
46
+ scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
47
+ # bOut, bIn, 128*4
48
+ offset = scale / scale_max_offset_bits
49
+ # bOut, bIn, 128, 128
50
+ offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
51
+ x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
52
+ return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
53
+
54
+
55
+ mapping = {
56
+ "embed_tokens": ("embed", 0),
57
+ "input_layernorm": ("attn_norm", None),
58
+ "post_attention_layernorm": ("ffn_norm", None),
59
+ "q_proj": ("wq", 0),
60
+ "q_a_proj": ("wq_a", None),
61
+ "q_a_layernorm": ("q_norm", None),
62
+ "q_b_proj": ("wq_b", 0),
63
+ "kv_a_proj_with_mqa": ("wkv_a", None),
64
+ "kv_a_layernorm": ("kv_norm", None),
65
+ "kv_b_proj": ("wkv_b", 0),
66
+ "o_proj": ("wo", 1),
67
+ "gate_proj": ("w1", 0),
68
+ "down_proj": ("w2", 1),
69
+ "up_proj": ("w3", 0),
70
+ "lm_head": ("head", 0),
71
+
72
+ "embed": ("embed", 0),
73
+ "wq_b": ("wq_b", 0),
74
+ "wo_a": ("wo_a", 0),
75
+ "wo_b": ("wo_b", 1),
76
+ "head": ("head", 0),
77
+ "attn_sink": ("attn_sink", 0),
78
+ "weights_proj": ("weights_proj", 0),
79
+ }
80
+
81
+
82
+ def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
83
+ """
84
+ Converts and saves model checkpoint files into a specified format.
85
+
86
+ Args:
87
+ hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
88
+ save_path (str): Path to the directory where the converted checkpoint files will be saved.
89
+ n_experts (int): Total number of experts in the model.
90
+ mp (int): Model parallelism factor.
91
+
92
+ Returns:
93
+ None
94
+ """
95
+ torch.set_num_threads(8)
96
+ n_local_experts = n_experts // mp
97
+ state_dicts = [{} for _ in range(mp)]
98
+
99
+ for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
100
+ with safe_open(file_path, framework="pt", device="cpu") as f:
101
+ for name in f.keys():
102
+ param: torch.Tensor = f.get_tensor(name)
103
+ if name.startswith("model."):
104
+ name = name[len("model."):]
105
+ if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
106
+ continue
107
+ name = name.replace("self_attn", "attn")
108
+ name = name.replace("mlp", "ffn")
109
+ name = name.replace("weight_scale_inv", "scale")
110
+ name = name.replace("e_score_correction_bias", "bias")
111
+ if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
112
+ key = name.split(".")[-1]
113
+ else:
114
+ key = name.split(".")[-2]
115
+ if key in mapping:
116
+ new_key, dim = mapping[key]
117
+ else:
118
+ new_key, dim = key, None
119
+ name = name.replace(key, new_key)
120
+ for i in range(mp):
121
+ new_param = param
122
+ if "experts" in name and "shared_experts" not in name:
123
+ idx = int(name.split(".")[-3])
124
+ if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
125
+ continue
126
+ elif dim is not None:
127
+ assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
128
+ shard_size = param.size(dim) // mp
129
+ new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
130
+ state_dicts[i][name] = new_param
131
+
132
+ os.makedirs(save_path, exist_ok=True)
133
+
134
+ for i in trange(mp):
135
+ names = list(state_dicts[i].keys())
136
+ for name in names:
137
+ if name.endswith("wo_a.weight"):
138
+ weight = state_dicts[i][name]
139
+ scale = state_dicts[i].pop(name.replace("weight", "scale"))
140
+ weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
141
+ state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
142
+ elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
143
+ if expert_dtype == "fp8":
144
+ scale_name = name.replace("weight", "scale")
145
+ weight = state_dicts[i].pop(name)
146
+ scale = state_dicts[i].pop(scale_name)
147
+ state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
148
+ else:
149
+ state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
150
+ save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
151
+
152
+ for file in ["tokenizer.json", "tokenizer_config.json"]:
153
+ old_file_path = os.path.join(hf_ckpt_path, file)
154
+ new_file_path = os.path.join(save_path, file)
155
+ if os.path.exists(old_file_path):
156
+ shutil.copyfile(old_file_path, new_file_path)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = ArgumentParser()
161
+ parser.add_argument("--hf-ckpt-path", type=str, required=True)
162
+ parser.add_argument("--save-path", type=str, required=True)
163
+ parser.add_argument("--n-experts", type=int, required=True)
164
+ parser.add_argument("--model-parallel", type=int, required=True)
165
+ parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
166
+ args = parser.parse_args()
167
+ assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
168
+ main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
inference/generate.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import sys
4
+ from argparse import ArgumentParser
5
+ from typing import List
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from transformers import AutoTokenizer
10
+ from safetensors.torch import load_model
11
+
12
+ from model import Transformer, ModelArgs
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ encoding_dir = os.path.join(current_dir, '../encoding')
15
+ sys.path.insert(0, os.path.abspath(encoding_dir))
16
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
17
+
18
+
19
+ def sample(logits, temperature: float = 1.0):
20
+ """Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
21
+ since it avoids the GPU-to-CPU sync in torch.multinomial."""
22
+ logits = logits / max(temperature, 1e-5)
23
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
24
+ return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
25
+
26
+
27
+ @torch.inference_mode()
28
+ def generate(
29
+ model: Transformer,
30
+ prompt_tokens: List[List[int]],
31
+ max_new_tokens: int,
32
+ eos_id: int,
33
+ temperature: float = 1.0
34
+ ) -> List[List[int]]:
35
+ """Batch generation with left-padded prompts.
36
+
37
+ The first forward pass processes [min_prompt_len:] tokens (prefill phase).
38
+ Subsequent passes generate one token at a time (decode phase). For positions
39
+ still within a prompt, the ground-truth token overrides the model's prediction.
40
+ """
41
+ prompt_lens = [len(t) for t in prompt_tokens]
42
+ assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
43
+ total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
44
+ tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
45
+ for i, t in enumerate(prompt_tokens):
46
+ tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
47
+ prev_pos = 0
48
+ finished = torch.tensor([False] * len(prompt_tokens))
49
+ prompt_mask = tokens != -1
50
+ for cur_pos in range(min(prompt_lens), total_len):
51
+ logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
52
+ if temperature > 0:
53
+ next_token = sample(logits, temperature)
54
+ else:
55
+ next_token = logits.argmax(dim=-1)
56
+ next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
57
+ tokens[:, cur_pos] = next_token
58
+ finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
59
+ prev_pos = cur_pos
60
+ if finished.all():
61
+ break
62
+ completion_tokens = []
63
+ for i, toks in enumerate(tokens.tolist()):
64
+ toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
65
+ if eos_id in toks:
66
+ toks = toks[:toks.index(eos_id)]
67
+ toks.append(eos_id)
68
+ completion_tokens.append(toks)
69
+ return completion_tokens
70
+
71
+
72
+ def main(
73
+ ckpt_path: str,
74
+ config: str,
75
+ input_file: str = "",
76
+ interactive: bool = True,
77
+ max_new_tokens: int = 100,
78
+ temperature: float = 1.0,
79
+ ) -> None:
80
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
81
+ rank = int(os.getenv("RANK", "0"))
82
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
83
+ if world_size > 1:
84
+ dist.init_process_group("nccl")
85
+ global print
86
+ if rank != 0:
87
+ print = lambda *_, **__: None
88
+ torch.cuda.set_device(local_rank)
89
+ torch.cuda.memory._set_allocator_settings("expandable_segments:True")
90
+ torch.set_default_dtype(torch.bfloat16)
91
+ torch.set_num_threads(8)
92
+ torch.manual_seed(33377335)
93
+ with open(config) as f:
94
+ args = ModelArgs(**json.load(f))
95
+ if interactive:
96
+ args.max_batch_size = 1
97
+ print(args)
98
+ with torch.device("cuda"):
99
+ model = Transformer(args)
100
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
101
+ print("load model")
102
+ load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
103
+ torch.set_default_device("cuda")
104
+ print("I'm DeepSeek 👋")
105
+
106
+ if interactive:
107
+ messages = []
108
+ while True:
109
+ if world_size == 1:
110
+ prompt = input(">>> ")
111
+ elif rank == 0:
112
+ prompt = input(">>> ")
113
+ objects = [prompt]
114
+ dist.broadcast_object_list(objects, 0)
115
+ else:
116
+ objects = [None]
117
+ dist.broadcast_object_list(objects, 0)
118
+ prompt = objects[0]
119
+ if prompt == "/exit":
120
+ break
121
+ elif prompt == "/clear":
122
+ messages.clear()
123
+ continue
124
+ messages.append({"role": "user", "content": prompt})
125
+ prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
126
+ completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
127
+ completion = tokenizer.decode(completion_tokens[0])
128
+ print(completion)
129
+ messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
130
+ else:
131
+ with open(input_file) as f:
132
+ prompts = f.read().split("\n\n")
133
+ prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
134
+ completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
135
+ completions = tokenizer.batch_decode(completion_tokens)
136
+ for prompt, completion in zip(prompts, completions):
137
+ print("Prompt:", prompt)
138
+ print("Completion:", completion)
139
+ print()
140
+
141
+ if world_size > 1:
142
+ dist.destroy_process_group()
143
+
144
+
145
+ if __name__ == "__main__":
146
+ parser = ArgumentParser()
147
+ parser.add_argument("--ckpt-path", type=str, required=True)
148
+ parser.add_argument("--config", type=str, required=True)
149
+ parser.add_argument("--input-file", type=str, default="")
150
+ parser.add_argument("--interactive", action="store_true")
151
+ parser.add_argument("--max-new-tokens", type=int, default=300)
152
+ parser.add_argument("--temperature", type=float, default=0.6)
153
+ args = parser.parse_args()
154
+ assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
155
+ main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
inference/kernel.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tilelang
3
+ import tilelang.language as T
4
+ from typing import Tuple, Optional
5
+
6
+
7
+ tilelang.set_log_level("WARNING")
8
+
9
+ pass_configs = {
10
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
11
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
12
+ }
13
+
14
+ FP8 = "float8_e4m3"
15
+ FP4 = "float4_e2m1fn"
16
+ FE8M0 = "float8_e8m0fnu"
17
+ BF16 = "bfloat16"
18
+ FP32 = "float32"
19
+ INT32 = "int32"
20
+
21
+
22
+ def fast_log2_ceil(x):
23
+ """Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
24
+ bits_x = T.reinterpret("uint32", x)
25
+ exp_x = (bits_x >> 23) & 0xFF
26
+ man_bits = bits_x & ((1 << 23) - 1)
27
+ return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
28
+
29
+
30
+ def fast_pow2(x):
31
+ """Compute 2^x for integer x via IEEE 754 bit manipulation."""
32
+ bits_x = (x + 127) << 23
33
+ return T.reinterpret("float32", bits_x)
34
+
35
+
36
+ def fast_round_scale(amax, fp8_max_inv):
37
+ return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
38
+
39
+
40
+ @tilelang.jit(pass_configs=pass_configs)
41
+ def act_quant_kernel(
42
+ N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
43
+ round_scale=False, inplace=False
44
+ ):
45
+ """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
46
+ M = T.symbolic("M")
47
+ fp8_min = -448.0
48
+ fp8_max = 448.0
49
+ fp8_max_inv = 1 / fp8_max
50
+ num_stages = 0 if round_scale or inplace else 2
51
+ blk_m = 32
52
+ group_size = block_size
53
+ # Internal computation in FP32; scale_dtype controls output storage format.
54
+ compute_dtype = FP32
55
+ out_dtype = in_dtype if inplace else out_dtype
56
+
57
+ @T.prim_func
58
+ def act_quant_kernel_(
59
+ X: T.Tensor[(M, N), in_dtype],
60
+ Y: T.Tensor[(M, N), out_dtype],
61
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
62
+ ):
63
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
64
+ pid_m,
65
+ pid_n,
66
+ ):
67
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
68
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
69
+ amax_local = T.alloc_fragment((blk_m,), compute_dtype)
70
+ s_local = T.alloc_fragment((blk_m,), compute_dtype)
71
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
72
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
73
+
74
+ for _ in T.Pipelined(1, num_stages=num_stages):
75
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
76
+ T.copy(x_shared, x_local)
77
+ T.reduce_absmax(x_local, amax_local, dim=1)
78
+ for i in T.Parallel(blk_m):
79
+ amax_local[i] = T.max(amax_local[i], 1e-4)
80
+ if round_scale:
81
+ s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
82
+ else:
83
+ s_local[i] = amax_local[i] * fp8_max_inv
84
+ if inplace:
85
+ for i, j in T.Parallel(blk_m, group_size):
86
+ y_local[i, j] = T.Cast(
87
+ out_dtype,
88
+ T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
89
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
90
+ ))) * s_local[i],
91
+ )
92
+ else:
93
+ for i, j in T.Parallel(blk_m, group_size):
94
+ y_local[i, j] = T.clamp(
95
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
96
+ )
97
+ for i in T.Parallel(blk_m):
98
+ S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
99
+ T.copy(y_local, y_shared)
100
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
101
+
102
+ return act_quant_kernel_
103
+
104
+
105
+ def act_quant(
106
+ x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
107
+ scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
108
+ ) -> torch.Tensor:
109
+ """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
110
+ When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
111
+ N = x.size(-1)
112
+ assert N % block_size == 0
113
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
114
+ z = x.contiguous()
115
+ y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
116
+ s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
117
+ kernel = act_quant_kernel(
118
+ N, block_size, scale_dtype=tl_dtype,
119
+ round_scale=scale_fmt is not None, inplace=inplace,
120
+ )
121
+ kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
122
+ if inplace:
123
+ x.copy_(y)
124
+ return x
125
+ return y, s
126
+
127
+
128
+ @tilelang.jit(pass_configs=pass_configs)
129
+ def fp4_quant_kernel(
130
+ N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
131
+ ):
132
+ """Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
133
+ M = T.symbolic("M")
134
+ fp4_max = 6.0
135
+ fp4_max_inv = 1.0 / fp4_max
136
+ blk_m = 32
137
+ group_size = block_size
138
+ compute_dtype = FP32
139
+ out_dtype = in_dtype if inplace else FP4
140
+
141
+ @T.prim_func
142
+ def fp4_quant_kernel_(
143
+ X: T.Tensor[(M, N), in_dtype],
144
+ Y: T.Tensor[(M, N), out_dtype],
145
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
146
+ ):
147
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
148
+ pid_m,
149
+ pid_n,
150
+ ):
151
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
152
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
153
+ amax_local = T.alloc_fragment((blk_m,), compute_dtype)
154
+ s_local = T.alloc_fragment((blk_m,), compute_dtype)
155
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
156
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
157
+
158
+ for _ in T.Pipelined(1, num_stages=2):
159
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
160
+ T.copy(x_shared, x_local)
161
+ T.reduce_absmax(x_local, amax_local, dim=1)
162
+ for i in T.Parallel(blk_m):
163
+ amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
164
+ s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
165
+ if inplace:
166
+ for i, j in T.Parallel(blk_m, group_size):
167
+ y_local[i, j] = T.Cast(
168
+ out_dtype,
169
+ T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
170
+ x_local[i, j] / s_local[i], -fp4_max, fp4_max
171
+ ))) * s_local[i],
172
+ )
173
+ else:
174
+ for i, j in T.Parallel(blk_m, group_size):
175
+ y_local[i, j] = T.clamp(
176
+ x_local[i, j] / s_local[i], -fp4_max, fp4_max
177
+ )
178
+ for i in T.Parallel(blk_m):
179
+ S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
180
+ T.copy(y_local, y_shared)
181
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
182
+
183
+ return fp4_quant_kernel_
184
+
185
+
186
+ def fp4_act_quant(
187
+ x: torch.Tensor, block_size: int = 32, inplace: bool = False,
188
+ ) -> torch.Tensor:
189
+ """Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
190
+ N = x.size(-1)
191
+ assert N % block_size == 0
192
+ z = x.contiguous()
193
+ y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
194
+ s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
195
+ kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
196
+ kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
197
+ if inplace:
198
+ x.copy_(y)
199
+ return x
200
+ return y, s
201
+
202
+
203
+ @tilelang.jit(pass_configs=pass_configs)
204
+ def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
205
+ assert out_dtype in [BF16, FP32]
206
+
207
+ M = T.symbolic("M")
208
+ group_size = 128
209
+ block_M = 32
210
+ block_N = 128
211
+ block_K = 128
212
+
213
+ @T.prim_func
214
+ def fp8_gemm_kernel_(
215
+ A: T.Tensor[(M, K), FP8],
216
+ B: T.Tensor[(N, K), FP8],
217
+ C: T.Tensor[(M, N), out_dtype],
218
+ scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
219
+ scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
220
+ ):
221
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
222
+ bx,
223
+ by,
224
+ ):
225
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
226
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
227
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
228
+ Scale_C_shared = T.alloc_shared((block_M), FP32)
229
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
230
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
231
+
232
+ # Improve L2 Cache
233
+ T.use_swizzle(panel_size=10)
234
+ T.clear(C_local)
235
+ T.clear(C_local_accum)
236
+
237
+ K_iters = T.ceildiv(K, block_K)
238
+ for k in T.Pipelined(K_iters, num_stages=4):
239
+ T.copy(A[by * block_M, k * block_K], A_shared)
240
+ T.copy(B[bx * block_N, k * block_K], B_shared)
241
+ # Cast scales to FP32 for computation; scales_b has one value per block_N group
242
+ Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
243
+ for i in T.Parallel(block_M):
244
+ Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
245
+
246
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
247
+ # Separate accumulator for scale-corrected results (2x accumulation precision)
248
+ for i, j in T.Parallel(block_M, block_N):
249
+ C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
250
+ T.clear(C_local)
251
+ T.copy(C_local_accum, C_shared)
252
+ T.copy(C_shared, C[by * block_M, bx * block_N])
253
+
254
+ return fp8_gemm_kernel_
255
+
256
+
257
+ def fp8_gemm(
258
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
259
+ scale_dtype: torch.dtype = torch.float32,
260
+ ) -> torch.Tensor:
261
+ """C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
262
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
263
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
264
+ "Scaling factor tensors must be contiguous"
265
+ )
266
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
267
+ K = a.size(-1)
268
+ M = a.numel() // K
269
+ N = b.size(0)
270
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
271
+ kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
272
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
273
+ return c
274
+
275
+
276
+ @tilelang.jit(pass_configs=pass_configs)
277
+ def sparse_attn_kernel(h: int, d: int, scale=None):
278
+ """Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
279
+ For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
280
+ with numerically stable running max/sum, and includes a learnable attn_sink bias."""
281
+ b = T.symbolic("b")
282
+ m = T.symbolic("m")
283
+ n = T.symbolic("n")
284
+ topk = T.symbolic("topk")
285
+ if scale is None:
286
+ scale = (1.0 / d) ** 0.5
287
+
288
+ num_stages = 2
289
+ threads = 256
290
+ block = 64
291
+ num_blocks = tilelang.cdiv(topk, block)
292
+
293
+ @T.prim_func
294
+ def sparse_attn_kernel_(
295
+ q: T.Tensor[(b, m, h, d), BF16],
296
+ kv: T.Tensor[(b, n, d), BF16],
297
+ o: T.Tensor[(b, m, h, d), BF16],
298
+ attn_sink: T.Tensor[(h,), FP32],
299
+ topk_idxs: T.Tensor[(b, m, topk), INT32],
300
+ ):
301
+ with T.Kernel(m, b, threads=threads) as (bx, by):
302
+ q_shared = T.alloc_shared((h, d), BF16)
303
+ kv_shared = T.alloc_shared((block, d), BF16)
304
+ o_shared = T.alloc_shared((h, d), BF16)
305
+ acc_s_cast = T.alloc_shared((h, block), BF16)
306
+
307
+ idxs = T.alloc_fragment(block, INT32)
308
+ acc_s = T.alloc_fragment((h, block), FP32)
309
+ acc_o = T.alloc_fragment((h, d), FP32)
310
+ scores_max = T.alloc_fragment(h, FP32)
311
+ scores_max_prev = T.alloc_fragment(h, FP32)
312
+ scores_scale = T.alloc_fragment(h, FP32)
313
+ scores_sum = T.alloc_fragment(h, FP32)
314
+ sum_exp = T.alloc_fragment(h, FP32)
315
+
316
+ T.clear(acc_o)
317
+ T.clear(sum_exp)
318
+ T.fill(scores_max, -T.infinity(FP32))
319
+ T.copy(q[by, bx, :, :], q_shared)
320
+
321
+ for t in T.Pipelined(num_blocks, num_stages=num_stages):
322
+ for i in T.Parallel(block):
323
+ idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
324
+ for i, j in T.Parallel(block, d):
325
+ kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
326
+ for i, j in T.Parallel(h, block):
327
+ acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
328
+ T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
329
+ for i, j in T.Parallel(h, block):
330
+ acc_s[i, j] *= scale
331
+ T.copy(scores_max, scores_max_prev)
332
+ T.reduce_max(acc_s, scores_max, dim=1, clear=False)
333
+ for i in T.Parallel(h):
334
+ scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
335
+ for i, j in T.Parallel(h, block):
336
+ acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
337
+ T.reduce_sum(acc_s, scores_sum, dim=1)
338
+ for i in T.Parallel(h):
339
+ sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
340
+ T.copy(acc_s, acc_s_cast)
341
+ for i, j in T.Parallel(h, d):
342
+ acc_o[i, j] *= scores_scale[i]
343
+ T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
344
+
345
+ for i in T.Parallel(h):
346
+ sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
347
+ for i, j in T.Parallel(h, d):
348
+ acc_o[i, j] /= sum_exp[i]
349
+ T.copy(acc_o, o_shared)
350
+ T.copy(o_shared, o[by, bx, :, :])
351
+
352
+ return sparse_attn_kernel_
353
+
354
+
355
+ def sparse_attn(
356
+ q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
357
+ ) -> torch.Tensor:
358
+ b, s, h, d = q.size()
359
+ # Pad heads to 16 for kernel efficiency (stripped after)
360
+ if h < 16:
361
+ q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
362
+ attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
363
+ o = torch.empty_like(q)
364
+ kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
365
+ kernel(q, kv, o, attn_sink, topk_idxs)
366
+ if h < 16:
367
+ o = o.narrow(2, 0, h).contiguous()
368
+ return o
369
+
370
+
371
+ @tilelang.jit(pass_configs=pass_configs)
372
+ def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
373
+ n = T.symbolic("n")
374
+ mix_hc = (2 + hc) * hc
375
+ threads = 64
376
+
377
+ @T.prim_func
378
+ def hc_split_sinkhorn_kernel_(
379
+ mixes: T.Tensor[(n, mix_hc), FP32],
380
+ hc_scale: T.Tensor[(3,), FP32],
381
+ hc_base: T.Tensor[(mix_hc,), FP32],
382
+ pre: T.Tensor[(n, hc), FP32],
383
+ post: T.Tensor[(n, hc), FP32],
384
+ comb: T.Tensor[(n, hc, hc), FP32],
385
+ ):
386
+ with T.Kernel(n, threads=threads) as i:
387
+ mixes_shared = T.alloc_shared(mix_hc, FP32)
388
+ comb_frag = T.alloc_fragment((hc, hc), FP32)
389
+ T.copy(mixes[i, :], mixes_shared)
390
+
391
+ for j in T.Parallel(hc):
392
+ pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
393
+ for j in T.Parallel(hc):
394
+ post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
395
+ for j, k in T.Parallel(hc, hc):
396
+ comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
397
+
398
+ row_sum = T.alloc_fragment(hc, FP32)
399
+ col_sum = T.alloc_fragment(hc, FP32)
400
+
401
+ # comb = comb.softmax(-1) + eps
402
+ row_max = T.alloc_fragment(hc, FP32)
403
+ T.reduce_max(comb_frag, row_max, dim=1)
404
+ for j, k in T.Parallel(hc, hc):
405
+ comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
406
+ T.reduce_sum(comb_frag, row_sum, dim=1)
407
+ for j, k in T.Parallel(hc, hc):
408
+ comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
409
+
410
+ # comb = comb / (comb.sum(-2) + eps)
411
+ T.reduce_sum(comb_frag, col_sum, dim=0)
412
+ for j, k in T.Parallel(hc, hc):
413
+ comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
414
+
415
+ for _ in T.serial(sinkhorn_iters - 1):
416
+ # comb = comb / (comb.sum(-1) + eps)
417
+ T.reduce_sum(comb_frag, row_sum, dim=1)
418
+ for j, k in T.Parallel(hc, hc):
419
+ comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
420
+ # comb = comb / (comb.sum(-2) + eps)
421
+ T.reduce_sum(comb_frag, col_sum, dim=0)
422
+ for j, k in T.Parallel(hc, hc):
423
+ comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
424
+
425
+ T.copy(comb_frag, comb[i, :, :])
426
+
427
+ return hc_split_sinkhorn_kernel_
428
+
429
+
430
+ def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
431
+ b, s, _ = mixes.size()
432
+ pre = mixes.new_empty(b, s, hc_mult)
433
+ post = mixes.new_empty(b, s, hc_mult)
434
+ comb = mixes.new_empty(b, s, hc_mult, hc_mult)
435
+ kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
436
+ kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
437
+ pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
438
+ return pre, post, comb
439
+
440
+
441
+ @tilelang.jit(pass_configs=pass_configs)
442
+ def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
443
+ """FP8 act x FP4 weight GEMM kernel.
444
+
445
+ C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
446
+
447
+ Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
448
+ Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
449
+
450
+ B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
451
+ The FP4 values are packed along the K (last) dimension.
452
+
453
+ Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
454
+ cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
455
+ Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
456
+ """
457
+ M = T.symbolic("M")
458
+ act_group_size = 128
459
+ weight_group_size = 32
460
+ block_M = 32
461
+ block_N = 128
462
+ block_K = 32 # matches weight_group_size for simple scale handling
463
+ n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
464
+
465
+ @T.prim_func
466
+ def fp4_gemm_kernel_(
467
+ A: T.Tensor[(M, K), FP8],
468
+ B: T.Tensor[(N, K), FP4],
469
+ C: T.Tensor[(M, N), out_dtype],
470
+ scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
471
+ scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
472
+ ):
473
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
474
+ bx,
475
+ by,
476
+ ):
477
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
478
+ B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
479
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
480
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
481
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
482
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
483
+ scale_a_frag = T.alloc_fragment((block_M,), FP32)
484
+ scale_b_frag = T.alloc_fragment((block_N,), FP32)
485
+
486
+ T.use_swizzle(panel_size=10)
487
+ T.clear(C_local)
488
+ T.clear(C_local_accum)
489
+
490
+ K_iters = T.ceildiv(K, block_K)
491
+ for k in T.Pipelined(K_iters, num_stages=2):
492
+ T.copy(A[by * block_M, k * block_K], A_shared)
493
+ T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
494
+ # FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
495
+ for i, j in T.Parallel(block_N, block_K):
496
+ B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
497
+
498
+ # Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
499
+ for i in T.Parallel(block_N):
500
+ scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
501
+
502
+ # Act scale: per 128 on K, indexed by k // 4
503
+ for i in T.Parallel(block_M):
504
+ scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
505
+
506
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
507
+
508
+ for i, j in T.Parallel(block_M, block_N):
509
+ C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
510
+ T.clear(C_local)
511
+
512
+ T.copy(C_local_accum, C_shared)
513
+ T.copy(C_shared, C[by * block_M, bx * block_N])
514
+
515
+ return fp4_gemm_kernel_
516
+
517
+
518
+ def fp4_gemm(
519
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
520
+ scale_dtype: torch.dtype = torch.float32,
521
+ ) -> torch.Tensor:
522
+ """C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
523
+ A has per-128 act scale; B has per-32 E8M0 weight scale.
524
+ B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
525
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
526
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
527
+ "Scaling factor tensors must be contiguous"
528
+ )
529
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
530
+ K = a.size(-1)
531
+ M = a.numel() // K
532
+ N = b.size(0)
533
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
534
+ kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
535
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
536
+ return c
inference/model.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Optional, Literal
4
+ from functools import lru_cache
5
+ from contextlib import contextmanager
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+
12
+ from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
13
+
14
+
15
+ world_size = 1
16
+ rank = 0
17
+ block_size = 128
18
+ fp4_block_size = 32
19
+ default_dtype = torch.bfloat16
20
+ scale_fmt = None
21
+ scale_dtype = torch.float32
22
+
23
+
24
+ @contextmanager
25
+ def set_dtype(dtype):
26
+ """Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
27
+ prev = torch.get_default_dtype()
28
+ torch.set_default_dtype(dtype)
29
+ try:
30
+ yield
31
+ finally:
32
+ torch.set_default_dtype(prev)
33
+
34
+ @dataclass
35
+ class ModelArgs:
36
+ """Model hyperparameters. Field names match the config JSON keys."""
37
+ max_batch_size: int = 4
38
+ max_seq_len: int = 4096
39
+ dtype: Literal["bf16", "fp8"] = "fp8"
40
+ scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
41
+ expert_dtype: Literal[None, "fp4"] = None
42
+ scale_dtype: Literal["fp32", "fp8"] = "fp8"
43
+ vocab_size: int = 129280
44
+ dim: int = 4096
45
+ moe_inter_dim: int = 4096
46
+ n_layers: int = 7
47
+ n_hash_layers: int = 0
48
+ n_mtp_layers: int = 1
49
+ n_heads: int = 64
50
+ # moe
51
+ n_routed_experts: int = 8
52
+ n_shared_experts: int = 1
53
+ n_activated_experts: int = 2
54
+ score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
55
+ route_scale: float = 1.
56
+ swiglu_limit: float = 0.
57
+ # mqa
58
+ q_lora_rank: int = 1024
59
+ head_dim: int = 512
60
+ rope_head_dim: int = 64
61
+ norm_eps: float = 1e-6
62
+ o_groups: int = 8
63
+ o_lora_rank: int = 1024
64
+ window_size: int = 128
65
+ compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
66
+ # yarn
67
+ compress_rope_theta: float = 40000.0
68
+ original_seq_len: int = 0
69
+ rope_theta: float = 10000.0
70
+ rope_factor: float = 40
71
+ beta_fast: int = 32
72
+ beta_slow: int = 1
73
+ # index
74
+ index_n_heads: int = 64
75
+ index_head_dim: int = 128
76
+ index_topk: int = 512
77
+ # hc
78
+ hc_mult: int = 4
79
+ hc_sinkhorn_iters: int = 20
80
+ hc_eps: float = 1e-6
81
+
82
+
83
+ class ParallelEmbedding(nn.Module):
84
+ """Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
85
+ Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
86
+ def __init__(self, vocab_size: int, dim: int):
87
+ super().__init__()
88
+ self.vocab_size = vocab_size
89
+ self.dim = dim
90
+ assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
91
+ self.part_vocab_size = (vocab_size // world_size)
92
+ self.vocab_start_idx = rank * self.part_vocab_size
93
+ self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
94
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ if world_size > 1:
98
+ mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
99
+ x = x - self.vocab_start_idx
100
+ x[mask] = 0
101
+ y = F.embedding(x, self.weight)
102
+ if world_size > 1:
103
+ y[mask] = 0
104
+ dist.all_reduce(y)
105
+ return y
106
+
107
+
108
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
109
+ """Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
110
+ For quantized weights, x is first quantized to FP8 via act_quant."""
111
+ assert bias is None
112
+
113
+ if weight.dtype == torch.float4_e2m1fn_x2:
114
+ x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
115
+ return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
116
+ elif weight.dtype == torch.float8_e4m3fn:
117
+ x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
118
+ return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
119
+ else:
120
+ return F.linear(x, weight)
121
+
122
+
123
+ class Linear(nn.Module):
124
+ """Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
125
+
126
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
127
+ super().__init__()
128
+ self.in_features = in_features
129
+ self.out_features = out_features
130
+ dtype = dtype or default_dtype
131
+ if dtype == torch.float4_e2m1fn_x2:
132
+ # FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
133
+ # Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
134
+ self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
135
+ scale_out_features = out_features
136
+ scale_in_features = in_features // fp4_block_size
137
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
138
+ elif dtype == torch.float8_e4m3fn:
139
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
140
+ scale_out_features = (out_features + block_size - 1) // block_size
141
+ scale_in_features = (in_features + block_size - 1) // block_size
142
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
143
+ else:
144
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
145
+ self.register_parameter("scale", None)
146
+ if bias:
147
+ self.bias = nn.Parameter(torch.empty(out_features))
148
+ else:
149
+ self.register_parameter("bias", None)
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ return linear(x, self.weight, self.bias)
153
+
154
+
155
+ class ColumnParallelLinear(Linear):
156
+ """Shards output dim across TP ranks. No all-reduce needed on output."""
157
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
158
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
159
+ self.part_out_features = out_features // world_size
160
+ super().__init__(in_features, self.part_out_features, bias, dtype)
161
+
162
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
163
+ return linear(x, self.weight, self.bias)
164
+
165
+
166
+ class RowParallelLinear(Linear):
167
+ """Shards input dim across TP ranks. All-reduce on output to sum partial results."""
168
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
169
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
170
+ self.part_in_features = in_features // world_size
171
+ super().__init__(self.part_in_features, out_features, bias, dtype)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ y = linear(x, self.weight, None)
175
+ if world_size > 1:
176
+ y = y.float()
177
+ dist.all_reduce(y)
178
+ if self.bias is not None:
179
+ y += self.bias
180
+ return y.type_as(x)
181
+
182
+
183
+ class RMSNorm(nn.Module):
184
+ def __init__(self, dim: int, eps: float = 1e-6):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.eps = eps
188
+ # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
189
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
190
+
191
+ def forward(self, x: torch.Tensor):
192
+ dtype = x.dtype
193
+ x = x.float()
194
+ var = x.square().mean(-1, keepdim=True)
195
+ x = x * torch.rsqrt(var + self.eps)
196
+ return (self.weight * x).to(dtype)
197
+
198
+
199
+ @lru_cache(2)
200
+ def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
201
+ """Precomputes complex exponentials for rotary embeddings with YaRN scaling.
202
+ When original_seq_len > 0, applies frequency interpolation with a smooth
203
+ linear ramp between beta_fast and beta_slow correction ranges."""
204
+
205
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
206
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
207
+
208
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
209
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
210
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
211
+ return max(low, 0), min(high, dim-1)
212
+
213
+ def linear_ramp_factor(min, max, dim):
214
+ if min == max:
215
+ max += 0.001
216
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
217
+ ramp_func = torch.clamp(linear_func, 0, 1)
218
+ return ramp_func
219
+
220
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
221
+ if original_seq_len > 0:
222
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
223
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
224
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
225
+
226
+ t = torch.arange(seqlen)
227
+ freqs = torch.outer(t, freqs)
228
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
229
+ return freqs_cis
230
+
231
+
232
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
233
+ """Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
234
+ y = x
235
+ x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
236
+ if inverse:
237
+ freqs_cis = freqs_cis.conj()
238
+ if x.ndim == 3:
239
+ freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
240
+ else:
241
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
242
+ x = torch.view_as_real(x * freqs_cis).flatten(-2)
243
+ y.copy_(x)
244
+ return y
245
+
246
+
247
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
248
+ """Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
249
+ assert x.dtype == torch.bfloat16
250
+ from fast_hadamard_transform import hadamard_transform
251
+ return hadamard_transform(x, scale=x.size(-1) ** -0.5)
252
+
253
+
254
+ @lru_cache(1)
255
+ def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
256
+ if start_pos >= window_size - 1:
257
+ start_pos %= window_size
258
+ matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
259
+ elif start_pos > 0:
260
+ matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
261
+ else:
262
+ base = torch.arange(seqlen).unsqueeze(1)
263
+ matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
264
+ matrix = torch.where(matrix > base, -1, matrix)
265
+ return matrix.unsqueeze(0).expand(bsz, -1, -1)
266
+
267
+
268
+ @lru_cache(2)
269
+ def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
270
+ if start_pos > 0:
271
+ matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
272
+ else:
273
+ matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
274
+ mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
275
+ matrix = torch.where(mask, -1, matrix + offset)
276
+ return matrix.unsqueeze(0).expand(bsz, -1, -1)
277
+
278
+
279
+ class Compressor(nn.Module):
280
+ """Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
281
+ When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
282
+
283
+ def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
284
+ super().__init__()
285
+ self.dim = args.dim
286
+ self.head_dim = head_dim
287
+ self.rope_head_dim = args.rope_head_dim
288
+ self.nope_head_dim = head_dim - args.rope_head_dim
289
+ self.compress_ratio = compress_ratio
290
+ self.overlap = compress_ratio == 4
291
+ self.rotate = rotate
292
+ coff = 1 + self.overlap
293
+
294
+ self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
295
+ # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
296
+ # When overlap, the first half of dims is for overlapping compression, second half for normal.
297
+ self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
298
+ self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
299
+ self.norm = RMSNorm(self.head_dim, args.norm_eps)
300
+ self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
301
+ # State buffers for decode-phase incremental compression.
302
+ # With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
303
+ self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
304
+ self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
305
+ self.freqs_cis: torch.Tensor = None
306
+
307
+ def overlap_transform(self, tensor: torch.Tensor, value=0):
308
+ # tensor: [b,s,r,2d]
309
+ b, s, _, _ = tensor.size()
310
+ ratio, d = self.compress_ratio, self.head_dim
311
+ new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
312
+ new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
313
+ new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
314
+ return new_tensor
315
+
316
+ def forward(self, x: torch.Tensor, start_pos: int):
317
+ assert self.kv_cache is not None
318
+ bsz, seqlen, _ = x.size()
319
+ ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
320
+ dtype = x.dtype
321
+ # compression need fp32
322
+ x = x.float()
323
+ kv = self.wkv(x)
324
+ score = self.wgate(x)
325
+ if start_pos == 0:
326
+ should_compress = seqlen >= ratio
327
+ remainder = seqlen % ratio
328
+ cutoff = seqlen - remainder
329
+ offset = ratio if overlap else 0
330
+ if overlap and cutoff >= ratio:
331
+ self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
332
+ self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
333
+ if remainder > 0:
334
+ kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
335
+ self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
336
+ score = score[:, :cutoff]
337
+ kv = kv.unflatten(1, (-1, ratio))
338
+ score = score.unflatten(1, (-1, ratio)) + self.ape
339
+ if overlap:
340
+ kv = self.overlap_transform(kv, 0)
341
+ score = self.overlap_transform(score, float("-inf"))
342
+ kv = (kv * score.softmax(dim=2)).sum(dim=2)
343
+ else:
344
+ should_compress = (start_pos + 1) % self.compress_ratio == 0
345
+ score += self.ape[start_pos % ratio]
346
+ if overlap:
347
+ self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
348
+ self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
349
+ if should_compress:
350
+ kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
351
+ score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
352
+ kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
353
+ self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
354
+ self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
355
+ else:
356
+ self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
357
+ self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
358
+ if should_compress:
359
+ kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
360
+ if not should_compress:
361
+ return
362
+ kv = self.norm(kv.to(dtype))
363
+ if start_pos == 0:
364
+ freqs_cis = self.freqs_cis[:cutoff:ratio]
365
+ else:
366
+ freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
367
+ apply_rotary_emb(kv[..., -rd:], freqs_cis)
368
+ if self.rotate:
369
+ kv = rotate_activation(kv)
370
+ fp4_act_quant(kv, fp4_block_size, True)
371
+ else:
372
+ act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
373
+ if start_pos == 0:
374
+ self.kv_cache[:bsz, :seqlen // ratio] = kv
375
+ else:
376
+ self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
377
+ return kv
378
+
379
+
380
+ class Indexer(torch.nn.Module):
381
+ """Selects top-k compressed KV positions for sparse attention via learned scoring.
382
+ Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
383
+
384
+ def __init__(self, args: ModelArgs, compress_ratio: int = 4):
385
+ super().__init__()
386
+ self.dim = args.dim
387
+ self.n_heads = args.index_n_heads
388
+ self.n_local_heads = args.index_n_heads // world_size
389
+ self.head_dim = args.index_head_dim
390
+ self.rope_head_dim = args.rope_head_dim
391
+ self.index_topk = args.index_topk
392
+ self.q_lora_rank = args.q_lora_rank
393
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
394
+ self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
395
+ self.softmax_scale = self.head_dim ** -0.5
396
+ self.compress_ratio = compress_ratio
397
+
398
+ self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
399
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
400
+ self.freqs_cis = None
401
+
402
+ def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
403
+ bsz, seqlen, _ = x.size()
404
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
405
+ ratio = self.compress_ratio
406
+ rd = self.rope_head_dim
407
+ end_pos = start_pos + seqlen
408
+ if self.compressor.kv_cache is None:
409
+ self.compressor.kv_cache = self.kv_cache
410
+ self.compressor.freqs_cis = self.freqs_cis
411
+ q = self.wq_b(qr)
412
+ q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
413
+ apply_rotary_emb(q[..., -rd:], freqs_cis)
414
+ q = rotate_activation(q)
415
+ # use fp4 simulation for q and kv in indexer
416
+ fp4_act_quant(q, fp4_block_size, True)
417
+ self.compressor(x, start_pos)
418
+ weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
419
+ # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
420
+ index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
421
+ index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
422
+ if world_size > 1:
423
+ dist.all_reduce(index_score)
424
+ if start_pos == 0:
425
+ mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
426
+ index_score += torch.where(mask, float("-inf"), 0)
427
+ topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
428
+ if start_pos == 0:
429
+ mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
430
+ topk_idxs = torch.where(mask, -1, topk_idxs + offset)
431
+ else:
432
+ topk_idxs += offset
433
+ return topk_idxs
434
+
435
+
436
+ class Attention(nn.Module):
437
+ """Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
438
+ Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
439
+ def __init__(self, layer_id: int, args: ModelArgs):
440
+ super().__init__()
441
+ self.layer_id = layer_id
442
+ self.dim = args.dim
443
+ self.n_heads = args.n_heads
444
+ self.n_local_heads = args.n_heads // world_size
445
+ self.q_lora_rank = args.q_lora_rank
446
+ self.o_lora_rank = args.o_lora_rank
447
+ self.head_dim = args.head_dim
448
+ self.rope_head_dim = args.rope_head_dim
449
+ self.nope_head_dim = args.head_dim - args.rope_head_dim
450
+ self.n_groups = args.o_groups
451
+ self.n_local_groups = self.n_groups // world_size
452
+ self.window_size = args.window_size
453
+ self.compress_ratio = args.compress_ratios[layer_id]
454
+ self.eps = args.norm_eps
455
+
456
+ self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
457
+ self.wq_a = Linear(self.dim, self.q_lora_rank)
458
+ self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
459
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
460
+ self.wkv = Linear(self.dim, self.head_dim)
461
+ self.kv_norm = RMSNorm(self.head_dim, self.eps)
462
+ self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
463
+ self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
464
+ self.softmax_scale = self.head_dim ** -0.5
465
+
466
+ if self.compress_ratio:
467
+ self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
468
+ if self.compress_ratio == 4:
469
+ self.indexer = Indexer(args, self.compress_ratio)
470
+ else:
471
+ self.indexer = None
472
+
473
+ kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
474
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
475
+ if self.compress_ratio:
476
+ original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
477
+ else:
478
+ # disable YaRN and use base rope_theta in pure sliding-window attention
479
+ original_seq_len, rope_theta = 0, args.rope_theta
480
+ freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
481
+ rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
482
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
483
+
484
+ def forward(self, x: torch.Tensor, start_pos: int):
485
+ bsz, seqlen, _ = x.size()
486
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
487
+ win = self.window_size
488
+ ratio = self.compress_ratio
489
+ rd = self.rope_head_dim
490
+ if self.compress_ratio and self.compressor.kv_cache is None:
491
+ self.compressor.kv_cache = self.kv_cache[:, win:]
492
+ self.compressor.freqs_cis = self.freqs_cis
493
+ if self.indexer is not None:
494
+ self.indexer.freqs_cis = self.freqs_cis
495
+ # q
496
+ qr = q = self.q_norm(self.wq_a(x))
497
+ q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
498
+ q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
499
+ apply_rotary_emb(q[..., -rd:], freqs_cis)
500
+
501
+ # win kv & topk_idxs
502
+ kv = self.wkv(x)
503
+ kv = self.kv_norm(kv)
504
+ apply_rotary_emb(kv[..., -rd:], freqs_cis)
505
+ # FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
506
+ act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
507
+ topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
508
+ if self.compress_ratio:
509
+ offset = kv.size(1) if start_pos == 0 else win
510
+ if self.indexer is not None:
511
+ compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
512
+ else:
513
+ compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
514
+ topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
515
+ topk_idxs = topk_idxs.int()
516
+
517
+ # compress kv & attn
518
+ if start_pos == 0:
519
+ if seqlen <= win:
520
+ self.kv_cache[:bsz, :seqlen] = kv
521
+ else:
522
+ cutoff = seqlen % win
523
+ self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
524
+ if self.compress_ratio:
525
+ if (kv_compress := self.compressor(x, start_pos)) is not None:
526
+ kv = torch.cat([kv, kv_compress], dim=1)
527
+ # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
528
+ o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
529
+ else:
530
+ self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
531
+ if self.compress_ratio:
532
+ self.compressor(x, start_pos)
533
+ o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
534
+ apply_rotary_emb(o[..., -rd:], freqs_cis, True)
535
+
536
+ # o
537
+ o = o.view(bsz, seqlen, self.n_local_groups, -1)
538
+ wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
539
+ # NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
540
+ # but using BF16 for simplicity.
541
+ o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
542
+ x = self.wo_b(o.flatten(2))
543
+ return x
544
+
545
+
546
+ class Gate(nn.Module):
547
+ """MoE gating: computes expert routing scores and selects top-k experts.
548
+ Supports hash-based routing (first n_hash_layers) where expert indices are
549
+ predetermined per token ID, and score-based routing (remaining layers)."""
550
+ def __init__(self, layer_id: int, args: ModelArgs):
551
+ super().__init__()
552
+ self.dim = args.dim
553
+ self.topk = args.n_activated_experts
554
+ self.score_func = args.score_func
555
+ self.route_scale = args.route_scale
556
+ self.hash = layer_id < args.n_hash_layers
557
+ self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
558
+ if self.hash:
559
+ self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
560
+ self.bias = None
561
+ else:
562
+ self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
563
+
564
+ def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
565
+ scores = linear(x.float(), self.weight.float())
566
+ if self.score_func == "softmax":
567
+ scores = scores.softmax(dim=-1)
568
+ elif self.score_func == "sigmoid":
569
+ scores = scores.sigmoid()
570
+ else:
571
+ scores = F.softplus(scores).sqrt()
572
+ original_scores = scores
573
+ # Bias shifts scores for expert selection (topk) but does not affect routing weights.
574
+ if self.bias is not None:
575
+ scores = scores + self.bias
576
+ if self.hash:
577
+ indices = self.tid2eid[input_ids]
578
+ else:
579
+ indices = scores.topk(self.topk, dim=-1)[1]
580
+ weights = original_scores.gather(1, indices)
581
+ if self.score_func != "softmax":
582
+ weights /= weights.sum(dim=-1, keepdim=True)
583
+ weights *= self.route_scale
584
+ return weights, indices
585
+
586
+
587
+ class Expert(nn.Module):
588
+ """Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
589
+ def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
590
+ super().__init__()
591
+ self.w1 = Linear(dim, inter_dim, dtype=dtype)
592
+ self.w2 = Linear(inter_dim, dim, dtype=dtype)
593
+ self.w3 = Linear(dim, inter_dim, dtype=dtype)
594
+ self.swiglu_limit = swiglu_limit
595
+
596
+ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
597
+ dtype = x.dtype
598
+ gate = self.w1(x).float()
599
+ up = self.w3(x).float()
600
+ if self.swiglu_limit > 0:
601
+ up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
602
+ gate = torch.clamp(gate, max=self.swiglu_limit)
603
+ x = F.silu(gate) * up
604
+ if weights is not None:
605
+ x = weights * x
606
+ return self.w2(x.to(dtype))
607
+
608
+
609
+ class MoE(nn.Module):
610
+ """Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
611
+ Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
612
+ def __init__(self, layer_id: int, args: ModelArgs):
613
+ super().__init__()
614
+ self.layer_id = layer_id
615
+ self.dim = args.dim
616
+ assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
617
+ self.n_routed_experts = args.n_routed_experts
618
+ self.n_local_experts = args.n_routed_experts // world_size
619
+ self.n_activated_experts = args.n_activated_experts
620
+ self.experts_start_idx = rank * self.n_local_experts
621
+ self.experts_end_idx = self.experts_start_idx + self.n_local_experts
622
+ self.gate = Gate(layer_id, args)
623
+ expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
624
+ self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
625
+ for i in range(self.n_routed_experts)])
626
+ assert args.n_shared_experts == 1
627
+ # no swiglu_limit
628
+ self.shared_experts = Expert(args.dim, args.moe_inter_dim)
629
+
630
+ def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
631
+ shape = x.size()
632
+ x = x.view(-1, self.dim)
633
+ weights, indices = self.gate(x, input_ids.flatten())
634
+ y = torch.zeros_like(x, dtype=torch.float32)
635
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
636
+ for i in range(self.experts_start_idx, self.experts_end_idx):
637
+ if counts[i] == 0:
638
+ continue
639
+ expert = self.experts[i]
640
+ idx, top = torch.where(indices == i)
641
+ y[idx] += expert(x[idx], weights[idx, top, None])
642
+ if world_size > 1:
643
+ dist.all_reduce(y)
644
+ y += self.shared_experts(x)
645
+ return y.type_as(x).view(shape)
646
+
647
+
648
+ class Block(nn.Module):
649
+ """Transformer block with Hyper-Connections (HC) mixing.
650
+ Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
651
+ hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
652
+ hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
653
+ def __init__(self, layer_id: int, args: ModelArgs):
654
+ super().__init__()
655
+ self.layer_id = layer_id
656
+ self.norm_eps = args.norm_eps
657
+ self.attn = Attention(layer_id, args)
658
+ self.ffn = MoE(layer_id, args)
659
+ self.attn_norm = RMSNorm(args.dim, self.norm_eps)
660
+ self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
661
+ self.hc_mult = hc_mult = args.hc_mult
662
+ self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
663
+ self.hc_eps = args.hc_eps
664
+ mix_hc = (2 + hc_mult) * hc_mult
665
+ hc_dim = hc_mult * args.dim
666
+ with set_dtype(torch.float32):
667
+ self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
668
+ self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
669
+ self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
670
+ self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
671
+ self.hc_attn_scale = nn.Parameter(torch.empty(3))
672
+ self.hc_ffn_scale = nn.Parameter(torch.empty(3))
673
+
674
+ def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
675
+ # x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
676
+ shape, dtype = x.size(), x.dtype
677
+ x = x.flatten(2).float()
678
+ rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
679
+ mixes = F.linear(x, hc_fn) * rsqrt
680
+ pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
681
+ y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
682
+ return y.to(dtype), post, comb
683
+
684
+ def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
685
+ # x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
686
+ y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
687
+ return y.type_as(x)
688
+
689
+ def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
690
+ residual = x
691
+ x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
692
+ x = self.attn_norm(x)
693
+ x = self.attn(x, start_pos)
694
+ x = self.hc_post(x, residual, post, comb)
695
+
696
+ residual = x
697
+ x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
698
+ x = self.ffn_norm(x)
699
+ x = self.ffn(x, input_ids)
700
+ x = self.hc_post(x, residual, post, comb)
701
+ return x
702
+
703
+
704
+ class ParallelHead(nn.Module):
705
+
706
+ def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
707
+ super().__init__()
708
+ self.vocab_size = vocab_size
709
+ self.dim = dim
710
+ self.norm_eps = norm_eps
711
+ self.hc_eps = hc_eps
712
+ self.part_vocab_size = (vocab_size // world_size)
713
+ # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
714
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
715
+
716
+ def get_logits(self, x):
717
+ return F.linear(x[:, -1].float(), self.weight)
718
+
719
+ def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
720
+ # x: [b,s,hc,d]
721
+ x = self.hc_head(x, hc_fn, hc_scale, hc_base)
722
+ logits = self.get_logits(norm(x))
723
+ if world_size > 1:
724
+ all_logits = [torch.empty_like(logits) for _ in range(world_size)]
725
+ dist.all_gather(all_logits, logits)
726
+ logits = torch.cat(all_logits, dim=-1)
727
+ return logits
728
+
729
+ def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
730
+ shape, dtype = x.size(), x.dtype
731
+ x = x.flatten(2).float()
732
+ rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
733
+ mixes = F.linear(x, hc_fn) * rsqrt
734
+ pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
735
+ y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
736
+ return y.to(dtype)
737
+
738
+
739
+ class MTPBlock(Block):
740
+
741
+ def __init__(self, layer_id: int, args: ModelArgs):
742
+ super().__init__(layer_id, args)
743
+ self.e_proj = Linear(args.dim, args.dim)
744
+ self.h_proj = Linear(args.dim, args.dim)
745
+ self.enorm = RMSNorm(args.dim, args.norm_eps)
746
+ self.hnorm = RMSNorm(args.dim, args.norm_eps)
747
+ self.norm = RMSNorm(args.dim, args.norm_eps)
748
+ self.hc_mult = hc_mult = args.hc_mult
749
+ hc_dim = hc_mult * args.dim
750
+ with set_dtype(torch.float32):
751
+ self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
752
+ self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
753
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
754
+ self.embed: ParallelEmbedding = None
755
+ self.head: ParallelHead = None
756
+
757
+ @torch.inference_mode()
758
+ def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
759
+ # x: [b,s,hc,d]
760
+ assert self.embed is not None and self.head is not None
761
+ e = self.embed(input_ids)
762
+ e = self.enorm(e)
763
+ x = self.hnorm(x)
764
+ x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
765
+ x = super().forward(x, start_pos, input_ids)
766
+ logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
767
+ return logits
768
+
769
+
770
+ class Transformer(nn.Module):
771
+ """Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
772
+ Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
773
+ def __init__(self, args: ModelArgs):
774
+ global world_size, rank, default_dtype, scale_fmt, scale_dtype
775
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
776
+ rank = dist.get_rank() if dist.is_initialized() else 0
777
+ default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
778
+ scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
779
+ scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
780
+ super().__init__()
781
+ self.max_seq_len = args.max_seq_len
782
+ self.norm_eps = args.norm_eps
783
+ self.hc_eps = args.hc_eps
784
+ self.embed = ParallelEmbedding(args.vocab_size, args.dim)
785
+ self.layers = torch.nn.ModuleList()
786
+ for layer_id in range(args.n_layers):
787
+ self.layers.append(Block(layer_id, args))
788
+ self.norm = RMSNorm(args.dim, self.norm_eps)
789
+ self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
790
+ self.mtp = torch.nn.ModuleList()
791
+ for layer_id in range(args.n_mtp_layers):
792
+ self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
793
+ self.mtp[-1].embed = self.embed
794
+ self.mtp[-1].head = self.head
795
+ self.hc_mult = hc_mult = args.hc_mult
796
+ hc_dim = hc_mult * args.dim
797
+ with set_dtype(torch.float32):
798
+ self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
799
+ self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
800
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
801
+
802
+ @torch.inference_mode()
803
+ def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
804
+ h = self.embed(input_ids)
805
+ # Expand to hc_mult copies for Hyper-Connections
806
+ h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
807
+ for layer in self.layers:
808
+ h = layer(h, start_pos, input_ids)
809
+ logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
810
+ return logits
811
+
812
+
813
+ if __name__ == "__main__":
814
+ torch.set_default_dtype(torch.bfloat16)
815
+ torch.set_default_device("cuda")
816
+ torch.manual_seed(0)
817
+ args = ModelArgs(n_hash_layers=0)
818
+ x = torch.randint(0, args.vocab_size, (2, 128))
819
+ model = Transformer(args)
820
+
821
+ print(model(x).size())
822
+ for i in range(128, 150):
823
+ print(i, model(x[:, 0:1], i).size())
824
+
825
+ h = torch.randn(2, 128, args.hc_mult, args.dim)
826
+ mtp = model.mtp[0]
827
+ print(mtp(h, 0, x).size())
828
+ print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
inference/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.10.0
2
+ transformers>=5.0.0
3
+ safetensors>=0.7.0
4
+ fast_hadamard_transform
5
+ tilelang==0.1.8
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee34ae7d963ef6675aebbd3b764795490e0666fdaa26c0b714b1e6217163e951
3
+ size 5365469592
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8ce706afe8f381c6ab5fbf3794a60254ecea9ab15264b19ebc631b9bac1fcbf
3
+ size 5365527480
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c88aadfd9aa56b48048f556baee53f3a45cd008fb2f1fba9e6e62a0072edcb2
3
+ size 4647227756
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cfa83086d31b6129e4606fd7951a52aba55aff75e0aef2bf46894d8c0d2dc15
3
+ size 4653028916
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|begin▁of▁sentence|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|end▁of▁sentence|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 1048576,
23
+ "pad_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<|end▁of▁sentence|>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "sp_model_kwargs": {},
32
+ "unk_token": null,
33
+ "tokenizer_class": "PreTrainedTokenizerFast"
34
+ }