xinhe commited on
Commit
0adcbb3
·
verified ·
1 Parent(s): 4e721c3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. DeepSeek_V4.pdf +3 -0
  3. LICENSE +21 -0
  4. README.md +54 -0
  5. assets/dsv4_performance.png +3 -0
  6. config.json +140 -0
  7. encoding/README.md +156 -0
  8. encoding/encoding_dsv4.py +744 -0
  9. encoding/test_encoding_dsv4.py +89 -0
  10. encoding/tests/test_input_1.json +81 -0
  11. encoding/tests/test_input_2.json +24 -0
  12. encoding/tests/test_input_3.json +159 -0
  13. encoding/tests/test_input_4.json +28 -0
  14. encoding/tests/test_output_1.txt +36 -0
  15. encoding/tests/test_output_2.txt +1 -0
  16. encoding/tests/test_output_3.txt +38 -0
  17. encoding/tests/test_output_4.txt +29 -0
  18. generation_config.json +9 -0
  19. inference/README.md +25 -0
  20. inference/config.json +35 -0
  21. inference/config_w4a16.json +34 -0
  22. inference/convert.py +168 -0
  23. inference/convert_w4a16.py +246 -0
  24. inference/generate.py +159 -0
  25. inference/kernel.py +536 -0
  26. inference/model.py +992 -0
  27. inference/requirements.txt +6 -0
  28. model-00001-of-00064.safetensors +3 -0
  29. model-00002-of-00064.safetensors +3 -0
  30. model-00003-of-00064.safetensors +3 -0
  31. model-00004-of-00064.safetensors +3 -0
  32. model-00005-of-00064.safetensors +3 -0
  33. model-00006-of-00064.safetensors +3 -0
  34. model-00007-of-00064.safetensors +3 -0
  35. model-00008-of-00064.safetensors +3 -0
  36. model-00009-of-00064.safetensors +3 -0
  37. model-00010-of-00064.safetensors +3 -0
  38. model-00011-of-00064.safetensors +3 -0
  39. model-00012-of-00064.safetensors +3 -0
  40. model-00013-of-00064.safetensors +3 -0
  41. model-00014-of-00064.safetensors +3 -0
  42. model-00015-of-00064.safetensors +3 -0
  43. model-00016-of-00064.safetensors +3 -0
  44. model-00017-of-00064.safetensors +3 -0
  45. model-00018-of-00064.safetensors +3 -0
  46. model-00019-of-00064.safetensors +3 -0
  47. model-00020-of-00064.safetensors +3 -0
  48. model-00021-of-00064.safetensors +3 -0
  49. model-00022-of-00064.safetensors +3 -0
  50. model-00023-of-00064.safetensors +3 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
37
+ *.pdf filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
DeepSeek_V4.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa4a3490e2dcc03c9da61b04a8be471795e9966ebbbf292a3899fa62683a330e
3
+ size 4479901
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ base_model:
5
+ - deepseek-ai/DeepSeek-V4-Pro
6
+ ---
7
+
8
+
9
+ This model is an int4 model with group_size 128 of [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro) generated by [intel/auto-round](https://github.com/intel/auto-round) with RTN mode. Please follow the license of the original model.
10
+
11
+
12
+
13
+ ## How to Run Locally
14
+
15
+ **vLLM and Sglang is not supported currently: https://huggingface.co/Intel/DeepSeek-V4-Flash-W4A16-AutoRound/discussions/1**
16
+
17
+ Please refer to the [inference](inference/README.md) folder for detailed instructions on running DeepSeek-V4 locally, including model weight conversion and interactive chat demos.
18
+
19
+ For local deployment, we recommend setting the sampling parameters to `temperature = 1.0, top_p = 1.0`. For the Think Max reasoning mode, we recommend setting the context window to at least **384K** tokens.
20
+
21
+
22
+ ## Generate the Model
23
+
24
+ This pr is required: [Support model_free WOQ quantization](https://github.com/intel/auto-round/pull/1699)
25
+
26
+ ~~~bash
27
+ auto-round deepseek-ai/DeepSeek-V4-Pro --model_free --output_dir "./DeepSeek-V4-Pro-W4A16"
28
+ ~~~
29
+
30
+
31
+
32
+ ## Ethical Considerations and Limitations
33
+
34
+ The model can produce factually incorrect output, and should not be relied on to produce factually accurate information. Because of the limitations of the pretrained model and the finetuning datasets, it is possible that this model could generate lewd, biased or otherwise offensive outputs.
35
+
36
+ Therefore, before deploying any applications of the model, developers should perform safety testing.
37
+
38
+ ## Caveats and Recommendations
39
+
40
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model.
41
+
42
+ Here are a couple of useful links to learn more about Intel's AI software:
43
+
44
+ - [Intel Neural Compressor](https://github.com/intel/neural-compressor)
45
+
46
+ ## Disclaimer
47
+
48
+ The license on this model does not constitute legal advice. We are not responsible for the actions of third parties who use this model. Please consult an attorney before using this model for commercial purposes.
49
+
50
+ ## Cite
51
+
52
+ @article{cheng2023optimize, title={Optimize weight rounding via signed gradient descent for the quantization of llms}, author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi}, journal={arXiv preprint arXiv:2309.05516}, year={2023} }
53
+
54
+ [arxiv](https://arxiv.org/abs/2309.05516) [github](https://github.com/intel/auto-round)
assets/dsv4_performance.png ADDED

Git LFS Details

  • SHA256: 8fd472981a4c8d40c1845c51c5e8961fc4ef3ac22e7ec77801f534c239c1b30f
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
config.json ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV4ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 1,
9
+ "hc_eps": 1e-06,
10
+ "hc_mult": 4,
11
+ "hc_sinkhorn_iters": 20,
12
+ "head_dim": 512,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 7168,
15
+ "index_head_dim": 128,
16
+ "index_n_heads": 64,
17
+ "index_topk": 1024,
18
+ "initializer_range": 0.02,
19
+ "max_position_embeddings": 1048576,
20
+ "model_type": "deepseek_v4",
21
+ "moe_intermediate_size": 3072,
22
+ "n_routed_experts": 384,
23
+ "n_shared_experts": 1,
24
+ "norm_topk_prob": true,
25
+ "num_attention_heads": 128,
26
+ "num_experts_per_tok": 6,
27
+ "num_hidden_layers": 61,
28
+ "num_hash_layers": 3,
29
+ "num_key_value_heads": 1,
30
+ "num_nextn_predict_layers": 1,
31
+ "o_groups": 16,
32
+ "o_lora_rank": 1024,
33
+ "q_lora_rank": 1536,
34
+ "qk_rope_head_dim": 64,
35
+ "quantization_config": {
36
+ "quant_method": "auto-round",
37
+ "packing_format": "auto_round:auto_gptq",
38
+ "bits": 4,
39
+ "group_size": 128,
40
+ "sym": true,
41
+ "data_type": "int",
42
+ "iters": 0,
43
+ "model_free": true,
44
+ "autoround_version": "0.13.0",
45
+ "extra_config": {
46
+ "embed": {
47
+ "bits": 16,
48
+ "data_type": "float"
49
+ },
50
+ "head": {
51
+ "bits": 16,
52
+ "data_type": "float"
53
+ }
54
+ }
55
+ },
56
+ "rms_norm_eps": 1e-06,
57
+ "rope_scaling": {
58
+ "beta_fast": 32,
59
+ "beta_slow": 1,
60
+ "factor": 16,
61
+ "original_max_position_embeddings": 65536,
62
+ "type": "yarn"
63
+ },
64
+ "rope_theta": 10000,
65
+ "routed_scaling_factor": 2.5,
66
+ "scoring_func": "sqrtsoftplus",
67
+ "sliding_window": 128,
68
+ "swiglu_limit": 10.0,
69
+ "tie_word_embeddings": false,
70
+ "topk_method": "noaux_tc",
71
+ "torch_dtype": "bfloat16",
72
+ "transformers_version": "4.57.1",
73
+ "use_cache": true,
74
+ "vocab_size": 129280,
75
+ "compress_rope_theta": 160000,
76
+ "compress_ratios": [
77
+ 128,
78
+ 128,
79
+ 4,
80
+ 128,
81
+ 4,
82
+ 128,
83
+ 4,
84
+ 128,
85
+ 4,
86
+ 128,
87
+ 4,
88
+ 128,
89
+ 4,
90
+ 128,
91
+ 4,
92
+ 128,
93
+ 4,
94
+ 128,
95
+ 4,
96
+ 128,
97
+ 4,
98
+ 128,
99
+ 4,
100
+ 128,
101
+ 4,
102
+ 128,
103
+ 4,
104
+ 128,
105
+ 4,
106
+ 128,
107
+ 4,
108
+ 128,
109
+ 4,
110
+ 128,
111
+ 4,
112
+ 128,
113
+ 4,
114
+ 128,
115
+ 4,
116
+ 128,
117
+ 4,
118
+ 128,
119
+ 4,
120
+ 128,
121
+ 4,
122
+ 128,
123
+ 4,
124
+ 128,
125
+ 4,
126
+ 128,
127
+ 4,
128
+ 128,
129
+ 4,
130
+ 128,
131
+ 4,
132
+ 128,
133
+ 4,
134
+ 128,
135
+ 4,
136
+ 128,
137
+ 4,
138
+ 0
139
+ ]
140
+ }
encoding/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek-V4 Encoding
2
+
3
+ This document describes the prompt encoding format used by DeepSeek-V4 series models. The encoding handles multi-turn conversations, tool calling, extended thinking (reasoning), and quick instruction tasks.
4
+
5
+ A self-contained reference implementation is provided in `encoding_dsv4.py`.
6
+
7
+ ## Quick Start
8
+
9
+ ```python
10
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
11
+
12
+ # Encode a conversation
13
+ messages = [
14
+ {"role": "system", "content": "You are a helpful assistant."},
15
+ {"role": "user", "content": "What is 2+2?"},
16
+ ]
17
+ prompt = encode_messages(messages, thinking_mode="thinking")
18
+ # => "<|begin▁of▁sentence|>You are a helpful assistant.<|User|>What is 2+2?<|Assistant|><think>"
19
+
20
+ # Parse model output back to structured message
21
+ completion = "Simple arithmetic.</think>2 + 2 = 4.<|end▁of▁sentence|>"
22
+ parsed = parse_message_from_completion_text(completion, thinking_mode="thinking")
23
+ # => {"role": "assistant", "reasoning_content": "Simple arithmetic.", "content": "2 + 2 = 4.", "tool_calls": []}
24
+ ```
25
+
26
+ > **Note:** The `parse_message_from_completion_text` function is designed to handle well-formatted model output only. It does not attempt to correct or recover from malformed output that the model might occasionally generate. For production use, additional error handling is recommended.
27
+
28
+ ## Message Format
29
+
30
+ ### Special Tokens
31
+
32
+ | Token | Purpose |
33
+ |-------|---------|
34
+ | `<|begin▁of▁sentence|>` | Beginning of sequence (BOS) |
35
+ | `<|end▁of▁sentence|>` | End of assistant turn (EOS) |
36
+ | `<|User|>` | User turn prefix |
37
+ | `<|Assistant|>` | Assistant turn prefix |
38
+ | `<|latest_reminder|>` | Latest reminder (date, locale, etc.) |
39
+ | `<think>` / `</think>` | Reasoning block delimiters |
40
+ | `|DSML|` | DSML markup token |
41
+
42
+ ### Roles
43
+
44
+ The encoding supports the following message roles: `system`, `user`, `assistant`, `tool`, `latest_reminder`, and `developer`.
45
+
46
+ > **Note on the `developer` role:** The `developer` role is used exclusively in the internal search agent pipeline. It is not needed for general-purpose chat or tool-calling tasks, and the official API does not accept messages with this role.
47
+
48
+ ### Basic Chat
49
+
50
+ A simple multi-turn conversation is encoded as:
51
+
52
+ ```
53
+ <|begin▁of▁sentence|>{system_prompt}
54
+ <|User|>{user_message}<|Assistant|></think>{response}<|end▁of▁sentence|>
55
+ <|User|>{user_message_2}<|Assistant|></think>{response_2}<|end▁of▁sentence|>
56
+ ```
57
+
58
+ - The BOS token is prepended at the very beginning of the conversation.
59
+ - In **chat mode** (`thinking_mode="chat"`), `</think>` is placed right after `<|Assistant|>` to immediately close the thinking block, so the model generates content directly.
60
+
61
+ ### Interleaved Thinking Mode
62
+
63
+ In **thinking mode** (`thinking_mode="thinking"`), the model produces explicit reasoning inside `<think>...</think>` blocks before responding.
64
+
65
+ ```
66
+ <|begin▁of▁sentence|>{system_prompt}
67
+ <|User|>{message}<|Assistant|><think>{reasoning}</think>{response}<|end▁of▁sentence|>
68
+ ```
69
+
70
+ The `drop_thinking` parameter (default `True`) controls whether reasoning from earlier turns is preserved:
71
+
72
+ - **Without tools**: `drop_thinking` takes effect. Reasoning content from assistant turns **before** the last user message is stripped. Only the final assistant turn retains its `<think>...</think>` block.
73
+ - **With tools** (on system or developer message): `drop_thinking` is automatically disabled. All turns retain their reasoning, because tool-calling conversations require full context for the model to track multi-step reasoning across tool calls.
74
+
75
+ ### Tool Calling (DSML Format)
76
+
77
+ Tools are defined on the `system` or `developer` message via the `tools` field (OpenAI-compatible format). When tools are present, the following schema block is injected into the system/user prompt:
78
+
79
+ ```
80
+ ## Tools
81
+
82
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
83
+
84
+ <|DSML|tool_calls>
85
+ <|DSML|invoke name="$TOOL_NAME">
86
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
87
+ ...
88
+ </|DSML|invoke>
89
+ <|DSML|invoke name="$TOOL_NAME2">
90
+ ...
91
+ </|DSML|invoke>
92
+ </|DSML|tool_calls>
93
+
94
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
95
+
96
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
97
+
98
+ Otherwise, output directly after </think> with tool calls or final response.
99
+
100
+ ### Available Tool Schemas
101
+
102
+ {tool_definitions_json}
103
+
104
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
105
+ ```
106
+
107
+ An actual tool call in the assistant turn looks like:
108
+
109
+ ```xml
110
+ <|DSML|tool_calls>
111
+ <|DSML|invoke name="function_name">
112
+ <|DSML|parameter name="param" string="true">string_value</|DSML|parameter>
113
+ <|DSML|parameter name="count" string="false">5</|DSML|parameter>
114
+ </|DSML|invoke>
115
+ </|DSML|tool_calls><|end▁of▁sentence|>
116
+ ```
117
+
118
+ - `string="true"`: the parameter value is a raw string.
119
+ - `string="false"`: the parameter value is JSON (number, boolean, array, object).
120
+
121
+ Tool execution results are wrapped in `<tool_result>` tags within user messages:
122
+
123
+ ```
124
+ <|User|><tool_result>{result_json}</tool_result><|Assistant|><think>...
125
+ ```
126
+
127
+ When multiple tool results are present, they are sorted by the order of the corresponding `tool_calls` in the preceding assistant message.
128
+
129
+ ### Reasoning Effort
130
+
131
+ When `reasoning_effort="max"` is set, a special prefix is prepended at the very beginning of the prompt (before the system message) to instruct the model to maximize its reasoning depth:
132
+
133
+ ```
134
+ Reasoning Effort: Absolute maximum with no shortcuts permitted.
135
+ You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.
136
+ Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.
137
+ ```
138
+
139
+ ### Quick Instruction Special Tokens
140
+
141
+ Quick instruction tokens are used for auxiliary classification and generation tasks. They are appended to messages via the `"task"` field to trigger specialized model behavior for a single-token or short-form output.
142
+
143
+ | Special Token | Description | Format |
144
+ |:---|:---|:---|
145
+ | `<|action|>` | Determines whether the user prompt requires a web search or can be answered directly. | `...<|User|>{prompt}<|Assistant|><think><|action|>` |
146
+ | `<|title|>` | Generates a concise conversation title after the first assistant response. | `...<|Assistant|>{response}<|end▁of▁sentence|><|title|>` |
147
+ | `<|query|>` | Generates search queries for the user prompt. | `...<|User|>{prompt}<|query|>` |
148
+ | `<|authority|>` | Classifies the user prompt's demand for source authoritativeness. | `...<|User|>{prompt}<|authority|>` |
149
+ | `<|domain|>` | Identifies the domain of the user prompt. | `...<|User|>{prompt}<|domain|>` |
150
+ | `<|extracted_url|>` `<|read_url|>` | Determines whether each URL in the user prompt should be fetched and read. | `...<|User|>{prompt}<|extracted_url|>{url}<|read_url|>` |
151
+
152
+ Usage in message format:
153
+
154
+ - **`action`** on a user message: the `<|action|>` token is placed after the assistant prefix and thinking token, triggering a routing decision (e.g., "Search" or "Answer").
155
+ - **Other tasks** (`query`, `authority`, `domain`, `read_url`) on a user message: the task token is appended directly after the user content.
156
+ - **`title`** on an assistant message: the `<|title|>` token is appended after the assistant's EOS. The next assistant message provides the generated title.
encoding/encoding_dsv4.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek-V4 Encoding
3
+
4
+ A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
5
+ with tool calling, thinking mode, and quick instruction task support.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Union, Optional, Tuple
9
+ import copy
10
+ import json
11
+ import re
12
+
13
+ # ============================================================
14
+ # Special Tokens
15
+ # ============================================================
16
+
17
+ bos_token: str = "<|begin▁of▁sentence|>"
18
+ eos_token: str = "<|end▁of▁sentence|>"
19
+ thinking_start_token: str = "<think>"
20
+ thinking_end_token: str = "</think>"
21
+ dsml_token: str = "|DSML|"
22
+
23
+ USER_SP_TOKEN = "<|User|>"
24
+ ASSISTANT_SP_TOKEN = "<|Assistant|>"
25
+ LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
26
+
27
+ # Task special tokens for internal classification tasks
28
+ DS_TASK_SP_TOKENS = {
29
+ "action": "<|action|>",
30
+ "query": "<|query|>",
31
+ "authority": "<|authority|>",
32
+ "domain": "<|domain|>",
33
+ "title": "<|title|>",
34
+ "read_url": "<|read_url|>",
35
+ }
36
+ VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
37
+
38
+ # ============================================================
39
+ # Templates
40
+ # ============================================================
41
+
42
+ system_msg_template: str = "{content}"
43
+ user_msg_template: str = "{content}"
44
+ latest_reminder_msg_template: str = "{content}"
45
+ assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
46
+ assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
47
+ thinking_template: str = "{reasoning_content}"
48
+
49
+ response_format_template: str = (
50
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
51
+ )
52
+ tool_call_template: str = (
53
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
54
+ )
55
+ tool_calls_template = (
56
+ "<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
57
+ )
58
+ tool_calls_block_name: str = "tool_calls"
59
+
60
+ tool_output_template: str = (
61
+ "<tool_result>{content}</tool_result>"
62
+ )
63
+
64
+ REASONING_EFFORT_MAX = (
65
+ "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
66
+ "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
67
+ "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
68
+ )
69
+
70
+ TOOLS_TEMPLATE = """## Tools
71
+
72
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
73
+
74
+ <{dsml_token}tool_calls>
75
+ <{dsml_token}invoke name="$TOOL_NAME">
76
+ <{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
77
+ ...
78
+ </{dsml_token}invoke>
79
+ <{dsml_token}invoke name="$TOOL_NAME2">
80
+ ...
81
+ </{dsml_token}invoke>
82
+ </{dsml_token}tool_calls>
83
+
84
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
85
+
86
+ If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
87
+
88
+ Otherwise, output directly after {thinking_end_token} with tool calls or final response.
89
+
90
+ ### Available Tool Schemas
91
+
92
+ {tool_schemas}
93
+
94
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
95
+ """
96
+
97
+ # ============================================================
98
+ # Utility Functions
99
+ # ============================================================
100
+
101
+ def to_json(value: Any) -> str:
102
+ """Serialize a value to JSON string."""
103
+ try:
104
+ return json.dumps(value, ensure_ascii=False)
105
+ except:
106
+ return json.dumps(value, ensure_ascii=True)
107
+
108
+
109
+ def tools_from_openai_format(tools):
110
+ """Extract function definitions from OpenAI-format tool list."""
111
+ return [tool["function"] for tool in tools]
112
+
113
+
114
+ def tool_calls_from_openai_format(tool_calls):
115
+ """Convert OpenAI-format tool calls to internal format."""
116
+ return [
117
+ {
118
+ "name": tool_call["function"]["name"],
119
+ "arguments": tool_call["function"]["arguments"],
120
+ }
121
+ for tool_call in tool_calls
122
+ ]
123
+
124
+
125
+ def tool_calls_to_openai_format(tool_calls):
126
+ """Convert internal tool calls to OpenAI format."""
127
+ return [
128
+ {
129
+ "type": "function",
130
+ "function": {
131
+ "name": tool_call["name"],
132
+ "arguments": tool_call["arguments"],
133
+ }
134
+ }
135
+ for tool_call in tool_calls
136
+ ]
137
+
138
+
139
+ def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
140
+ """
141
+ Encode tool call arguments into DSML parameter format.
142
+
143
+ Args:
144
+ tool_call: Dict with "name" and "arguments" (JSON string) keys.
145
+
146
+ Returns:
147
+ DSML-formatted parameter string.
148
+ """
149
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
150
+ P_dsml_strs = []
151
+
152
+ try:
153
+ arguments = json.loads(tool_call["arguments"])
154
+ except Exception as err:
155
+ arguments = {"arguments": tool_call["arguments"]}
156
+
157
+ for k, v in arguments.items():
158
+ p_dsml_str = p_dsml_template.format(
159
+ dsml_token=dsml_token,
160
+ key=k,
161
+ is_str="true" if isinstance(v, str) else "false",
162
+ value=v if isinstance(v, str) else to_json(v),
163
+ )
164
+ P_dsml_strs.append(p_dsml_str)
165
+
166
+ return "\n".join(P_dsml_strs)
167
+
168
+
169
+ def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
170
+ """
171
+ Decode DSML parameters back to a tool call dict.
172
+
173
+ Args:
174
+ tool_name: Name of the tool.
175
+ tool_args: Dict mapping param_name -> (value, is_string_flag).
176
+
177
+ Returns:
178
+ Dict with "name" and "arguments" (JSON string) keys.
179
+ """
180
+ def _decode_value(key: str, value: str, string: str):
181
+ if string == "true":
182
+ value = to_json(value)
183
+ return f"{to_json(key)}: {value}"
184
+
185
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
186
+ return dict(name=tool_name, arguments=tool_args_json)
187
+
188
+
189
+ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
190
+ """
191
+ Render tool schemas into the system prompt format.
192
+
193
+ Args:
194
+ tools: List of tool schema dicts (each with name, description, parameters).
195
+
196
+ Returns:
197
+ Formatted tools section string.
198
+ """
199
+ tools_json = [to_json(t) for t in tools]
200
+
201
+ return TOOLS_TEMPLATE.format(
202
+ tool_schemas="\n".join(tools_json),
203
+ dsml_token=dsml_token,
204
+ thinking_start_token=thinking_start_token,
205
+ thinking_end_token=thinking_end_token,
206
+ )
207
+
208
+
209
+ def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
210
+ """Find the index of the last user/developer message."""
211
+ last_user_index = -1
212
+ for idx in range(len(messages) - 1, -1, -1):
213
+ if messages[idx].get("role") in ["user", "developer"]:
214
+ last_user_index = idx
215
+ break
216
+ return last_user_index
217
+
218
+
219
+ # ============================================================
220
+ # Message Rendering
221
+ # ============================================================
222
+
223
+ def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
224
+ """
225
+ Render a single message at the given index into its encoded string form.
226
+
227
+ This is the core function that converts each message in the conversation
228
+ into the DeepSeek-V4 format.
229
+
230
+ Args:
231
+ index: Index of the message to render.
232
+ messages: Full list of messages in the conversation.
233
+ thinking_mode: Either "chat" or "thinking".
234
+ drop_thinking: Whether to drop reasoning content from earlier turns.
235
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
236
+
237
+ Returns:
238
+ Encoded string for this message.
239
+ """
240
+ assert 0 <= index < len(messages)
241
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
242
+
243
+ prompt = ""
244
+ msg = messages[index]
245
+ last_user_idx = find_last_user_index(messages)
246
+
247
+ role = msg.get("role")
248
+ content = msg.get("content")
249
+ tools = msg.get("tools")
250
+ response_format = msg.get("response_format")
251
+ tool_calls = msg.get("tool_calls")
252
+ reasoning_content = msg.get("reasoning_content")
253
+ wo_eos = msg.get("wo_eos", False)
254
+
255
+ if tools:
256
+ tools = tools_from_openai_format(tools)
257
+ if tool_calls:
258
+ tool_calls = tool_calls_from_openai_format(tool_calls)
259
+
260
+ # Reasoning effort prefix (only at index 0 in thinking mode with max effort)
261
+ assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
262
+ if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
263
+ prompt += REASONING_EFFORT_MAX
264
+
265
+ if role == "system":
266
+ prompt += system_msg_template.format(content=content or "")
267
+ if tools:
268
+ prompt += "\n\n" + render_tools(tools)
269
+ if response_format:
270
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
271
+
272
+ elif role == "developer":
273
+ assert content, f"Invalid message for role `{role}`: {msg}"
274
+
275
+ content_developer = USER_SP_TOKEN
276
+ content_developer += content
277
+
278
+ if tools:
279
+ content_developer += "\n\n" + render_tools(tools)
280
+ if response_format:
281
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
282
+
283
+ prompt += user_msg_template.format(content=content_developer)
284
+
285
+ elif role == "user":
286
+ prompt += USER_SP_TOKEN
287
+
288
+ # Handle content blocks (tool results mixed with text)
289
+ content_blocks = msg.get("content_blocks")
290
+ if content_blocks:
291
+ parts = []
292
+ for block in content_blocks:
293
+ block_type = block.get("type")
294
+ if block_type == "text":
295
+ parts.append(block.get("text", ""))
296
+ elif block_type == "tool_result":
297
+ tool_content = block.get("content", "")
298
+ if isinstance(tool_content, list):
299
+ text_parts = []
300
+ for b in tool_content:
301
+ if b.get("type") == "text":
302
+ text_parts.append(b.get("text", ""))
303
+ else:
304
+ text_parts.append(f"[Unsupported {b.get('type')}]")
305
+ tool_content = "\n\n".join(text_parts)
306
+ parts.append(tool_output_template.format(content=tool_content))
307
+ else:
308
+ parts.append(f"[Unsupported {block_type}]")
309
+ prompt += "\n\n".join(parts)
310
+ else:
311
+ prompt += content or ""
312
+
313
+ elif role == "latest_reminder":
314
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
315
+
316
+ elif role == "tool":
317
+ raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
318
+
319
+ elif role == "assistant":
320
+ thinking_part = ""
321
+ tc_content = ""
322
+
323
+ if tool_calls:
324
+ tc_list = [
325
+ tool_call_template.format(
326
+ dsml_token=dsml_token,
327
+ name=tc.get("name"),
328
+ arguments=encode_arguments_to_dsml(tc)
329
+ )
330
+ for tc in tool_calls
331
+ ]
332
+ tc_content += '\n\n' + tool_calls_template.format(
333
+ dsml_token=dsml_token,
334
+ tool_calls="\n".join(tc_list),
335
+ tc_block_name=tool_calls_block_name,
336
+ )
337
+
338
+ summary_content = content or ""
339
+ rc = reasoning_content or ""
340
+
341
+ # Check if previous message has a task - if so, this is a task output (no thinking)
342
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
343
+
344
+ if thinking_mode == "thinking" and not prev_has_task:
345
+ if not drop_thinking or index > last_user_idx:
346
+ thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
347
+ else:
348
+ thinking_part = ""
349
+
350
+ if wo_eos:
351
+ prompt += assistant_msg_wo_eos_template.format(
352
+ reasoning=thinking_part,
353
+ content=summary_content,
354
+ tool_calls=tc_content,
355
+ )
356
+ else:
357
+ prompt += assistant_msg_template.format(
358
+ reasoning=thinking_part,
359
+ content=summary_content,
360
+ tool_calls=tc_content,
361
+ )
362
+ else:
363
+ raise NotImplementedError(f"Unknown role: {role}")
364
+
365
+ # Append transition tokens based on what follows
366
+ if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
367
+ return prompt
368
+
369
+ task = messages[index].get("task")
370
+ if task is not None:
371
+ # Task special token for internal classification tasks
372
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
373
+ task_sp_token = DS_TASK_SP_TOKENS[task]
374
+
375
+ if task != "action":
376
+ # Non-action tasks: append task sp token directly after the message
377
+ prompt += task_sp_token
378
+ else:
379
+ # Action task: append Assistant + thinking token + action sp token
380
+ prompt += ASSISTANT_SP_TOKEN
381
+ prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
382
+ prompt += task_sp_token
383
+
384
+ elif messages[index].get("role") in ["user", "developer"]:
385
+ # Normal generation: append Assistant + thinking token
386
+ prompt += ASSISTANT_SP_TOKEN
387
+ if not drop_thinking and thinking_mode == "thinking":
388
+ prompt += thinking_start_token
389
+ elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
390
+ prompt += thinking_start_token
391
+ else:
392
+ prompt += thinking_end_token
393
+
394
+ return prompt
395
+
396
+
397
+ # ============================================================
398
+ # Preprocessing
399
+ # ============================================================
400
+
401
+ def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
402
+ """
403
+ Merge tool messages into the preceding user message using content_blocks format.
404
+
405
+ DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
406
+ are encoded as <tool_result> blocks within user messages.
407
+
408
+ This function converts a standard OpenAI-format conversation (with separate
409
+ "tool" role messages) into V4 format where tool results are merged into
410
+ user messages.
411
+
412
+ Args:
413
+ messages: List of message dicts in OpenAI format.
414
+
415
+ Returns:
416
+ Processed message list with tool messages merged into user messages.
417
+ """
418
+ merged: List[Dict[str, Any]] = []
419
+
420
+ for msg in messages:
421
+ msg = copy.deepcopy(msg)
422
+ role = msg.get("role")
423
+
424
+ if role == "tool":
425
+ # Convert tool message to a user message with tool_result block
426
+ tool_block = {
427
+ "type": "tool_result",
428
+ "tool_use_id": msg.get("tool_call_id", ""),
429
+ "content": msg.get("content", ""),
430
+ }
431
+ # Merge into previous message if it's already a user (merged tool)
432
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
433
+ merged[-1]["content_blocks"].append(tool_block)
434
+ else:
435
+ merged.append({
436
+ "role": "user",
437
+ "content_blocks": [tool_block],
438
+ })
439
+ elif role == "user":
440
+ text_block = {"type": "text", "text": msg.get("content", "")}
441
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
442
+ merged[-1]["content_blocks"].append(text_block)
443
+ else:
444
+ new_msg = {
445
+ "role": "user",
446
+ "content": msg.get("content", ""),
447
+ "content_blocks": [text_block],
448
+ }
449
+ # Preserve extra fields (task, wo_eos, mask, etc.)
450
+ for key in ("task", "wo_eos", "mask"):
451
+ if key in msg:
452
+ new_msg[key] = msg[key]
453
+ merged.append(new_msg)
454
+ else:
455
+ merged.append(msg)
456
+
457
+ return merged
458
+
459
+
460
+ def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
461
+ """
462
+ Sort tool_result blocks within user messages by the order of tool_calls
463
+ in the preceding assistant message.
464
+
465
+ Args:
466
+ messages: Preprocessed message list (after merge_tool_messages).
467
+
468
+ Returns:
469
+ Message list with sorted tool result blocks.
470
+ """
471
+ last_tool_call_order: Dict[str, int] = {}
472
+
473
+ for msg in messages:
474
+ role = msg.get("role")
475
+ if role == "assistant" and msg.get("tool_calls"):
476
+ last_tool_call_order = {}
477
+ for idx, tc in enumerate(msg["tool_calls"]):
478
+ tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
479
+ if tc_id:
480
+ last_tool_call_order[tc_id] = idx
481
+
482
+ elif role == "user" and msg.get("content_blocks"):
483
+ tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
484
+ if len(tool_blocks) > 1 and last_tool_call_order:
485
+ sorted_blocks = sorted(
486
+ tool_blocks,
487
+ key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
488
+ )
489
+ sorted_idx = 0
490
+ new_blocks = []
491
+ for block in msg["content_blocks"]:
492
+ if block.get("type") == "tool_result":
493
+ new_blocks.append(sorted_blocks[sorted_idx])
494
+ sorted_idx += 1
495
+ else:
496
+ new_blocks.append(block)
497
+ msg["content_blocks"] = new_blocks
498
+
499
+ return messages
500
+
501
+
502
+ # ============================================================
503
+ # Main Encoding Function
504
+ # ============================================================
505
+
506
+ def encode_messages(
507
+ messages: List[Dict[str, Any]],
508
+ thinking_mode: str,
509
+ context: Optional[List[Dict[str, Any]]] = None,
510
+ drop_thinking: bool = True,
511
+ add_default_bos_token: bool = True,
512
+ reasoning_effort: Optional[str] = None,
513
+ ) -> str:
514
+ """
515
+ Encode a list of messages into the DeepSeek-V4 prompt format.
516
+
517
+ This is the main entry point for encoding conversations. It handles:
518
+ - BOS token insertion
519
+ - Thinking mode with optional reasoning content dropping
520
+ - Tool message merging into user messages
521
+ - Multi-turn conversation context
522
+
523
+ Args:
524
+ messages: List of message dicts to encode.
525
+ thinking_mode: Either "chat" or "thinking".
526
+ context: Optional preceding context messages (already encoded prefix).
527
+ drop_thinking: If True, drop reasoning_content from earlier assistant turns
528
+ (only keep reasoning for messages after the last user message).
529
+ add_default_bos_token: Whether to prepend BOS token at conversation start.
530
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
531
+
532
+ Returns:
533
+ The encoded prompt string.
534
+ """
535
+ context = context if context else []
536
+
537
+ # Preprocess: merge tool messages and sort tool results
538
+ messages = merge_tool_messages(messages)
539
+ messages = sort_tool_results_by_call_order(context + messages)[len(context):]
540
+ if context:
541
+ context = merge_tool_messages(context)
542
+ context = sort_tool_results_by_call_order(context)
543
+
544
+ full_messages = context + messages
545
+
546
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
547
+
548
+ # Resolve drop_thinking: if any message has tools defined, don't drop thinking
549
+ effective_drop_thinking = drop_thinking
550
+ if any(m.get("tools") for m in full_messages):
551
+ effective_drop_thinking = False
552
+
553
+ if thinking_mode == "thinking" and effective_drop_thinking:
554
+ full_messages = _drop_thinking_messages(full_messages)
555
+ # After dropping, recalculate how many messages to render
556
+ # (context may have shrunk too)
557
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
558
+ context_len = len(full_messages) - num_to_render
559
+ else:
560
+ num_to_render = len(messages)
561
+ context_len = len(context)
562
+
563
+ for idx in range(num_to_render):
564
+ prompt += render_message(
565
+ idx + context_len,
566
+ full_messages,
567
+ thinking_mode=thinking_mode,
568
+ drop_thinking=effective_drop_thinking,
569
+ reasoning_effort=reasoning_effort,
570
+ )
571
+
572
+ return prompt
573
+
574
+
575
+ def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
576
+ """
577
+ Drop reasoning_content and non-essential messages before the last user message.
578
+
579
+ Behavior:
580
+ - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
581
+ - Messages at or after the last user index are always kept.
582
+ - Assistant messages before the last user get reasoning_content removed.
583
+ - Developer messages before the last user are dropped entirely.
584
+ """
585
+ last_user_idx = find_last_user_index(messages)
586
+ result = []
587
+ keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
588
+
589
+ for idx, msg in enumerate(messages):
590
+ role = msg.get("role")
591
+ if role in keep_roles or idx >= last_user_idx:
592
+ result.append(msg)
593
+ elif role == "assistant":
594
+ msg = copy.copy(msg)
595
+ msg.pop("reasoning_content", None)
596
+ result.append(msg)
597
+ # developer and other roles before last_user_idx are dropped
598
+
599
+ return result
600
+
601
+
602
+ # ============================================================
603
+ # Parsing (Decoding model output)
604
+ # ============================================================
605
+
606
+ def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
607
+ """
608
+ Read text from index until one of the stop strings is found.
609
+
610
+ Returns:
611
+ Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
612
+ """
613
+ min_pos = len(text)
614
+ matched_stop = None
615
+
616
+ for s in stop:
617
+ pos = text.find(s, index)
618
+ if pos != -1 and pos < min_pos:
619
+ min_pos = pos
620
+ matched_stop = s
621
+
622
+ if matched_stop:
623
+ content = text[index:min_pos]
624
+ return min_pos + len(matched_stop), content, matched_stop
625
+ else:
626
+ content = text[index:]
627
+ return len(text), content, None
628
+
629
+
630
+ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
631
+ """
632
+ Parse DSML tool calls from text starting at the given index.
633
+
634
+ Args:
635
+ index: Starting position in text.
636
+ text: The full text to parse.
637
+
638
+ Returns:
639
+ Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
640
+ Each tool call dict has "name" and "arguments" keys.
641
+ """
642
+ tool_calls: List[Dict[str, Any]] = []
643
+ stop_token = None
644
+ tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
645
+
646
+ while index < len(text):
647
+ index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
648
+ if _ != ">\n":
649
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
650
+
651
+ if stop_token == tool_calls_end_token:
652
+ break
653
+
654
+ if stop_token is None:
655
+ raise ValueError("Missing special token in tool calls")
656
+
657
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
658
+
659
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
660
+ if len(p_tool_name) != 1:
661
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
662
+ tool_name = p_tool_name[0]
663
+
664
+ tool_args: Dict[str, Tuple[str, str]] = {}
665
+ while stop_token == f"<{dsml_token}parameter":
666
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
667
+
668
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
669
+ if len(param_kv) != 1:
670
+ raise ValueError(f"Parameter format error: '{param_content}'")
671
+ param_name, string, param_value = param_kv[0]
672
+
673
+ if param_name in tool_args:
674
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
675
+ tool_args[param_name] = (param_value, string)
676
+
677
+ index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
678
+ if content != ">\n":
679
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
680
+
681
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
682
+ tool_calls.append(tool_call)
683
+
684
+ return index, stop_token, tool_calls
685
+
686
+
687
+ def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
688
+ """
689
+ Parse a model completion text into a structured assistant message.
690
+
691
+ This function takes the raw text output from the model (a single assistant turn)
692
+ and extracts:
693
+ - reasoning_content (thinking block)
694
+ - content (summary/response)
695
+ - tool_calls (if any)
696
+
697
+ NOTE: This function is designed to parse only correctly formatted strings and
698
+ will raise ValueError for malformed output.
699
+
700
+ Args:
701
+ text: The raw completion text (including EOS token).
702
+ thinking_mode: Either "chat" or "thinking".
703
+
704
+ Returns:
705
+ Dict with keys: "role", "content", "reasoning_content", "tool_calls".
706
+ tool_calls are in OpenAI format.
707
+ """
708
+ summary_content, reasoning_content, tool_calls = "", "", []
709
+ index, stop_token = 0, None
710
+ tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
711
+
712
+ is_thinking = thinking_mode == "thinking"
713
+ is_tool_calling = False
714
+
715
+ if is_thinking:
716
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
717
+ reasoning_content = content_delta
718
+ assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
719
+
720
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
721
+ summary_content = content_delta
722
+ if stop_token == tool_calls_start_token:
723
+ is_tool_calling = True
724
+ else:
725
+ assert stop_token == eos_token, "Invalid format: missing EOS token"
726
+
727
+ if is_tool_calling:
728
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
729
+
730
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
731
+ assert not tool_ends_text, "Unexpected content after tool calls"
732
+
733
+ assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
734
+
735
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
736
+ assert sp_token not in summary_content and sp_token not in reasoning_content, \
737
+ f"Unexpected special token '{sp_token}' in content"
738
+
739
+ return {
740
+ "role": "assistant",
741
+ "content": summary_content,
742
+ "reasoning_content": reasoning_content,
743
+ "tool_calls": tool_calls_to_openai_format(tool_calls)
744
+ }
encoding/test_encoding_dsv4.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test suite for DeepSeek-V4 Encoding.
3
+
4
+ Run: python test_encoding_dsv4.py
5
+ """
6
+
7
+ import json
8
+ import os
9
+
10
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
11
+
12
+ TESTS_DIR = os.path.join(os.path.dirname(__file__), "tests")
13
+
14
+
15
+ def test_case_1():
16
+ """Thinking mode with tool calls (multi-turn, tool results merged into user)."""
17
+ with open(os.path.join(TESTS_DIR, "test_input_1.json")) as f:
18
+ td = json.load(f)
19
+ messages = td["messages"]
20
+ messages[0]["tools"] = td["tools"]
21
+ gold = open(os.path.join(TESTS_DIR, "test_output_1.txt")).read()
22
+ prompt = encode_messages(messages, thinking_mode="thinking")
23
+ assert prompt == gold
24
+
25
+ # Parse: assistant turn with tool call
26
+ marker = "<|Assistant|><think>"
27
+ first_start = prompt.find(marker) + len(marker)
28
+ first_end = prompt.find("<|User|>", first_start)
29
+ parsed_tc = parse_message_from_completion_text(prompt[first_start:first_end], thinking_mode="thinking")
30
+ assert parsed_tc["reasoning_content"] == "The user wants to know the weather in Beijing. I should use the get_weather tool."
31
+ assert parsed_tc["content"] == ""
32
+ assert len(parsed_tc["tool_calls"]) == 1
33
+ assert parsed_tc["tool_calls"][0]["function"]["name"] == "get_weather"
34
+ assert json.loads(parsed_tc["tool_calls"][0]["function"]["arguments"]) == {"location": "Beijing", "unit": "celsius"}
35
+
36
+ # Parse: final assistant turn with content
37
+ last_start = prompt.rfind(marker) + len(marker)
38
+ parsed_final = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
39
+ assert parsed_final["reasoning_content"] == "Got the weather data. Let me format a nice response."
40
+ assert "22°C" in parsed_final["content"]
41
+ assert parsed_final["tool_calls"] == []
42
+
43
+ print(" [PASS] case 1: thinking with tools (encode + parse)")
44
+
45
+
46
+ def test_case_2():
47
+ """Thinking mode without tools (drop_thinking removes earlier reasoning)."""
48
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_2.json")))
49
+ gold = open(os.path.join(TESTS_DIR, "test_output_2.txt")).read()
50
+ prompt = encode_messages(messages, thinking_mode="thinking")
51
+ assert prompt == gold
52
+
53
+ # Parse: last assistant turn
54
+ marker = "<|Assistant|><think>"
55
+ last_start = prompt.rfind(marker) + len(marker)
56
+ parsed = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
57
+ assert parsed["reasoning_content"] == "The user asks about the capital of France. It is Paris."
58
+ assert parsed["content"] == "The capital of France is Paris."
59
+ assert parsed["tool_calls"] == []
60
+
61
+ # Verify drop_thinking: first assistant's reasoning should be absent
62
+ assert "The user said hello" not in prompt
63
+
64
+ print(" [PASS] case 2: thinking without tools (encode + parse)")
65
+
66
+
67
+ def test_case_3():
68
+ """Interleaved thinking + search (developer with tools, latest_reminder)."""
69
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_3.json")))
70
+ gold = open(os.path.join(TESTS_DIR, "test_output_3.txt")).read()
71
+ assert encode_messages(messages, thinking_mode="thinking") == gold
72
+ print(" [PASS] case 3: interleaved thinking + search")
73
+
74
+
75
+ def test_case_4():
76
+ """Quick instruction task with latest_reminder (chat mode, action task)."""
77
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_4.json")))
78
+ gold = open(os.path.join(TESTS_DIR, "test_output_4.txt")).read()
79
+ assert encode_messages(messages, thinking_mode="chat") == gold
80
+ print(" [PASS] case 4: quick instruction task")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ print("Running DeepSeek-V4 Encoding Tests...\n")
85
+ test_case_1()
86
+ test_case_2()
87
+ test_case_3()
88
+ test_case_4()
89
+ print("\nAll 4 tests passed!")
encoding/tests/test_input_1.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tools": [
3
+ {
4
+ "type": "function",
5
+ "function": {
6
+ "name": "get_weather",
7
+ "description": "Get the weather for a specific location",
8
+ "parameters": {
9
+ "type": "object",
10
+ "properties": {
11
+ "location": {
12
+ "type": "string",
13
+ "description": "The city name"
14
+ },
15
+ "unit": {
16
+ "type": "string",
17
+ "enum": ["celsius", "fahrenheit"],
18
+ "description": "Temperature unit"
19
+ }
20
+ },
21
+ "required": ["location"]
22
+ }
23
+ }
24
+ },
25
+ {
26
+ "type": "function",
27
+ "function": {
28
+ "name": "search",
29
+ "description": "Search the web for information",
30
+ "parameters": {
31
+ "type": "object",
32
+ "properties": {
33
+ "query": {
34
+ "type": "string",
35
+ "description": "Search query"
36
+ },
37
+ "num_results": {
38
+ "type": "integer",
39
+ "description": "Number of results to return"
40
+ }
41
+ },
42
+ "required": ["query"]
43
+ }
44
+ }
45
+ }
46
+ ],
47
+ "messages": [
48
+ {
49
+ "role": "system",
50
+ "content": "You are a helpful assistant."
51
+ },
52
+ {
53
+ "role": "user",
54
+ "content": "What's the weather in Beijing?"
55
+ },
56
+ {
57
+ "role": "assistant",
58
+ "reasoning_content": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
59
+ "tool_calls": [
60
+ {
61
+ "id": "call_001",
62
+ "type": "function",
63
+ "function": {
64
+ "name": "get_weather",
65
+ "arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
66
+ }
67
+ }
68
+ ]
69
+ },
70
+ {
71
+ "role": "tool",
72
+ "tool_call_id": "call_001",
73
+ "content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
74
+ },
75
+ {
76
+ "role": "assistant",
77
+ "reasoning_content": "Got the weather data. Let me format a nice response.",
78
+ "content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
79
+ }
80
+ ]
81
+ }
encoding/tests/test_input_2.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "You are a helpful assistant."
5
+ },
6
+ {
7
+ "role": "user",
8
+ "content": "Hello"
9
+ },
10
+ {
11
+ "role": "assistant",
12
+ "reasoning_content": "The user said hello, I should greet back.",
13
+ "content": "Hi there! How can I help you?"
14
+ },
15
+ {
16
+ "role": "user",
17
+ "content": "What is the capital of France?"
18
+ },
19
+ {
20
+ "role": "assistant",
21
+ "reasoning_content": "The user asks about the capital of France. It is Paris.",
22
+ "content": "The capital of France is Paris."
23
+ }
24
+ ]
encoding/tests/test_input_3.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "该助手为DeepSeek,由深度求索公司创造。"
5
+ },
6
+ {
7
+ "role": "latest_reminder",
8
+ "content": "2026-02-21,星期六,广州,App,中文"
9
+ },
10
+ {
11
+ "role": "developer",
12
+ "content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
13
+ "tools": [
14
+ {
15
+ "type": "function",
16
+ "function": {
17
+ "name": "search",
18
+ "description": "Web search. Split multiple queries with '||'.",
19
+ "parameters": {
20
+ "type": "object",
21
+ "properties": {
22
+ "queries": {
23
+ "type": "string",
24
+ "description": "query1||query2"
25
+ }
26
+ },
27
+ "required": [
28
+ "queries"
29
+ ],
30
+ "additionalProperties": false,
31
+ "$schema": "http://json-schema.org/draft-07/schema#"
32
+ }
33
+ }
34
+ },
35
+ {
36
+ "type": "function",
37
+ "function": {
38
+ "name": "open",
39
+ "description": "Batch open IDs (format 【{id}†...】) or URLs.",
40
+ "parameters": {
41
+ "type": "object",
42
+ "properties": {
43
+ "open_list": {
44
+ "type": "array",
45
+ "items": {
46
+ "type": "object",
47
+ "properties": {
48
+ "id": {
49
+ "description": "ID or URL",
50
+ "anyOf": [
51
+ {
52
+ "type": "integer"
53
+ },
54
+ {
55
+ "type": "string"
56
+ }
57
+ ],
58
+ "default": -1
59
+ },
60
+ "cursor": {
61
+ "type": "integer",
62
+ "description": "",
63
+ "default": -1
64
+ },
65
+ "loc": {
66
+ "type": "integer",
67
+ "description": "Start line",
68
+ "default": -1
69
+ },
70
+ "num_lines": {
71
+ "type": "integer",
72
+ "description": "",
73
+ "default": -1
74
+ },
75
+ "view_source": {
76
+ "type": "boolean",
77
+ "description": "",
78
+ "default": false
79
+ }
80
+ },
81
+ "additionalProperties": false
82
+ },
83
+ "description": ""
84
+ }
85
+ },
86
+ "required": [
87
+ "open_list"
88
+ ],
89
+ "additionalProperties": false,
90
+ "$schema": "http://json-schema.org/draft-07/schema#"
91
+ }
92
+ }
93
+ },
94
+ {
95
+ "type": "function",
96
+ "function": {
97
+ "name": "find",
98
+ "description": "Find exact text pattern in pages.",
99
+ "parameters": {
100
+ "type": "object",
101
+ "properties": {
102
+ "find_list": {
103
+ "type": "array",
104
+ "items": {
105
+ "type": "object",
106
+ "properties": {
107
+ "pattern": {
108
+ "type": "string",
109
+ "description": ""
110
+ },
111
+ "cursor": {
112
+ "type": "integer",
113
+ "description": "",
114
+ "default": -1
115
+ }
116
+ },
117
+ "required": [
118
+ "pattern"
119
+ ],
120
+ "additionalProperties": false
121
+ },
122
+ "description": ""
123
+ }
124
+ },
125
+ "required": [
126
+ "find_list"
127
+ ],
128
+ "additionalProperties": false,
129
+ "$schema": "http://json-schema.org/draft-07/schema#"
130
+ }
131
+ }
132
+ }
133
+ ]
134
+ },
135
+ {
136
+ "role": "assistant",
137
+ "content": "",
138
+ "reasoning_content": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
139
+ "tool_calls": [
140
+ {
141
+ "type": "function",
142
+ "function": {
143
+ "name": "search",
144
+ "arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
145
+ }
146
+ }
147
+ ]
148
+ },
149
+ {
150
+ "role": "tool",
151
+ "content": "[0]"
152
+ },
153
+ {
154
+ "role": "assistant",
155
+ "content": "请及时就医。",
156
+ "reasoning_content": "现在开始组织回答。",
157
+ "tool_calls": []
158
+ }
159
+ ]
encoding/tests/test_input_4.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
5
+ },
6
+ {
7
+ "role": "latest_reminder",
8
+ "content": "2024-11-15,上海市,App,中文"
9
+ },
10
+ {
11
+ "role": "user",
12
+ "content": "热海大滚锅是世界著名温泉吗"
13
+ },
14
+ {
15
+ "role": "assistant",
16
+ "content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
17
+ "mask": 1
18
+ },
19
+ {
20
+ "role": "user",
21
+ "content": "世界著名温泉有哪些",
22
+ "task": "action"
23
+ },
24
+ {
25
+ "role": "assistant",
26
+ "content": "Search"
27
+ }
28
+ ]
encoding/tests/test_output_1.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>You are a helpful assistant.
2
+
3
+ ## Tools
4
+
5
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
6
+
7
+ <|DSML|tool_calls>
8
+ <|DSML|invoke name="$TOOL_NAME">
9
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
10
+ ...
11
+ </|DSML|invoke>
12
+ <|DSML|invoke name="$TOOL_NAME2">
13
+ ...
14
+ </|DSML|invoke>
15
+ </|DSML|tool_calls>
16
+
17
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
18
+
19
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
20
+
21
+ Otherwise, output directly after </think> with tool calls or final response.
22
+
23
+ ### Available Tool Schemas
24
+
25
+ {"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
26
+ {"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
27
+
28
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
29
+ <|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
30
+
31
+ <|DSML|tool_calls>
32
+ <|DSML|invoke name="get_weather">
33
+ <|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
34
+ <|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
35
+ </|DSML|invoke>
36
+ </|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
encoding/tests/test_output_2.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ <|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
encoding/tests/test_output_3.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
2
+
3
+ CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
4
+
5
+ ## Tools
6
+
7
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
8
+
9
+ <|DSML|tool_calls>
10
+ <|DSML|invoke name="$TOOL_NAME">
11
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
12
+ ...
13
+ </|DSML|invoke>
14
+ <|DSML|invoke name="$TOOL_NAME2">
15
+ ...
16
+ </|DSML|invoke>
17
+ </|DSML|tool_calls>
18
+
19
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
20
+
21
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
22
+
23
+ Otherwise, output directly after </think> with tool calls or final response.
24
+
25
+ ### Available Tool Schemas
26
+
27
+ {"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
28
+ {"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
29
+ {"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
30
+
31
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
32
+ <|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
33
+
34
+ <|DSML|tool_calls>
35
+ <|DSML|invoke name="search">
36
+ <|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
37
+ </|DSML|invoke>
38
+ </|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
encoding/tests/test_output_4.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
2
+ 今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
3
+
4
+ 为了让你快速了解,我整理了一个简要的表格:
5
+
6
+ | 维度 | 热海大滚锅的地位与特点 |
7
+ | :--- | :--- |
8
+ | **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
9
+ | **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
10
+ | **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
11
+ | **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
12
+
13
+ ### 💡 游玩攻略与温馨提示
14
+
15
+ 如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
16
+
17
+ - **门票与开放时间**:
18
+ - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
19
+ - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
20
+
21
+ - **特色体验**:
22
+ - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
23
+ - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
24
+
25
+ - **注意事项**:
26
+ - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
27
+ - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
28
+
29
+ 希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "do_sample": true,
6
+ "temperature": 1.0,
7
+ "top_p": 1.0,
8
+ "transformers_version": "4.46.3"
9
+ }
inference/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code for DeepSeek models
2
+
3
+ First convert huggingface model weight files to the format of this project.
4
+ ```bash
5
+ export EXPERTS=384
6
+ export MP=8
7
+ export CONFIG=config_w4a16.json
8
+ python convert_w4a16.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
9
+ ```
10
+
11
+ Then chat with DeepSeek model at will!
12
+ ```bash
13
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
14
+ ```
15
+
16
+ Or batch inference from file.
17
+ ```bash
18
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
19
+ ```
20
+
21
+ Or multi nodes inference.
22
+ ```bash
23
+ torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
24
+ ```
25
+
inference/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 129280,
3
+ "dim": 7168,
4
+ "moe_inter_dim": 3072,
5
+ "n_layers": 61,
6
+ "n_hash_layers": 3,
7
+ "n_heads": 128,
8
+ "n_routed_experts": 384,
9
+ "n_shared_experts": 1,
10
+ "n_activated_experts": 6,
11
+ "score_func": "sqrtsoftplus",
12
+ "route_scale": 2.5,
13
+ "swiglu_limit": 10.0,
14
+ "q_lora_rank": 1536,
15
+ "head_dim": 512,
16
+ "rope_head_dim": 64,
17
+ "o_groups": 16,
18
+ "o_lora_rank": 1024,
19
+ "window_size": 128,
20
+ "original_seq_len": 65536,
21
+ "rope_theta": 10000,
22
+ "rope_factor": 16,
23
+ "beta_fast": 32,
24
+ "beta_slow": 1,
25
+ "index_n_heads": 64,
26
+ "index_head_dim": 128,
27
+ "index_topk": 1024,
28
+ "hc_mult": 4,
29
+ "hc_sinkhorn_iters": 20,
30
+ "dtype": "fp8",
31
+ "scale_fmt": "ue8m0",
32
+ "expert_dtype": "fp4",
33
+ "compress_rope_theta": 160000,
34
+ "compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
35
+ }
inference/config_w4a16.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "vocab_size": 129280,
4
+ "dim": 7168,
5
+ "moe_inter_dim": 3072,
6
+ "n_layers": 61,
7
+ "n_hash_layers": 3,
8
+ "n_heads": 128,
9
+ "n_routed_experts": 384,
10
+ "n_shared_experts": 1,
11
+ "n_activated_experts": 6,
12
+ "score_func": "sqrtsoftplus",
13
+ "route_scale": 2.5,
14
+ "swiglu_limit": 10.0,
15
+ "q_lora_rank": 1536,
16
+ "head_dim": 512,
17
+ "rope_head_dim": 64,
18
+ "o_groups": 16,
19
+ "o_lora_rank": 1024,
20
+ "window_size": 128,
21
+ "original_seq_len": 65536,
22
+ "rope_theta": 10000,
23
+ "rope_factor": 16,
24
+ "beta_fast": 32,
25
+ "beta_slow": 1,
26
+ "index_n_heads": 64,
27
+ "index_head_dim": 128,
28
+ "index_topk": 1024,
29
+ "hc_mult": 4,
30
+ "hc_sinkhorn_iters": 20,
31
+ "dtype": "w4a16",
32
+ "compress_rope_theta": 160000,
33
+ "compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
34
+ }
inference/convert.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from argparse import ArgumentParser
4
+ from glob import glob
5
+ from tqdm import tqdm, trange
6
+
7
+ import torch
8
+ from safetensors.torch import safe_open, save_file
9
+
10
+
11
+ FP4_TABLE = torch.tensor([
12
+ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
13
+ 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
14
+ ], dtype=torch.float32)
15
+
16
+
17
+ def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Casts a tensor from e2m1fn to e4m3fn losslessly.
20
+ """
21
+ assert x.dtype == torch.int8
22
+ assert x.ndim == 2
23
+ out_dim, in_dim = x.size()
24
+ in_dim *= 2
25
+ fp8_block_size = 128
26
+ fp4_block_size = 32
27
+ assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
28
+ assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
29
+
30
+ x = x.view(torch.uint8)
31
+ low = x & 0x0F
32
+ high = (x >> 4) & 0x0F
33
+ x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
34
+
35
+ # max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
36
+ # 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
37
+ MAX_OFFSET_BITS = 6
38
+
39
+ bOut = out_dim // fp8_block_size
40
+ bIn = in_dim // fp8_block_size
41
+ # bOut, bIn, 128, 128
42
+ x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
43
+ # bOut, bIn, 128*4
44
+ scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
45
+ ## bOut, bIn, 1
46
+ scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
47
+ # bOut, bIn, 128*4
48
+ offset = scale / scale_max_offset_bits
49
+ # bOut, bIn, 128, 128
50
+ offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
51
+ x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
52
+ return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
53
+
54
+
55
+ mapping = {
56
+ "embed_tokens": ("embed", 0),
57
+ "input_layernorm": ("attn_norm", None),
58
+ "post_attention_layernorm": ("ffn_norm", None),
59
+ "q_proj": ("wq", 0),
60
+ "q_a_proj": ("wq_a", None),
61
+ "q_a_layernorm": ("q_norm", None),
62
+ "q_b_proj": ("wq_b", 0),
63
+ "kv_a_proj_with_mqa": ("wkv_a", None),
64
+ "kv_a_layernorm": ("kv_norm", None),
65
+ "kv_b_proj": ("wkv_b", 0),
66
+ "o_proj": ("wo", 1),
67
+ "gate_proj": ("w1", 0),
68
+ "down_proj": ("w2", 1),
69
+ "up_proj": ("w3", 0),
70
+ "lm_head": ("head", 0),
71
+
72
+ "embed": ("embed", 0),
73
+ "wq_b": ("wq_b", 0),
74
+ "wo_a": ("wo_a", 0),
75
+ "wo_b": ("wo_b", 1),
76
+ "head": ("head", 0),
77
+ "attn_sink": ("attn_sink", 0),
78
+ "weights_proj": ("weights_proj", 0),
79
+ }
80
+
81
+
82
+ def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
83
+ """
84
+ Converts and saves model checkpoint files into a specified format.
85
+
86
+ Args:
87
+ hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
88
+ save_path (str): Path to the directory where the converted checkpoint files will be saved.
89
+ n_experts (int): Total number of experts in the model.
90
+ mp (int): Model parallelism factor.
91
+
92
+ Returns:
93
+ None
94
+ """
95
+ torch.set_num_threads(8)
96
+ n_local_experts = n_experts // mp
97
+ state_dicts = [{} for _ in range(mp)]
98
+
99
+ for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
100
+ with safe_open(file_path, framework="pt", device="cpu") as f:
101
+ for name in f.keys():
102
+ param: torch.Tensor = f.get_tensor(name)
103
+ if name.startswith("model."):
104
+ name = name[len("model."):]
105
+ if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
106
+ continue
107
+ name = name.replace("self_attn", "attn")
108
+ name = name.replace("mlp", "ffn")
109
+ name = name.replace("weight_scale_inv", "scale")
110
+ name = name.replace("e_score_correction_bias", "bias")
111
+ if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
112
+ key = name.split(".")[-1]
113
+ else:
114
+ key = name.split(".")[-2]
115
+ if key in mapping:
116
+ new_key, dim = mapping[key]
117
+ else:
118
+ new_key, dim = key, None
119
+ name = name.replace(key, new_key)
120
+ for i in range(mp):
121
+ new_param = param
122
+ if "experts" in name and "shared_experts" not in name:
123
+ idx = int(name.split(".")[-3])
124
+ if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
125
+ continue
126
+ elif dim is not None:
127
+ assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
128
+ shard_size = param.size(dim) // mp
129
+ new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
130
+ state_dicts[i][name] = new_param
131
+
132
+ os.makedirs(save_path, exist_ok=True)
133
+
134
+ for i in trange(mp):
135
+ names = list(state_dicts[i].keys())
136
+ for name in names:
137
+ if name.endswith("wo_a.weight"):
138
+ weight = state_dicts[i][name]
139
+ scale = state_dicts[i].pop(name.replace("weight", "scale"))
140
+ weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
141
+ state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
142
+ elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
143
+ if expert_dtype == "fp8":
144
+ scale_name = name.replace("weight", "scale")
145
+ weight = state_dicts[i].pop(name)
146
+ scale = state_dicts[i].pop(scale_name)
147
+ state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
148
+ else:
149
+ state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
150
+ save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
151
+
152
+ for file in ["tokenizer.json", "tokenizer_config.json"]:
153
+ old_file_path = os.path.join(hf_ckpt_path, file)
154
+ new_file_path = os.path.join(save_path, file)
155
+ if os.path.exists(old_file_path):
156
+ shutil.copyfile(old_file_path, new_file_path)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = ArgumentParser()
161
+ parser.add_argument("--hf-ckpt-path", type=str, required=True)
162
+ parser.add_argument("--save-path", type=str, required=True)
163
+ parser.add_argument("--n-experts", type=int, required=True)
164
+ parser.add_argument("--model-parallel", type=int, required=True)
165
+ parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
166
+ args = parser.parse_args()
167
+ assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
168
+ main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
inference/convert_w4a16.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert an auto-round / GPTQ W4A16 packed HuggingFace checkpoint of DeepSeek-V4
2
+ into the MP-sharded local format consumed by `model.py`/`generate.py`.
3
+
4
+ Packing convention (auto-round → auto_gptq):
5
+ - qweight : int32 [in_features // 8, out_features], LSB-first 4-bit packed along dim 0
6
+ - qzeros : int32 [in_features // group_size, out_features // 8], LSB-first 4-bit packed along dim 1
7
+ - scales : fp16 [in_features // group_size, out_features]
8
+
9
+ Sharding rules per linear:
10
+ - ColumnParallel (shard output dim, original `dim=0` in `mapping`):
11
+ qweight along dim 1; qzeros along dim 1 (must be divisible by 8 first, then by world_size);
12
+ scales along dim 1.
13
+ - RowParallel (shard input dim, original `dim=1` in `mapping`):
14
+ qweight along dim 0 (must be divisible by 8 first, then by world_size);
15
+ qzeros along dim 0 (must be divisible by group_size first, then by world_size);
16
+ scales along dim 0.
17
+
18
+ Non-quantised tensors (embed.weight, *.norm.weight, attn_sink, hc_*, ape, gate.bias,
19
+ gate.tid2eid, etc.) follow the same rules as the original `convert.py`.
20
+ """
21
+
22
+ import os
23
+ import shutil
24
+ from argparse import ArgumentParser
25
+ from glob import glob
26
+ from tqdm import tqdm, trange
27
+
28
+ import torch
29
+ from safetensors.torch import safe_open, save_file
30
+
31
+
32
+ GROUP_SIZE = 128
33
+
34
+ # Same name remapping as the original convert.py
35
+ mapping = {
36
+ "embed_tokens": ("embed", 0),
37
+ "input_layernorm": ("attn_norm", None),
38
+ "post_attention_layernorm": ("ffn_norm", None),
39
+ "q_proj": ("wq", 0),
40
+ "q_a_proj": ("wq_a", None),
41
+ "q_a_layernorm": ("q_norm", None),
42
+ "q_b_proj": ("wq_b", 0),
43
+ "kv_a_proj_with_mqa": ("wkv_a", None),
44
+ "kv_a_layernorm": ("kv_norm", None),
45
+ "kv_b_proj": ("wkv_b", 0),
46
+ "o_proj": ("wo", 1),
47
+ "gate_proj": ("w1", 0),
48
+ "down_proj": ("w2", 1),
49
+ "up_proj": ("w3", 0),
50
+ "lm_head": ("head", 0),
51
+
52
+ # Already-translated names (used by the inference checkpoints we already have)
53
+ "embed": ("embed", 0),
54
+ "wq_a": ("wq_a", None),
55
+ "wq_b": ("wq_b", 0),
56
+ "wkv": ("wkv", None),
57
+ "wo_a": ("wo_a", 0),
58
+ "wo_b": ("wo_b", 1),
59
+ "w1": ("w1", 0),
60
+ "w2": ("w2", 1),
61
+ "w3": ("w3", 0),
62
+ "head": ("head", 0),
63
+ "weights_proj": ("weights_proj", 0),
64
+ # special non-weight keys
65
+ "attn_sink": ("attn_sink", 0),
66
+ "ape": ("ape", None),
67
+ # NOTE: 'gate' is intentionally NOT in this mapping -- the routing gate is a
68
+ # plain nn.Parameter that is replicated on every rank.
69
+ }
70
+
71
+
72
+ # Suffixes that mark the three pieces of a packed W4A16 linear.
73
+ QUANT_SUFFIXES = (".qweight", ".qzeros", ".scales")
74
+
75
+
76
+ def shard_quant(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
77
+ shard_dim: int, mp: int):
78
+ """Yield (qweight_i, qzeros_i, scales_i) for i in range(mp).
79
+
80
+ shard_dim is the *logical* dim of the dequantised weight: 0 == output (column parallel),
81
+ 1 == input (row parallel)."""
82
+ out = qweight.size(1)
83
+ in_packed = qweight.size(0) # in_features // 8
84
+ n_groups = scales.size(0) # in_features // group_size
85
+
86
+ if shard_dim == 0: # ColumnParallel: shard along OUTPUT
87
+ assert out % mp == 0, f"out={out} not divisible by mp={mp}"
88
+ # qzeros packs 8 outputs per int32 in dim 1, so need (out/mp) % 8 == 0
89
+ assert (out // mp) % 8 == 0, f"shard {out//mp} of out dim not divisible by 8 (qzeros packing)"
90
+ sh_out = out // mp
91
+ sh_qz_cols = qzeros.size(1) // mp # == out / 8 / mp
92
+ for i in range(mp):
93
+ yield (
94
+ qweight.narrow(1, i * sh_out, sh_out).contiguous(),
95
+ qzeros.narrow(1, i * sh_qz_cols, sh_qz_cols).contiguous(),
96
+ scales.narrow(1, i * sh_out, sh_out).contiguous(),
97
+ )
98
+ elif shard_dim == 1: # RowParallel: shard along INPUT
99
+ # qweight packs 8 inputs per int32 in dim 0, scales/qzeros are per-group on dim 0
100
+ assert in_packed % mp == 0, f"in_packed={in_packed} not divisible by mp={mp}"
101
+ assert n_groups % mp == 0, f"n_groups={n_groups} not divisible by mp={mp}"
102
+ sh_in_packed = in_packed // mp
103
+ sh_groups = n_groups // mp
104
+ for i in range(mp):
105
+ yield (
106
+ qweight.narrow(0, i * sh_in_packed, sh_in_packed).contiguous(),
107
+ qzeros.narrow(0, i * sh_groups, sh_groups).contiguous(),
108
+ scales.narrow(0, i * sh_groups, sh_groups).contiguous(),
109
+ )
110
+ else:
111
+ # Replicate
112
+ for _ in range(mp):
113
+ yield qweight, qzeros, scales
114
+
115
+
116
+ def get_layer_key(name: str):
117
+ """Return the linear-name token (e.g. wq_a, w1, head) used for the rename mapping."""
118
+ parts = name.split(".")
119
+ if name.endswith(QUANT_SUFFIXES):
120
+ return parts[-2] # ...x.qweight -> x
121
+ if name.endswith(".bias") and "gate" in name:
122
+ return "gate" # ffn.gate.bias
123
+ if name.endswith(".tid2eid"):
124
+ return "gate"
125
+ if any(k in parts for k in ("hc_attn_fn", "hc_attn_base", "hc_attn_scale",
126
+ "hc_ffn_fn", "hc_ffn_base", "hc_ffn_scale",
127
+ "hc_head_fn", "hc_head_base", "hc_head_scale",
128
+ "attn_sink", "ape")):
129
+ return parts[-1]
130
+ return parts[-2]
131
+
132
+
133
+ def main(hf_ckpt_path, save_path, n_experts, mp):
134
+ torch.set_num_threads(8)
135
+ n_local_experts = n_experts // mp
136
+ state_dicts = [{} for _ in range(mp)]
137
+
138
+ # Group all fragments belonging to the same logical linear so we can shard
139
+ # qweight/qzeros/scales together.
140
+ pending: dict[str, dict[str, torch.Tensor]] = {}
141
+
142
+ def emit_linear(base_name: str, parts: dict[str, torch.Tensor], shard_dim):
143
+ """Distribute a quantised linear (3 tensors) across `mp` shards."""
144
+ qweight = parts["qweight"]
145
+ qzeros = parts["qzeros"]
146
+ scales = parts["scales"].to(torch.bfloat16) # store bf16 instead of fp16
147
+ # Expert-local pruning: only the rank that owns this expert keeps the tensors.
148
+ if "experts" in base_name and "shared_experts" not in base_name:
149
+ idx = int(base_name.split(".experts.")[1].split(".")[0])
150
+ owner = idx // n_local_experts
151
+ state_dicts[owner][base_name + ".qweight"] = qweight
152
+ state_dicts[owner][base_name + ".qzeros"] = qzeros
153
+ state_dicts[owner][base_name + ".scales"] = scales
154
+ return
155
+ if shard_dim is None:
156
+ # Replicate across all ranks
157
+ for i in range(mp):
158
+ state_dicts[i][base_name + ".qweight"] = qweight
159
+ state_dicts[i][base_name + ".qzeros"] = qzeros
160
+ state_dicts[i][base_name + ".scales"] = scales
161
+ else:
162
+ for i, (qw, qz, sc) in enumerate(shard_quant(qweight, qzeros, scales, shard_dim, mp)):
163
+ state_dicts[i][base_name + ".qweight"] = qw
164
+ state_dicts[i][base_name + ".qzeros"] = qz
165
+ state_dicts[i][base_name + ".scales"] = sc
166
+
167
+ files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
168
+ for file_path in tqdm(files, desc="files"):
169
+ with safe_open(file_path, framework="pt", device="cpu") as f:
170
+ for orig_name in f.keys():
171
+ # ----- name remapping (mirrors original convert.py) -----
172
+ name = orig_name
173
+ if name.startswith("model."):
174
+ name = name[len("model."):]
175
+ if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
176
+ continue
177
+ name = name.replace("self_attn", "attn")
178
+ name = name.replace("mlp", "ffn")
179
+ name = name.replace("weight_scale_inv", "scale")
180
+ name = name.replace("e_score_correction_bias", "bias")
181
+
182
+ key = get_layer_key(name)
183
+ if key in mapping:
184
+ new_key, dim = mapping[key]
185
+ name = name.replace(key, new_key)
186
+ else:
187
+ dim = None
188
+
189
+ tensor = f.get_tensor(orig_name)
190
+
191
+ # ----- handle the three-piece quantised linear -----
192
+ # `shared_experts` are plain (non-parallel) Linears in the model;
193
+ # never shard them even though `w1/w2/w3` are in the mapping.
194
+ if "shared_experts" in name:
195
+ dim = None
196
+
197
+ if orig_name.endswith(QUANT_SUFFIXES):
198
+ base = name.rsplit(".", 1)[0]
199
+ suf = name.rsplit(".", 1)[1] # qweight|qzeros|scales
200
+ pending.setdefault(base, {"_dim": dim})[suf] = tensor
201
+ pending[base]["_dim"] = dim
202
+ parts = pending[base]
203
+ if all(s in parts for s in ("qweight", "qzeros", "scales")):
204
+ emit_linear(base, parts, parts["_dim"])
205
+ del pending[base]
206
+ continue
207
+
208
+ # ----- non-quantised tensor -----
209
+ if "experts" in name and "shared_experts" not in name:
210
+ idx = int(name.split(".experts.")[1].split(".")[0])
211
+ owner = idx // n_local_experts
212
+ state_dicts[owner][name] = tensor
213
+ continue
214
+
215
+ if dim is None:
216
+ for i in range(mp):
217
+ state_dicts[i][name] = tensor
218
+ else:
219
+ assert tensor.size(dim) % mp == 0, f"{name} dim {dim} ({tensor.size(dim)}) not divisible by {mp}"
220
+ sh = tensor.size(dim) // mp
221
+ for i in range(mp):
222
+ state_dicts[i][name] = tensor.narrow(dim, i * sh, sh).contiguous()
223
+
224
+ if pending:
225
+ raise RuntimeError(f"Incomplete quantised linears: {list(pending)[:5]}")
226
+
227
+ os.makedirs(save_path, exist_ok=True)
228
+ for i in trange(mp, desc="write shards"):
229
+ save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
230
+
231
+ for fn in ["tokenizer.json", "tokenizer_config.json"]:
232
+ src = os.path.join(hf_ckpt_path, fn)
233
+ dst = os.path.join(save_path, fn)
234
+ if os.path.exists(src):
235
+ shutil.copyfile(src, dst)
236
+
237
+
238
+ if __name__ == "__main__":
239
+ p = ArgumentParser()
240
+ p.add_argument("--hf-ckpt-path", required=True)
241
+ p.add_argument("--save-path", required=True)
242
+ p.add_argument("--n-experts", type=int, required=True)
243
+ p.add_argument("--model-parallel", type=int, required=True)
244
+ a = p.parse_args()
245
+ assert a.n_experts % a.model_parallel == 0
246
+ main(a.hf_ckpt_path, a.save_path, a.n_experts, a.model_parallel)
inference/generate.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Prevent gptqmodel from setting CUDA_DEVICE_ORDER=PCI_BUS_ID (breaks multi-GPU on some systems)
3
+ os.environ.setdefault("CUDA_DEVICE_ORDER", "FASTEST_FIRST")
4
+ import json
5
+ import sys
6
+ from argparse import ArgumentParser
7
+ from typing import List
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ from transformers import AutoTokenizer
12
+ from safetensors.torch import load_model
13
+
14
+ from model import Transformer, ModelArgs
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ encoding_dir = os.path.join(current_dir, '../encoding')
17
+ sys.path.insert(0, os.path.abspath(encoding_dir))
18
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
19
+
20
+
21
+ def sample(logits, temperature: float = 1.0):
22
+ """Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
23
+ since it avoids the GPU-to-CPU sync in torch.multinomial."""
24
+ logits = logits / max(temperature, 1e-5)
25
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
26
+ return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
27
+
28
+
29
+ @torch.inference_mode()
30
+ def generate(
31
+ model: Transformer,
32
+ prompt_tokens: List[List[int]],
33
+ max_new_tokens: int,
34
+ eos_id: int,
35
+ temperature: float = 1.0
36
+ ) -> List[List[int]]:
37
+ """Batch generation with left-padded prompts.
38
+
39
+ The first forward pass processes [min_prompt_len:] tokens (prefill phase).
40
+ Subsequent passes generate one token at a time (decode phase). For positions
41
+ still within a prompt, the ground-truth token overrides the model's prediction.
42
+ """
43
+ prompt_lens = [len(t) for t in prompt_tokens]
44
+ assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
45
+ total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
46
+ tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
47
+ for i, t in enumerate(prompt_tokens):
48
+ tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
49
+ prev_pos = 0
50
+ finished = torch.tensor([False] * len(prompt_tokens))
51
+ prompt_mask = tokens != -1
52
+ for cur_pos in range(min(prompt_lens), total_len):
53
+ logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
54
+ if temperature > 0:
55
+ next_token = sample(logits, temperature)
56
+ else:
57
+ next_token = logits.argmax(dim=-1)
58
+ next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
59
+ tokens[:, cur_pos] = next_token
60
+ finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
61
+ prev_pos = cur_pos
62
+ if finished.all():
63
+ break
64
+ completion_tokens = []
65
+ for i, toks in enumerate(tokens.tolist()):
66
+ toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
67
+ if eos_id in toks:
68
+ toks = toks[:toks.index(eos_id)]
69
+ toks.append(eos_id)
70
+ completion_tokens.append(toks)
71
+ return completion_tokens
72
+
73
+
74
+ def main(
75
+ ckpt_path: str,
76
+ config: str,
77
+ input_file: str = "",
78
+ interactive: bool = True,
79
+ max_new_tokens: int = 100,
80
+ temperature: float = 1.0,
81
+ ) -> None:
82
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
83
+ rank = int(os.getenv("RANK", "0"))
84
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
85
+ if world_size > 1:
86
+ dist.init_process_group("nccl")
87
+ global print
88
+ if rank != 0:
89
+ print = lambda *_, **__: None
90
+ torch.cuda.set_device(local_rank)
91
+ torch.cuda.memory._set_allocator_settings("expandable_segments:True")
92
+ torch.set_default_dtype(torch.bfloat16)
93
+ torch.set_num_threads(8)
94
+ torch.manual_seed(33377335)
95
+ with open(config) as f:
96
+ args = ModelArgs(**json.load(f))
97
+ if interactive:
98
+ args.max_batch_size = 1
99
+ print(args)
100
+ with torch.device("cuda"):
101
+ model = Transformer(args)
102
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
103
+ print("load model")
104
+ load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
105
+ if args.dtype == "w4a16":
106
+ model.init_woq_layers()
107
+ torch.set_default_device("cuda")
108
+ print("I'm DeepSeek 👋")
109
+
110
+ if interactive:
111
+ messages = []
112
+ while True:
113
+ if world_size == 1:
114
+ prompt = input(">>> ")
115
+ elif rank == 0:
116
+ prompt = input(">>> ")
117
+ objects = [prompt]
118
+ dist.broadcast_object_list(objects, 0)
119
+ else:
120
+ objects = [None]
121
+ dist.broadcast_object_list(objects, 0)
122
+ prompt = objects[0]
123
+ if prompt == "/exit":
124
+ break
125
+ elif prompt == "/clear":
126
+ messages.clear()
127
+ continue
128
+ messages.append({"role": "user", "content": prompt})
129
+ prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
130
+ completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
131
+ completion = tokenizer.decode(completion_tokens[0])
132
+ print(completion)
133
+ messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
134
+ else:
135
+ with open(input_file) as f:
136
+ prompts = f.read().split("\n\n")
137
+ prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
138
+ completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
139
+ completions = tokenizer.batch_decode(completion_tokens)
140
+ for prompt, completion in zip(prompts, completions):
141
+ print("Prompt:", prompt)
142
+ print("Completion:", completion)
143
+ print()
144
+
145
+ if world_size > 1:
146
+ dist.destroy_process_group()
147
+
148
+
149
+ if __name__ == "__main__":
150
+ parser = ArgumentParser()
151
+ parser.add_argument("--ckpt-path", type=str, required=True)
152
+ parser.add_argument("--config", type=str, required=True)
153
+ parser.add_argument("--input-file", type=str, default="")
154
+ parser.add_argument("--interactive", action="store_true")
155
+ parser.add_argument("--max-new-tokens", type=int, default=300)
156
+ parser.add_argument("--temperature", type=float, default=0.6)
157
+ args = parser.parse_args()
158
+ assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
159
+ main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
inference/kernel.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tilelang
3
+ import tilelang.language as T
4
+ from typing import Tuple, Optional
5
+
6
+
7
+ tilelang.set_log_level("WARNING")
8
+
9
+ pass_configs = {
10
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
11
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
12
+ }
13
+
14
+ FP8 = "float8_e4m3"
15
+ FP4 = "float4_e2m1fn"
16
+ FE8M0 = "float8_e8m0fnu"
17
+ BF16 = "bfloat16"
18
+ FP32 = "float32"
19
+ INT32 = "int32"
20
+
21
+
22
+ def fast_log2_ceil(x):
23
+ """Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
24
+ bits_x = T.reinterpret("uint32", x)
25
+ exp_x = (bits_x >> 23) & 0xFF
26
+ man_bits = bits_x & ((1 << 23) - 1)
27
+ return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
28
+
29
+
30
+ def fast_pow2(x):
31
+ """Compute 2^x for integer x via IEEE 754 bit manipulation."""
32
+ bits_x = (x + 127) << 23
33
+ return T.reinterpret("float32", bits_x)
34
+
35
+
36
+ def fast_round_scale(amax, fp8_max_inv):
37
+ return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
38
+
39
+
40
+ @tilelang.jit(pass_configs=pass_configs)
41
+ def act_quant_kernel(
42
+ N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
43
+ round_scale=False, inplace=False
44
+ ):
45
+ """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
46
+ M = T.symbolic("M")
47
+ fp8_min = -448.0
48
+ fp8_max = 448.0
49
+ fp8_max_inv = 1 / fp8_max
50
+ num_stages = 0 if round_scale or inplace else 2
51
+ blk_m = 32
52
+ group_size = block_size
53
+ # Internal computation in FP32; scale_dtype controls output storage format.
54
+ compute_dtype = FP32
55
+ out_dtype = in_dtype if inplace else out_dtype
56
+
57
+ @T.prim_func
58
+ def act_quant_kernel_(
59
+ X: T.Tensor[(M, N), in_dtype],
60
+ Y: T.Tensor[(M, N), out_dtype],
61
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
62
+ ):
63
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
64
+ pid_m,
65
+ pid_n,
66
+ ):
67
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
68
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
69
+ amax_local = T.alloc_fragment((blk_m,), compute_dtype)
70
+ s_local = T.alloc_fragment((blk_m,), compute_dtype)
71
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
72
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
73
+
74
+ for _ in T.Pipelined(1, num_stages=num_stages):
75
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
76
+ T.copy(x_shared, x_local)
77
+ T.reduce_absmax(x_local, amax_local, dim=1)
78
+ for i in T.Parallel(blk_m):
79
+ amax_local[i] = T.max(amax_local[i], 1e-4)
80
+ if round_scale:
81
+ s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
82
+ else:
83
+ s_local[i] = amax_local[i] * fp8_max_inv
84
+ if inplace:
85
+ for i, j in T.Parallel(blk_m, group_size):
86
+ y_local[i, j] = T.Cast(
87
+ out_dtype,
88
+ T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
89
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
90
+ ))) * s_local[i],
91
+ )
92
+ else:
93
+ for i, j in T.Parallel(blk_m, group_size):
94
+ y_local[i, j] = T.clamp(
95
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
96
+ )
97
+ for i in T.Parallel(blk_m):
98
+ S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
99
+ T.copy(y_local, y_shared)
100
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
101
+
102
+ return act_quant_kernel_
103
+
104
+
105
+ def act_quant(
106
+ x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
107
+ scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
108
+ ) -> torch.Tensor:
109
+ """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
110
+ When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
111
+ N = x.size(-1)
112
+ assert N % block_size == 0
113
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
114
+ z = x.contiguous()
115
+ y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
116
+ s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
117
+ kernel = act_quant_kernel(
118
+ N, block_size, scale_dtype=tl_dtype,
119
+ round_scale=scale_fmt is not None, inplace=inplace,
120
+ )
121
+ kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
122
+ if inplace:
123
+ x.copy_(y)
124
+ return x
125
+ return y, s
126
+
127
+
128
+ @tilelang.jit(pass_configs=pass_configs)
129
+ def fp4_quant_kernel(
130
+ N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
131
+ ):
132
+ """Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
133
+ M = T.symbolic("M")
134
+ fp4_max = 6.0
135
+ fp4_max_inv = 1.0 / fp4_max
136
+ blk_m = 32
137
+ group_size = block_size
138
+ compute_dtype = FP32
139
+ out_dtype = in_dtype if inplace else FP4
140
+
141
+ @T.prim_func
142
+ def fp4_quant_kernel_(
143
+ X: T.Tensor[(M, N), in_dtype],
144
+ Y: T.Tensor[(M, N), out_dtype],
145
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
146
+ ):
147
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
148
+ pid_m,
149
+ pid_n,
150
+ ):
151
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
152
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
153
+ amax_local = T.alloc_fragment((blk_m,), compute_dtype)
154
+ s_local = T.alloc_fragment((blk_m,), compute_dtype)
155
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
156
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
157
+
158
+ for _ in T.Pipelined(1, num_stages=2):
159
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
160
+ T.copy(x_shared, x_local)
161
+ T.reduce_absmax(x_local, amax_local, dim=1)
162
+ for i in T.Parallel(blk_m):
163
+ amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
164
+ s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
165
+ if inplace:
166
+ for i, j in T.Parallel(blk_m, group_size):
167
+ y_local[i, j] = T.Cast(
168
+ out_dtype,
169
+ T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
170
+ x_local[i, j] / s_local[i], -fp4_max, fp4_max
171
+ ))) * s_local[i],
172
+ )
173
+ else:
174
+ for i, j in T.Parallel(blk_m, group_size):
175
+ y_local[i, j] = T.clamp(
176
+ x_local[i, j] / s_local[i], -fp4_max, fp4_max
177
+ )
178
+ for i in T.Parallel(blk_m):
179
+ S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
180
+ T.copy(y_local, y_shared)
181
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
182
+
183
+ return fp4_quant_kernel_
184
+
185
+
186
+ def fp4_act_quant(
187
+ x: torch.Tensor, block_size: int = 32, inplace: bool = False,
188
+ ) -> torch.Tensor:
189
+ """Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
190
+ N = x.size(-1)
191
+ assert N % block_size == 0
192
+ z = x.contiguous()
193
+ y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
194
+ s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
195
+ kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
196
+ kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
197
+ if inplace:
198
+ x.copy_(y)
199
+ return x
200
+ return y, s
201
+
202
+
203
+ @tilelang.jit(pass_configs=pass_configs)
204
+ def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
205
+ assert out_dtype in [BF16, FP32]
206
+
207
+ M = T.symbolic("M")
208
+ group_size = 128
209
+ block_M = 32
210
+ block_N = 128
211
+ block_K = 128
212
+
213
+ @T.prim_func
214
+ def fp8_gemm_kernel_(
215
+ A: T.Tensor[(M, K), FP8],
216
+ B: T.Tensor[(N, K), FP8],
217
+ C: T.Tensor[(M, N), out_dtype],
218
+ scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
219
+ scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
220
+ ):
221
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
222
+ bx,
223
+ by,
224
+ ):
225
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
226
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
227
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
228
+ Scale_C_shared = T.alloc_shared((block_M), FP32)
229
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
230
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
231
+
232
+ # Improve L2 Cache
233
+ T.use_swizzle(panel_size=10)
234
+ T.clear(C_local)
235
+ T.clear(C_local_accum)
236
+
237
+ K_iters = T.ceildiv(K, block_K)
238
+ for k in T.Pipelined(K_iters, num_stages=4):
239
+ T.copy(A[by * block_M, k * block_K], A_shared)
240
+ T.copy(B[bx * block_N, k * block_K], B_shared)
241
+ # Cast scales to FP32 for computation; scales_b has one value per block_N group
242
+ Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
243
+ for i in T.Parallel(block_M):
244
+ Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
245
+
246
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
247
+ # Separate accumulator for scale-corrected results (2x accumulation precision)
248
+ for i, j in T.Parallel(block_M, block_N):
249
+ C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
250
+ T.clear(C_local)
251
+ T.copy(C_local_accum, C_shared)
252
+ T.copy(C_shared, C[by * block_M, bx * block_N])
253
+
254
+ return fp8_gemm_kernel_
255
+
256
+
257
+ def fp8_gemm(
258
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
259
+ scale_dtype: torch.dtype = torch.float32,
260
+ ) -> torch.Tensor:
261
+ """C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
262
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
263
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
264
+ "Scaling factor tensors must be contiguous"
265
+ )
266
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
267
+ K = a.size(-1)
268
+ M = a.numel() // K
269
+ N = b.size(0)
270
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
271
+ kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
272
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
273
+ return c
274
+
275
+
276
+ @tilelang.jit(pass_configs=pass_configs)
277
+ def sparse_attn_kernel(h: int, d: int, scale=None):
278
+ """Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
279
+ For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
280
+ with numerically stable running max/sum, and includes a learnable attn_sink bias."""
281
+ b = T.symbolic("b")
282
+ m = T.symbolic("m")
283
+ n = T.symbolic("n")
284
+ topk = T.symbolic("topk")
285
+ if scale is None:
286
+ scale = (1.0 / d) ** 0.5
287
+
288
+ num_stages = 2
289
+ threads = 256
290
+ block = 64
291
+ num_blocks = tilelang.cdiv(topk, block)
292
+
293
+ @T.prim_func
294
+ def sparse_attn_kernel_(
295
+ q: T.Tensor[(b, m, h, d), BF16],
296
+ kv: T.Tensor[(b, n, d), BF16],
297
+ o: T.Tensor[(b, m, h, d), BF16],
298
+ attn_sink: T.Tensor[(h,), FP32],
299
+ topk_idxs: T.Tensor[(b, m, topk), INT32],
300
+ ):
301
+ with T.Kernel(m, b, threads=threads) as (bx, by):
302
+ q_shared = T.alloc_shared((h, d), BF16)
303
+ kv_shared = T.alloc_shared((block, d), BF16)
304
+ o_shared = T.alloc_shared((h, d), BF16)
305
+ acc_s_cast = T.alloc_shared((h, block), BF16)
306
+
307
+ idxs = T.alloc_fragment(block, INT32)
308
+ acc_s = T.alloc_fragment((h, block), FP32)
309
+ acc_o = T.alloc_fragment((h, d), FP32)
310
+ scores_max = T.alloc_fragment(h, FP32)
311
+ scores_max_prev = T.alloc_fragment(h, FP32)
312
+ scores_scale = T.alloc_fragment(h, FP32)
313
+ scores_sum = T.alloc_fragment(h, FP32)
314
+ sum_exp = T.alloc_fragment(h, FP32)
315
+
316
+ T.clear(acc_o)
317
+ T.clear(sum_exp)
318
+ T.fill(scores_max, -T.infinity(FP32))
319
+ T.copy(q[by, bx, :, :], q_shared)
320
+
321
+ for t in T.Pipelined(num_blocks, num_stages=num_stages):
322
+ for i in T.Parallel(block):
323
+ idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
324
+ for i, j in T.Parallel(block, d):
325
+ kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
326
+ for i, j in T.Parallel(h, block):
327
+ acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
328
+ T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
329
+ for i, j in T.Parallel(h, block):
330
+ acc_s[i, j] *= scale
331
+ T.copy(scores_max, scores_max_prev)
332
+ T.reduce_max(acc_s, scores_max, dim=1, clear=False)
333
+ for i in T.Parallel(h):
334
+ scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
335
+ for i, j in T.Parallel(h, block):
336
+ acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
337
+ T.reduce_sum(acc_s, scores_sum, dim=1)
338
+ for i in T.Parallel(h):
339
+ sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
340
+ T.copy(acc_s, acc_s_cast)
341
+ for i, j in T.Parallel(h, d):
342
+ acc_o[i, j] *= scores_scale[i]
343
+ T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
344
+
345
+ for i in T.Parallel(h):
346
+ sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
347
+ for i, j in T.Parallel(h, d):
348
+ acc_o[i, j] /= sum_exp[i]
349
+ T.copy(acc_o, o_shared)
350
+ T.copy(o_shared, o[by, bx, :, :])
351
+
352
+ return sparse_attn_kernel_
353
+
354
+
355
+ def sparse_attn(
356
+ q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
357
+ ) -> torch.Tensor:
358
+ b, s, h, d = q.size()
359
+ # Pad heads to 16 for kernel efficiency (stripped after)
360
+ if h < 16:
361
+ q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
362
+ attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
363
+ o = torch.empty_like(q)
364
+ kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
365
+ kernel(q, kv, o, attn_sink, topk_idxs)
366
+ if h < 16:
367
+ o = o.narrow(2, 0, h).contiguous()
368
+ return o
369
+
370
+
371
+ @tilelang.jit(pass_configs=pass_configs)
372
+ def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
373
+ n = T.symbolic("n")
374
+ mix_hc = (2 + hc) * hc
375
+ threads = 64
376
+
377
+ @T.prim_func
378
+ def hc_split_sinkhorn_kernel_(
379
+ mixes: T.Tensor[(n, mix_hc), FP32],
380
+ hc_scale: T.Tensor[(3,), FP32],
381
+ hc_base: T.Tensor[(mix_hc,), FP32],
382
+ pre: T.Tensor[(n, hc), FP32],
383
+ post: T.Tensor[(n, hc), FP32],
384
+ comb: T.Tensor[(n, hc, hc), FP32],
385
+ ):
386
+ with T.Kernel(n, threads=threads) as i:
387
+ mixes_shared = T.alloc_shared(mix_hc, FP32)
388
+ comb_frag = T.alloc_fragment((hc, hc), FP32)
389
+ T.copy(mixes[i, :], mixes_shared)
390
+
391
+ for j in T.Parallel(hc):
392
+ pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
393
+ for j in T.Parallel(hc):
394
+ post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
395
+ for j, k in T.Parallel(hc, hc):
396
+ comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
397
+
398
+ row_sum = T.alloc_fragment(hc, FP32)
399
+ col_sum = T.alloc_fragment(hc, FP32)
400
+
401
+ # comb = comb.softmax(-1) + eps
402
+ row_max = T.alloc_fragment(hc, FP32)
403
+ T.reduce_max(comb_frag, row_max, dim=1)
404
+ for j, k in T.Parallel(hc, hc):
405
+ comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
406
+ T.reduce_sum(comb_frag, row_sum, dim=1)
407
+ for j, k in T.Parallel(hc, hc):
408
+ comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
409
+
410
+ # comb = comb / (comb.sum(-2) + eps)
411
+ T.reduce_sum(comb_frag, col_sum, dim=0)
412
+ for j, k in T.Parallel(hc, hc):
413
+ comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
414
+
415
+ for _ in T.serial(sinkhorn_iters - 1):
416
+ # comb = comb / (comb.sum(-1) + eps)
417
+ T.reduce_sum(comb_frag, row_sum, dim=1)
418
+ for j, k in T.Parallel(hc, hc):
419
+ comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
420
+ # comb = comb / (comb.sum(-2) + eps)
421
+ T.reduce_sum(comb_frag, col_sum, dim=0)
422
+ for j, k in T.Parallel(hc, hc):
423
+ comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
424
+
425
+ T.copy(comb_frag, comb[i, :, :])
426
+
427
+ return hc_split_sinkhorn_kernel_
428
+
429
+
430
+ def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
431
+ b, s, _ = mixes.size()
432
+ pre = mixes.new_empty(b, s, hc_mult)
433
+ post = mixes.new_empty(b, s, hc_mult)
434
+ comb = mixes.new_empty(b, s, hc_mult, hc_mult)
435
+ kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
436
+ kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
437
+ pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
438
+ return pre, post, comb
439
+
440
+
441
+ @tilelang.jit(pass_configs=pass_configs)
442
+ def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
443
+ """FP8 act x FP4 weight GEMM kernel.
444
+
445
+ C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
446
+
447
+ Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
448
+ Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
449
+
450
+ B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
451
+ The FP4 values are packed along the K (last) dimension.
452
+
453
+ Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
454
+ cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
455
+ Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
456
+ """
457
+ M = T.symbolic("M")
458
+ act_group_size = 128
459
+ weight_group_size = 32
460
+ block_M = 32
461
+ block_N = 128
462
+ block_K = 32 # matches weight_group_size for simple scale handling
463
+ n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
464
+
465
+ @T.prim_func
466
+ def fp4_gemm_kernel_(
467
+ A: T.Tensor[(M, K), FP8],
468
+ B: T.Tensor[(N, K), FP4],
469
+ C: T.Tensor[(M, N), out_dtype],
470
+ scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
471
+ scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
472
+ ):
473
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
474
+ bx,
475
+ by,
476
+ ):
477
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
478
+ B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
479
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
480
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
481
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
482
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
483
+ scale_a_frag = T.alloc_fragment((block_M,), FP32)
484
+ scale_b_frag = T.alloc_fragment((block_N,), FP32)
485
+
486
+ T.use_swizzle(panel_size=10)
487
+ T.clear(C_local)
488
+ T.clear(C_local_accum)
489
+
490
+ K_iters = T.ceildiv(K, block_K)
491
+ for k in T.Pipelined(K_iters, num_stages=2):
492
+ T.copy(A[by * block_M, k * block_K], A_shared)
493
+ T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
494
+ # FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
495
+ for i, j in T.Parallel(block_N, block_K):
496
+ B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
497
+
498
+ # Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
499
+ for i in T.Parallel(block_N):
500
+ scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
501
+
502
+ # Act scale: per 128 on K, indexed by k // 4
503
+ for i in T.Parallel(block_M):
504
+ scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
505
+
506
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
507
+
508
+ for i, j in T.Parallel(block_M, block_N):
509
+ C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
510
+ T.clear(C_local)
511
+
512
+ T.copy(C_local_accum, C_shared)
513
+ T.copy(C_shared, C[by * block_M, bx * block_N])
514
+
515
+ return fp4_gemm_kernel_
516
+
517
+
518
+ def fp4_gemm(
519
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
520
+ scale_dtype: torch.dtype = torch.float32,
521
+ ) -> torch.Tensor:
522
+ """C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
523
+ A has per-128 act scale; B has per-32 E8M0 weight scale.
524
+ B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
525
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
526
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
527
+ "Scaling factor tensors must be contiguous"
528
+ )
529
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
530
+ K = a.size(-1)
531
+ M = a.numel() // K
532
+ N = b.size(0)
533
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
534
+ kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
535
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
536
+ return c
inference/model.py ADDED
@@ -0,0 +1,992 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Optional, Literal
4
+ from functools import lru_cache
5
+ from contextlib import contextmanager
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+
12
+ from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
13
+
14
+
15
+ world_size = 1
16
+ rank = 0
17
+ block_size = 128
18
+ fp4_block_size = 32
19
+ w4a16_group_size = 128
20
+ default_dtype = torch.bfloat16
21
+ scale_fmt = None
22
+ scale_dtype = torch.float32
23
+ w4a16_mode = False # set in Transformer.__init__ when args.dtype == "w4a16"
24
+
25
+
26
+ def dequantize_w4a16(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
27
+ group_size: int = 128) -> torch.Tensor:
28
+ """Auto-round / auto_gptq W4A16 packing -> BF16 weight [out, in].
29
+
30
+ qweight: int32 [in/8, out], LSB-first 4-bit packed along input dim
31
+ qzeros : int32 [in/g, out/8], LSB-first 4-bit packed along output dim
32
+ scales : bf16 [in/g, out]
33
+ """
34
+ in_packed, out_features = qweight.shape
35
+ in_features = in_packed * 8
36
+ n_groups = scales.shape[0]
37
+ device = qweight.device
38
+ shifts = torch.arange(0, 32, 4, device=device, dtype=torch.int32)
39
+ w = (qweight.unsqueeze(1) >> shifts.view(1, 8, 1)) & 0xF # [in/8, 8, out]
40
+ w = w.reshape(in_features, out_features).to(torch.float32)
41
+ z = (qzeros.unsqueeze(2) >> shifts.view(1, 1, 8)) & 0xF # [in/g, out/8, 8]
42
+ z = z.reshape(n_groups, out_features).to(torch.float32) + 1.0 # GPTQ stores zero - 1
43
+ s = scales.to(torch.float32)
44
+ w = w.view(n_groups, group_size, out_features)
45
+ deq = (w - z.unsqueeze(1)) * s.unsqueeze(1)
46
+ deq = deq.view(in_features, out_features)
47
+ return deq.t().contiguous().to(torch.bfloat16)
48
+
49
+
50
+ @contextmanager
51
+ def set_dtype(dtype):
52
+ """Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
53
+ prev = torch.get_default_dtype()
54
+ torch.set_default_dtype(dtype)
55
+ try:
56
+ yield
57
+ finally:
58
+ torch.set_default_dtype(prev)
59
+
60
+ @dataclass
61
+ class ModelArgs:
62
+ """Model hyperparameters. Field names match the config JSON keys."""
63
+ max_batch_size: int = 4
64
+ max_seq_len: int = 4096
65
+ dtype: Literal["bf16", "fp8", "w4a16"] = "fp8"
66
+ scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
67
+ expert_dtype: Literal[None, "fp4"] = None
68
+ scale_dtype: Literal["fp32", "fp8"] = "fp8"
69
+ vocab_size: int = 129280
70
+ dim: int = 4096
71
+ moe_inter_dim: int = 4096
72
+ n_layers: int = 7
73
+ n_hash_layers: int = 0
74
+ n_mtp_layers: int = 1
75
+ n_heads: int = 64
76
+ # moe
77
+ n_routed_experts: int = 8
78
+ n_shared_experts: int = 1
79
+ n_activated_experts: int = 2
80
+ score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
81
+ route_scale: float = 1.
82
+ swiglu_limit: float = 0.
83
+ # mqa
84
+ q_lora_rank: int = 1024
85
+ head_dim: int = 512
86
+ rope_head_dim: int = 64
87
+ norm_eps: float = 1e-6
88
+ o_groups: int = 8
89
+ o_lora_rank: int = 1024
90
+ window_size: int = 128
91
+ compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
92
+ # yarn
93
+ compress_rope_theta: float = 40000.0
94
+ original_seq_len: int = 0
95
+ rope_theta: float = 10000.0
96
+ rope_factor: float = 40
97
+ beta_fast: int = 32
98
+ beta_slow: int = 1
99
+ # index
100
+ index_n_heads: int = 64
101
+ index_head_dim: int = 128
102
+ index_topk: int = 512
103
+ # hc
104
+ hc_mult: int = 4
105
+ hc_sinkhorn_iters: int = 20
106
+ hc_eps: float = 1e-6
107
+
108
+
109
+ class ParallelEmbedding(nn.Module):
110
+ """Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
111
+ Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
112
+ def __init__(self, vocab_size: int, dim: int):
113
+ super().__init__()
114
+ self.vocab_size = vocab_size
115
+ self.dim = dim
116
+ assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
117
+ self.part_vocab_size = (vocab_size // world_size)
118
+ self.vocab_start_idx = rank * self.part_vocab_size
119
+ self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
120
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ if world_size > 1:
124
+ mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
125
+ x = x - self.vocab_start_idx
126
+ x[mask] = 0
127
+ y = F.embedding(x, self.weight)
128
+ if world_size > 1:
129
+ y[mask] = 0
130
+ dist.all_reduce(y)
131
+ return y
132
+
133
+
134
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
135
+ """Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
136
+ For quantized weights, x is first quantized to FP8 via act_quant."""
137
+ assert bias is None
138
+
139
+ if weight.dtype == torch.float4_e2m1fn_x2:
140
+ x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
141
+ return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
142
+ elif weight.dtype == torch.float8_e4m3fn:
143
+ x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
144
+ return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
145
+ else:
146
+ return F.linear(x, weight)
147
+
148
+
149
+ class Linear(nn.Module):
150
+ """Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
151
+
152
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
153
+ super().__init__()
154
+ self.in_features = in_features
155
+ self.out_features = out_features
156
+ # In a W4A16 build every Linear becomes W4A16 regardless of the dtype the
157
+ # original FP8/FP4 model wanted. The non-quant special cases (RMSNorm,
158
+ # embed, attn_sink, etc.) are NOT instances of `Linear`, so they are
159
+ # untouched.
160
+ if w4a16_mode:
161
+ dtype = "w4a16"
162
+ else:
163
+ dtype = dtype or default_dtype
164
+ self.is_w4a16 = (dtype == "w4a16")
165
+ if self.is_w4a16:
166
+ assert in_features % 8 == 0 and in_features % w4a16_group_size == 0
167
+ assert out_features % 8 == 0
168
+ self.group_size = w4a16_group_size
169
+ self.qweight = nn.Parameter(
170
+ torch.empty(in_features // 8, out_features, dtype=torch.int32),
171
+ requires_grad=False,
172
+ )
173
+ self.qzeros = nn.Parameter(
174
+ torch.empty(in_features // self.group_size, out_features // 8, dtype=torch.int32),
175
+ requires_grad=False,
176
+ )
177
+ self.scales = nn.Parameter(
178
+ torch.empty(in_features // self.group_size, out_features, dtype=torch.bfloat16),
179
+ requires_grad=False,
180
+ )
181
+ self.register_parameter("weight", None)
182
+ self.register_parameter("scale", None)
183
+ elif dtype == torch.float4_e2m1fn_x2:
184
+ # FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
185
+ # Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
186
+ self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
187
+ scale_out_features = out_features
188
+ scale_in_features = in_features // fp4_block_size
189
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
190
+ elif dtype == torch.float8_e4m3fn:
191
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
192
+ scale_out_features = (out_features + block_size - 1) // block_size
193
+ scale_in_features = (in_features + block_size - 1) // block_size
194
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
195
+ else:
196
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
197
+ self.register_parameter("scale", None)
198
+ if bias:
199
+ self.bias = nn.Parameter(torch.empty(out_features))
200
+ else:
201
+ self.register_parameter("bias", None)
202
+
203
+ def init_woq(self, QuantLinear):
204
+ """Create a QuantLinear from loaded GPTQ parameters."""
205
+ if not self.is_w4a16:
206
+ return
207
+ # Marlin requires out_features % 64 == 0; fall back to manual dequant
208
+ if self.out_features % 64 != 0:
209
+ self._woq = None
210
+ return
211
+ dev = self.qweight.device
212
+ layer = QuantLinear(
213
+ bits=4, group_size=self.group_size,
214
+ in_features=self.in_features, out_features=self.out_features,
215
+ bias=False, desc_act=False, sym=True, register_buffers=True,
216
+ )
217
+ layer = layer.to(dev)
218
+ layer.qweight.copy_(self.qweight.data)
219
+ layer.qzeros.copy_(self.qzeros.data)
220
+ layer.scales.copy_(self.scales.to(layer.scales.dtype).data)
221
+ layer.g_idx.copy_(torch.arange(self.in_features, dtype=torch.int32, device=dev) // self.group_size)
222
+ layer.post_init()
223
+ self._woq = layer
224
+ # Free original parameters to save memory
225
+ self.qweight = None
226
+ self.qzeros = None
227
+ self.scales = None
228
+
229
+ def get_weight(self) -> torch.Tensor:
230
+ """Return the dequantised BF16 weight [out, in]. For non-W4A16 modes
231
+ returns ``self.weight`` unchanged. Used only for wo_a einsum path."""
232
+ if self.is_w4a16:
233
+ if self._woq is not None:
234
+ return dequantize_w4a16(self._woq.qweight, self._woq.qzeros, self._woq.scales, self.group_size)
235
+ return dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
236
+ return self.weight
237
+
238
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
239
+ if self.is_w4a16:
240
+ if hasattr(self, '_woq') and self._woq is not None:
241
+ y = self._woq(x.to(torch.bfloat16))
242
+ else:
243
+ w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
244
+ y = F.linear(x.to(w.dtype), w)
245
+ if self.bias is not None:
246
+ y = y + self.bias
247
+ return y.type_as(x)
248
+ return linear(x, self.weight, self.bias)
249
+
250
+
251
+ class ColumnParallelLinear(Linear):
252
+ """Shards output dim across TP ranks. No all-reduce needed on output."""
253
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
254
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
255
+ self.part_out_features = out_features // world_size
256
+ super().__init__(in_features, self.part_out_features, bias, dtype)
257
+
258
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
259
+ if self.is_w4a16:
260
+ return Linear.forward(self, x)
261
+ return linear(x, self.weight, self.bias)
262
+
263
+
264
+ class RowParallelLinear(Linear):
265
+ """Shards input dim across TP ranks. All-reduce on output to sum partial results."""
266
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
267
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
268
+ self.part_in_features = in_features // world_size
269
+ super().__init__(self.part_in_features, out_features, bias, dtype)
270
+
271
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
272
+ if self.is_w4a16:
273
+ if hasattr(self, '_woq') and self._woq is not None:
274
+ y = self._woq(x.to(torch.bfloat16))
275
+ else:
276
+ w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
277
+ y = F.linear(x.to(w.dtype), w)
278
+ else:
279
+ y = linear(x, self.weight, None)
280
+ if world_size > 1:
281
+ y = y.float()
282
+ dist.all_reduce(y)
283
+ if self.bias is not None:
284
+ y += self.bias
285
+ return y.type_as(x)
286
+
287
+
288
+ class RMSNorm(nn.Module):
289
+ def __init__(self, dim: int, eps: float = 1e-6):
290
+ super().__init__()
291
+ self.dim = dim
292
+ self.eps = eps
293
+ # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
294
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
295
+
296
+ def forward(self, x: torch.Tensor):
297
+ dtype = x.dtype
298
+ x = x.float()
299
+ var = x.square().mean(-1, keepdim=True)
300
+ x = x * torch.rsqrt(var + self.eps)
301
+ return (self.weight * x).to(dtype)
302
+
303
+
304
+ @lru_cache(2)
305
+ def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
306
+ """Precomputes complex exponentials for rotary embeddings with YaRN scaling.
307
+ When original_seq_len > 0, applies frequency interpolation with a smooth
308
+ linear ramp between beta_fast and beta_slow correction ranges."""
309
+
310
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
311
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
312
+
313
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
314
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
315
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
316
+ return max(low, 0), min(high, dim-1)
317
+
318
+ def linear_ramp_factor(min, max, dim):
319
+ if min == max:
320
+ max += 0.001
321
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
322
+ ramp_func = torch.clamp(linear_func, 0, 1)
323
+ return ramp_func
324
+
325
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
326
+ if original_seq_len > 0:
327
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
328
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
329
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
330
+
331
+ t = torch.arange(seqlen)
332
+ freqs = torch.outer(t, freqs)
333
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
334
+ return freqs_cis
335
+
336
+
337
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
338
+ """Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
339
+ y = x
340
+ x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
341
+ if inverse:
342
+ freqs_cis = freqs_cis.conj()
343
+ if x.ndim == 3:
344
+ freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
345
+ else:
346
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
347
+ x = torch.view_as_real(x * freqs_cis).flatten(-2)
348
+ y.copy_(x)
349
+ return y
350
+
351
+
352
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
353
+ """Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
354
+ assert x.dtype == torch.bfloat16
355
+ from fast_hadamard_transform import hadamard_transform
356
+ return hadamard_transform(x, scale=x.size(-1) ** -0.5)
357
+
358
+
359
+ @lru_cache(1)
360
+ def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
361
+ if start_pos >= window_size - 1:
362
+ start_pos %= window_size
363
+ matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
364
+ elif start_pos > 0:
365
+ matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
366
+ else:
367
+ base = torch.arange(seqlen).unsqueeze(1)
368
+ matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
369
+ matrix = torch.where(matrix > base, -1, matrix)
370
+ return matrix.unsqueeze(0).expand(bsz, -1, -1)
371
+
372
+
373
+ @lru_cache(2)
374
+ def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
375
+ if start_pos > 0:
376
+ matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
377
+ else:
378
+ matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
379
+ mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
380
+ matrix = torch.where(mask, -1, matrix + offset)
381
+ return matrix.unsqueeze(0).expand(bsz, -1, -1)
382
+
383
+
384
+ class Compressor(nn.Module):
385
+ """Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
386
+ When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
387
+
388
+ def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
389
+ super().__init__()
390
+ self.dim = args.dim
391
+ self.head_dim = head_dim
392
+ self.rope_head_dim = args.rope_head_dim
393
+ self.nope_head_dim = head_dim - args.rope_head_dim
394
+ self.compress_ratio = compress_ratio
395
+ self.overlap = compress_ratio == 4
396
+ self.rotate = rotate
397
+ coff = 1 + self.overlap
398
+
399
+ self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
400
+ # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
401
+ # When overlap, the first half of dims is for overlapping compression, second half for normal.
402
+ self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
403
+ self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
404
+ self.norm = RMSNorm(self.head_dim, args.norm_eps)
405
+ self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
406
+ # State buffers for decode-phase incremental compression.
407
+ # With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
408
+ self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
409
+ self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
410
+ self.freqs_cis: torch.Tensor = None
411
+
412
+ def overlap_transform(self, tensor: torch.Tensor, value=0):
413
+ # tensor: [b,s,r,2d]
414
+ b, s, _, _ = tensor.size()
415
+ ratio, d = self.compress_ratio, self.head_dim
416
+ new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
417
+ new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
418
+ new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
419
+ return new_tensor
420
+
421
+ def forward(self, x: torch.Tensor, start_pos: int):
422
+ assert self.kv_cache is not None
423
+ bsz, seqlen, _ = x.size()
424
+ ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
425
+ dtype = x.dtype
426
+ # compression need fp32
427
+ x = x.float()
428
+ kv = self.wkv(x)
429
+ score = self.wgate(x)
430
+ if start_pos == 0:
431
+ should_compress = seqlen >= ratio
432
+ remainder = seqlen % ratio
433
+ cutoff = seqlen - remainder
434
+ offset = ratio if overlap else 0
435
+ if overlap and cutoff >= ratio:
436
+ self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
437
+ self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
438
+ if remainder > 0:
439
+ kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
440
+ self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
441
+ score = score[:, :cutoff]
442
+ kv = kv.unflatten(1, (-1, ratio))
443
+ score = score.unflatten(1, (-1, ratio)) + self.ape
444
+ if overlap:
445
+ kv = self.overlap_transform(kv, 0)
446
+ score = self.overlap_transform(score, float("-inf"))
447
+ kv = (kv * score.softmax(dim=2)).sum(dim=2)
448
+ else:
449
+ should_compress = (start_pos + 1) % self.compress_ratio == 0
450
+ score += self.ape[start_pos % ratio]
451
+ if overlap:
452
+ self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
453
+ self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
454
+ if should_compress:
455
+ kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
456
+ score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
457
+ kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
458
+ self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
459
+ self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
460
+ else:
461
+ self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
462
+ self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
463
+ if should_compress:
464
+ kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
465
+ if not should_compress:
466
+ return
467
+ kv = self.norm(kv.to(dtype))
468
+ if start_pos == 0:
469
+ freqs_cis = self.freqs_cis[:cutoff:ratio]
470
+ else:
471
+ freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
472
+ apply_rotary_emb(kv[..., -rd:], freqs_cis)
473
+ if self.rotate:
474
+ kv = rotate_activation(kv)
475
+ fp4_act_quant(kv, fp4_block_size, True)
476
+ else:
477
+ act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
478
+ if start_pos == 0:
479
+ self.kv_cache[:bsz, :seqlen // ratio] = kv
480
+ else:
481
+ self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
482
+ return kv
483
+
484
+
485
+ class Indexer(torch.nn.Module):
486
+ """Selects top-k compressed KV positions for sparse attention via learned scoring.
487
+ Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
488
+
489
+ def __init__(self, args: ModelArgs, compress_ratio: int = 4):
490
+ super().__init__()
491
+ self.dim = args.dim
492
+ self.n_heads = args.index_n_heads
493
+ self.n_local_heads = args.index_n_heads // world_size
494
+ self.head_dim = args.index_head_dim
495
+ self.rope_head_dim = args.rope_head_dim
496
+ self.index_topk = args.index_topk
497
+ self.q_lora_rank = args.q_lora_rank
498
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
499
+ self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
500
+ self.softmax_scale = self.head_dim ** -0.5
501
+ self.compress_ratio = compress_ratio
502
+
503
+ self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
504
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
505
+ self.freqs_cis = None
506
+
507
+ def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
508
+ bsz, seqlen, _ = x.size()
509
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
510
+ ratio = self.compress_ratio
511
+ rd = self.rope_head_dim
512
+ end_pos = start_pos + seqlen
513
+ if self.compressor.kv_cache is None:
514
+ self.compressor.kv_cache = self.kv_cache
515
+ self.compressor.freqs_cis = self.freqs_cis
516
+ q = self.wq_b(qr)
517
+ q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
518
+ apply_rotary_emb(q[..., -rd:], freqs_cis)
519
+ q = rotate_activation(q)
520
+ # use fp4 simulation for q and kv in indexer
521
+ fp4_act_quant(q, fp4_block_size, True)
522
+ self.compressor(x, start_pos)
523
+ weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
524
+ # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
525
+ index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
526
+ index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
527
+ if world_size > 1:
528
+ dist.all_reduce(index_score)
529
+ if start_pos == 0:
530
+ mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
531
+ index_score += torch.where(mask, float("-inf"), 0)
532
+ topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
533
+ if start_pos == 0:
534
+ mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
535
+ topk_idxs = torch.where(mask, -1, topk_idxs + offset)
536
+ else:
537
+ topk_idxs += offset
538
+ return topk_idxs
539
+
540
+
541
+ class Attention(nn.Module):
542
+ """Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
543
+ Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
544
+ def __init__(self, layer_id: int, args: ModelArgs):
545
+ super().__init__()
546
+ self.layer_id = layer_id
547
+ self.dim = args.dim
548
+ self.n_heads = args.n_heads
549
+ self.n_local_heads = args.n_heads // world_size
550
+ self.q_lora_rank = args.q_lora_rank
551
+ self.o_lora_rank = args.o_lora_rank
552
+ self.head_dim = args.head_dim
553
+ self.rope_head_dim = args.rope_head_dim
554
+ self.nope_head_dim = args.head_dim - args.rope_head_dim
555
+ self.n_groups = args.o_groups
556
+ self.n_local_groups = self.n_groups // world_size
557
+ self.window_size = args.window_size
558
+ self.compress_ratio = args.compress_ratios[layer_id]
559
+ self.eps = args.norm_eps
560
+
561
+ self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
562
+ self.wq_a = Linear(self.dim, self.q_lora_rank)
563
+ self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
564
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
565
+ self.wkv = Linear(self.dim, self.head_dim)
566
+ self.kv_norm = RMSNorm(self.head_dim, self.eps)
567
+ self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
568
+ self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
569
+ self.softmax_scale = self.head_dim ** -0.5
570
+
571
+ if self.compress_ratio:
572
+ self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
573
+ if self.compress_ratio == 4:
574
+ self.indexer = Indexer(args, self.compress_ratio)
575
+ else:
576
+ self.indexer = None
577
+
578
+ kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
579
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
580
+ if self.compress_ratio:
581
+ original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
582
+ else:
583
+ # disable YaRN and use base rope_theta in pure sliding-window attention
584
+ original_seq_len, rope_theta = 0, args.rope_theta
585
+ freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
586
+ rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
587
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
588
+
589
+ def forward(self, x: torch.Tensor, start_pos: int):
590
+ bsz, seqlen, _ = x.size()
591
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
592
+ win = self.window_size
593
+ ratio = self.compress_ratio
594
+ rd = self.rope_head_dim
595
+ if self.compress_ratio and self.compressor.kv_cache is None:
596
+ self.compressor.kv_cache = self.kv_cache[:, win:]
597
+ self.compressor.freqs_cis = self.freqs_cis
598
+ if self.indexer is not None:
599
+ self.indexer.freqs_cis = self.freqs_cis
600
+ # q
601
+ qr = q = self.q_norm(self.wq_a(x))
602
+ q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
603
+ q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
604
+ apply_rotary_emb(q[..., -rd:], freqs_cis)
605
+
606
+ # win kv & topk_idxs
607
+ kv = self.wkv(x)
608
+ kv = self.kv_norm(kv)
609
+ apply_rotary_emb(kv[..., -rd:], freqs_cis)
610
+ # FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
611
+ act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
612
+ topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
613
+ if self.compress_ratio:
614
+ offset = kv.size(1) if start_pos == 0 else win
615
+ if self.indexer is not None:
616
+ compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
617
+ else:
618
+ compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
619
+ topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
620
+ topk_idxs = topk_idxs.int()
621
+
622
+ # compress kv & attn
623
+ if start_pos == 0:
624
+ if seqlen <= win:
625
+ self.kv_cache[:bsz, :seqlen] = kv
626
+ else:
627
+ cutoff = seqlen % win
628
+ self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
629
+ if self.compress_ratio:
630
+ if (kv_compress := self.compressor(x, start_pos)) is not None:
631
+ kv = torch.cat([kv, kv_compress], dim=1)
632
+ # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
633
+ o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
634
+ else:
635
+ self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
636
+ if self.compress_ratio:
637
+ self.compressor(x, start_pos)
638
+ o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
639
+ apply_rotary_emb(o[..., -rd:], freqs_cis, True)
640
+
641
+ # o: apply wo_a per-group projection then wo_b
642
+ # Flatten groups into the feature dim, call wo_a as a normal linear, then reshape back.
643
+ # Equivalent to the per-group einsum when wo_a weight is block-diagonal across groups
644
+ # (always true here since n_local_groups = n_groups/world_size = 1 for 8-GPU deploy).
645
+ o = o.view(bsz, seqlen, self.n_local_groups, -1)
646
+ o = self.wo_a(o.flatten(2)).view(bsz, seqlen, self.n_local_groups, self.o_lora_rank)
647
+ x = self.wo_b(o.flatten(2))
648
+ return x
649
+
650
+
651
+ class Gate(nn.Module):
652
+ """MoE gating: computes expert routing scores and selects top-k experts.
653
+ Supports hash-based routing (first n_hash_layers) where expert indices are
654
+ predetermined per token ID, and score-based routing (remaining layers)."""
655
+ def __init__(self, layer_id: int, args: ModelArgs):
656
+ super().__init__()
657
+ self.dim = args.dim
658
+ self.topk = args.n_activated_experts
659
+ self.score_func = args.score_func
660
+ self.route_scale = args.route_scale
661
+ self.hash = layer_id < args.n_hash_layers
662
+ self.is_w4a16 = w4a16_mode
663
+ if self.is_w4a16:
664
+ in_f, out_f = args.dim, args.n_routed_experts
665
+ assert in_f % w4a16_group_size == 0 and out_f % 8 == 0
666
+ self.group_size = w4a16_group_size
667
+ self.qweight = nn.Parameter(
668
+ torch.empty(in_f // 8, out_f, dtype=torch.int32), requires_grad=False)
669
+ self.qzeros = nn.Parameter(
670
+ torch.empty(in_f // self.group_size, out_f // 8, dtype=torch.int32), requires_grad=False)
671
+ self.scales = nn.Parameter(
672
+ torch.empty(in_f // self.group_size, out_f, dtype=torch.bfloat16), requires_grad=False)
673
+ self.register_parameter("weight", None)
674
+ else:
675
+ self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
676
+ if self.hash:
677
+ self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
678
+ self.bias = None
679
+ else:
680
+ self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
681
+
682
+ def init_woq(self, QuantLinear):
683
+ if not self.is_w4a16:
684
+ return
685
+ dev = self.qweight.device
686
+ in_f, out_f = self.dim, self.qweight.shape[1]
687
+ if out_f % 64 != 0:
688
+ self._woq = None
689
+ return
690
+ layer = QuantLinear(
691
+ bits=4, group_size=self.group_size,
692
+ in_features=in_f, out_features=out_f,
693
+ bias=False, desc_act=False, sym=True, register_buffers=True,
694
+ )
695
+ layer = layer.to(dev)
696
+ layer.qweight.copy_(self.qweight.data)
697
+ layer.qzeros.copy_(self.qzeros.data)
698
+ layer.scales.copy_(self.scales.to(layer.scales.dtype).data)
699
+ layer.g_idx.copy_(torch.arange(in_f, dtype=torch.int32, device=dev) // self.group_size)
700
+ layer.post_init()
701
+ self._woq = layer
702
+ self.qweight = None
703
+ self.qzeros = None
704
+ self.scales = None
705
+
706
+ def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
707
+ if self.is_w4a16:
708
+ if hasattr(self, '_woq') and self._woq is not None:
709
+ scores = self._woq(x.to(torch.bfloat16)).float()
710
+ else:
711
+ w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
712
+ scores = F.linear(x.to(w.dtype), w).float()
713
+ else:
714
+ scores = linear(x.float(), self.weight.float())
715
+ if self.score_func == "softmax":
716
+ scores = scores.softmax(dim=-1)
717
+ elif self.score_func == "sigmoid":
718
+ scores = scores.sigmoid()
719
+ else:
720
+ scores = F.softplus(scores).sqrt()
721
+ original_scores = scores
722
+ # Bias shifts scores for expert selection (topk) but does not affect routing weights.
723
+ if self.bias is not None:
724
+ scores = scores + self.bias
725
+ if self.hash:
726
+ indices = self.tid2eid[input_ids]
727
+ else:
728
+ indices = scores.topk(self.topk, dim=-1)[1]
729
+ weights = original_scores.gather(1, indices)
730
+ if self.score_func != "softmax":
731
+ weights /= weights.sum(dim=-1, keepdim=True)
732
+ weights *= self.route_scale
733
+ return weights, indices
734
+
735
+
736
+ class Expert(nn.Module):
737
+ """Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
738
+ def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
739
+ super().__init__()
740
+ self.w1 = Linear(dim, inter_dim, dtype=dtype)
741
+ self.w2 = Linear(inter_dim, dim, dtype=dtype)
742
+ self.w3 = Linear(dim, inter_dim, dtype=dtype)
743
+ self.swiglu_limit = swiglu_limit
744
+
745
+ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
746
+ dtype = x.dtype
747
+ gate = self.w1(x).float()
748
+ up = self.w3(x).float()
749
+ if self.swiglu_limit > 0:
750
+ up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
751
+ gate = torch.clamp(gate, max=self.swiglu_limit)
752
+ x = F.silu(gate) * up
753
+ if weights is not None:
754
+ x = weights * x
755
+ return self.w2(x.to(dtype))
756
+
757
+
758
+ class MoE(nn.Module):
759
+ """Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
760
+ Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
761
+ def __init__(self, layer_id: int, args: ModelArgs):
762
+ super().__init__()
763
+ self.layer_id = layer_id
764
+ self.dim = args.dim
765
+ assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
766
+ self.n_routed_experts = args.n_routed_experts
767
+ self.n_local_experts = args.n_routed_experts // world_size
768
+ self.n_activated_experts = args.n_activated_experts
769
+ self.experts_start_idx = rank * self.n_local_experts
770
+ self.experts_end_idx = self.experts_start_idx + self.n_local_experts
771
+ self.gate = Gate(layer_id, args)
772
+ expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
773
+ self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
774
+ for i in range(self.n_routed_experts)])
775
+ assert args.n_shared_experts == 1
776
+ # no swiglu_limit
777
+ self.shared_experts = Expert(args.dim, args.moe_inter_dim, swiglu_limit=args.swiglu_limit)
778
+
779
+ def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
780
+ shape = x.size()
781
+ x = x.view(-1, self.dim)
782
+ weights, indices = self.gate(x, input_ids.flatten())
783
+ y = torch.zeros_like(x, dtype=torch.float32)
784
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
785
+ for i in range(self.experts_start_idx, self.experts_end_idx):
786
+ if counts[i] == 0:
787
+ continue
788
+ expert = self.experts[i]
789
+ idx, top = torch.where(indices == i)
790
+ y[idx] += expert(x[idx], weights[idx, top, None])
791
+ if world_size > 1:
792
+ dist.all_reduce(y)
793
+ y += self.shared_experts(x)
794
+ return y.type_as(x).view(shape)
795
+
796
+
797
+ class Block(nn.Module):
798
+ """Transformer block with Hyper-Connections (HC) mixing.
799
+ Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
800
+ hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
801
+ hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
802
+ def __init__(self, layer_id: int, args: ModelArgs):
803
+ super().__init__()
804
+ self.layer_id = layer_id
805
+ self.norm_eps = args.norm_eps
806
+ self.attn = Attention(layer_id, args)
807
+ self.ffn = MoE(layer_id, args)
808
+ self.attn_norm = RMSNorm(args.dim, self.norm_eps)
809
+ self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
810
+ self.hc_mult = hc_mult = args.hc_mult
811
+ self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
812
+ self.hc_eps = args.hc_eps
813
+ mix_hc = (2 + hc_mult) * hc_mult
814
+ hc_dim = hc_mult * args.dim
815
+ with set_dtype(torch.float32):
816
+ self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
817
+ self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
818
+ self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
819
+ self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
820
+ self.hc_attn_scale = nn.Parameter(torch.empty(3))
821
+ self.hc_ffn_scale = nn.Parameter(torch.empty(3))
822
+
823
+ def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
824
+ # x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
825
+ shape, dtype = x.size(), x.dtype
826
+ x = x.flatten(2).float()
827
+ rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
828
+ mixes = F.linear(x, hc_fn) * rsqrt
829
+ pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
830
+ y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
831
+ return y.to(dtype), post, comb
832
+
833
+ def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
834
+ # x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
835
+ y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
836
+ return y.type_as(x)
837
+
838
+ def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
839
+ residual = x
840
+ x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
841
+ x = self.attn_norm(x)
842
+ x = self.attn(x, start_pos)
843
+ x = self.hc_post(x, residual, post, comb)
844
+
845
+ residual = x
846
+ x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
847
+ x = self.ffn_norm(x)
848
+ x = self.ffn(x, input_ids)
849
+ x = self.hc_post(x, residual, post, comb)
850
+ return x
851
+
852
+
853
+ class ParallelHead(nn.Module):
854
+
855
+ def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
856
+ super().__init__()
857
+ self.vocab_size = vocab_size
858
+ self.dim = dim
859
+ self.norm_eps = norm_eps
860
+ self.hc_eps = hc_eps
861
+ self.part_vocab_size = (vocab_size // world_size)
862
+ # lm_head is always stored as bf16 (even in W4A16 checkpoints); use fp32 for logit precision
863
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
864
+
865
+ def get_logits(self, x):
866
+ return F.linear(x[:, -1].float(), self.weight)
867
+
868
+ def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
869
+ # x: [b,s,hc,d]
870
+ x = self.hc_head(x, hc_fn, hc_scale, hc_base)
871
+ logits = self.get_logits(norm(x))
872
+ if world_size > 1:
873
+ all_logits = [torch.empty_like(logits) for _ in range(world_size)]
874
+ dist.all_gather(all_logits, logits)
875
+ logits = torch.cat(all_logits, dim=-1)
876
+ return logits
877
+
878
+ def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
879
+ shape, dtype = x.size(), x.dtype
880
+ x = x.flatten(2).float()
881
+ rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
882
+ mixes = F.linear(x, hc_fn) * rsqrt
883
+ pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
884
+ y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
885
+ return y.to(dtype)
886
+
887
+
888
+ class MTPBlock(Block):
889
+
890
+ def __init__(self, layer_id: int, args: ModelArgs):
891
+ super().__init__(layer_id, args)
892
+ self.e_proj = Linear(args.dim, args.dim)
893
+ self.h_proj = Linear(args.dim, args.dim)
894
+ self.enorm = RMSNorm(args.dim, args.norm_eps)
895
+ self.hnorm = RMSNorm(args.dim, args.norm_eps)
896
+ self.norm = RMSNorm(args.dim, args.norm_eps)
897
+ self.hc_mult = hc_mult = args.hc_mult
898
+ hc_dim = hc_mult * args.dim
899
+ with set_dtype(torch.float32):
900
+ self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
901
+ self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
902
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
903
+ self.embed: ParallelEmbedding = None
904
+ self.head: ParallelHead = None
905
+
906
+ @torch.inference_mode()
907
+ def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
908
+ # x: [b,s,hc,d]
909
+ assert self.embed is not None and self.head is not None
910
+ e = self.embed(input_ids)
911
+ e = self.enorm(e)
912
+ x = self.hnorm(x)
913
+ x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
914
+ x = super().forward(x, start_pos, input_ids)
915
+ logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
916
+ return logits
917
+
918
+
919
+ class Transformer(nn.Module):
920
+ """Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
921
+ Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
922
+ def __init__(self, args: ModelArgs):
923
+ global world_size, rank, default_dtype, scale_fmt, scale_dtype, w4a16_mode
924
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
925
+ rank = dist.get_rank() if dist.is_initialized() else 0
926
+ w4a16_mode = (args.dtype == "w4a16")
927
+ if w4a16_mode:
928
+ default_dtype = torch.bfloat16
929
+ scale_fmt = None
930
+ scale_dtype = torch.float32
931
+ else:
932
+ default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
933
+ scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
934
+ scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
935
+ super().__init__()
936
+ self.max_seq_len = args.max_seq_len
937
+ self.norm_eps = args.norm_eps
938
+ self.hc_eps = args.hc_eps
939
+ self.embed = ParallelEmbedding(args.vocab_size, args.dim)
940
+ self.layers = torch.nn.ModuleList()
941
+ for layer_id in range(args.n_layers):
942
+ self.layers.append(Block(layer_id, args))
943
+ self.norm = RMSNorm(args.dim, self.norm_eps)
944
+ self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
945
+ self.mtp = torch.nn.ModuleList()
946
+ for layer_id in range(args.n_mtp_layers):
947
+ self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
948
+ self.mtp[-1].embed = self.embed
949
+ self.mtp[-1].head = self.head
950
+ self.hc_mult = hc_mult = args.hc_mult
951
+ hc_dim = hc_mult * args.dim
952
+ with set_dtype(torch.float32):
953
+ self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
954
+ self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
955
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
956
+
957
+ def init_woq_layers(self):
958
+ """After load_model(), convert all W4A16 parameters into QuantLinear layers."""
959
+ # from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear as QuantLinear
960
+ from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear as QuantLinear
961
+ for module in self.modules():
962
+ if hasattr(module, 'init_woq') and module is not self:
963
+ module.init_woq(QuantLinear)
964
+ torch.cuda.empty_cache()
965
+
966
+ @torch.inference_mode()
967
+ def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
968
+ h = self.embed(input_ids)
969
+ # Expand to hc_mult copies for Hyper-Connections
970
+ h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
971
+ for layer in self.layers:
972
+ h = layer(h, start_pos, input_ids)
973
+ logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
974
+ return logits
975
+
976
+
977
+ if __name__ == "__main__":
978
+ torch.set_default_dtype(torch.bfloat16)
979
+ torch.set_default_device("cuda")
980
+ torch.manual_seed(0)
981
+ args = ModelArgs(n_hash_layers=0)
982
+ x = torch.randint(0, args.vocab_size, (2, 128))
983
+ model = Transformer(args)
984
+
985
+ print(model(x).size())
986
+ for i in range(128, 150):
987
+ print(i, model(x[:, 0:1], i).size())
988
+
989
+ h = torch.randn(2, 128, args.hc_mult, args.dim)
990
+ mtp = model.mtp[0]
991
+ print(mtp(h, 0, x).size())
992
+ print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
inference/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.10.0
2
+ transformers>=5.0.0
3
+ safetensors>=0.7.0
4
+ fast_hadamard_transform
5
+ tilelang==0.1.8
6
+ gptqmodel==6.0.3
model-00001-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef64d991b80c86f24bd78d1bac9d452d95bc78e6b1d8feb6a182dae0240c7e5
3
+ size 1853358176
model-00002-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7d7f2862248d66164e5bf9007a9da3b0f1d6c0b9d5143c43dbc2d9ad4a7ff12
3
+ size 13390865040
model-00003-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:881002f38770cf6e0be5c24801d11c595674c3c4d7ff3485c45bf06d848e90a7
3
+ size 13390865040
model-00004-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c1c824e837c5c78d2467c70222e5b690c91a9bc772f6662ad48b39db06d116d
3
+ size 13403120848
model-00005-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c7cb7212f3374b69f29e1198b572392084135bdd0bda38282f8bd9a0c8ed46b
3
+ size 13384661096
model-00006-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:683444102f6eeae98ccaa00d1aada1e44253a80541a17d457382bd070947ac59
3
+ size 13396916904
model-00007-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:863e5634abb25b3cf83625ff90f3c5783b5619435ae184112f05c2d65c407a84
3
+ size 13384661096
model-00008-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a86efce4fc51c41911f78480863ba467d97f8a65ce03365d3f0a002b892674ad
3
+ size 13396916904
model-00009-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cdf0891b172d21168999bd0adb520b820a078351ee15a77d7b07867908347cd
3
+ size 13384661096
model-00010-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2615fd85e6a5913201d05b1cefe8e664dc910b88531ff1ae46da21ed41a9558e
3
+ size 13396916904
model-00011-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4922eb51e18579581a83bce28ab880b8dd5390cd4d03dd0c24e4b2c4f6fd3dc6
3
+ size 13384661096
model-00012-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1b289316453fe719094a30e5cea8e577787dc7560512f872db669962addac88
3
+ size 13396920416
model-00013-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53c5b20d54738bdcfc827f165a6543848b6ccf384ef66bf9402ea7244ef67c52
3
+ size 13384664600
model-00014-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7423ce801964f2af7f4056531432fcc1a140a7cd33a2a5b530fb30e49609e618
3
+ size 13396920416
model-00015-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d13483f05edd67b110bcf0cefe7a84f9bfe94ce47e654af4e60238955fa50989
3
+ size 13384664600
model-00016-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75704e031c9be328e8133ee2c6603a430bcfd61979f5a561964bf29d0134104b
3
+ size 13396920416
model-00017-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7f0c6876bade6ea29e2662ebae000ce3cfde96887c4ca66a90db77ee530db8b
3
+ size 13384664600
model-00018-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4abe0458401fc5e990e3699dbff2b98e05355f66bcc5d974ece79e7e14672765
3
+ size 13396920416
model-00019-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29238f11c48bbb56de1c513e9806edcaf63a142b191e36668b64d93df29b5238
3
+ size 13384664600
model-00020-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bd7d5d0e0337dc0af3b68287c237c6ac6c922596a64c88320fa8c62ddc32ceb
3
+ size 13396920416
model-00021-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:572a2a0e0910421cd7f4519da30c604cddf4a1aa61fe50e51788ac8f3e0875d7
3
+ size 13384664600
model-00022-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:077bce599585a64769465c6ef3614d4717184ce7f776162ee63884cc38e5ded3
3
+ size 13396920416
model-00023-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5ea118977b28e91802ff0df2cfd262c42eb2b439d6bd019bb2f830ea9ee7de2
3
+ size 13384664600