Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- DeepSeek_V4.pdf +3 -0
- LICENSE +21 -0
- README.md +54 -0
- assets/dsv4_performance.png +3 -0
- config.json +140 -0
- encoding/README.md +156 -0
- encoding/encoding_dsv4.py +744 -0
- encoding/test_encoding_dsv4.py +89 -0
- encoding/tests/test_input_1.json +81 -0
- encoding/tests/test_input_2.json +24 -0
- encoding/tests/test_input_3.json +159 -0
- encoding/tests/test_input_4.json +28 -0
- encoding/tests/test_output_1.txt +36 -0
- encoding/tests/test_output_2.txt +1 -0
- encoding/tests/test_output_3.txt +38 -0
- encoding/tests/test_output_4.txt +29 -0
- generation_config.json +9 -0
- inference/README.md +25 -0
- inference/config.json +35 -0
- inference/config_w4a16.json +34 -0
- inference/convert.py +168 -0
- inference/convert_w4a16.py +246 -0
- inference/generate.py +159 -0
- inference/kernel.py +536 -0
- inference/model.py +992 -0
- inference/requirements.txt +6 -0
- model-00001-of-00064.safetensors +3 -0
- model-00002-of-00064.safetensors +3 -0
- model-00003-of-00064.safetensors +3 -0
- model-00004-of-00064.safetensors +3 -0
- model-00005-of-00064.safetensors +3 -0
- model-00006-of-00064.safetensors +3 -0
- model-00007-of-00064.safetensors +3 -0
- model-00008-of-00064.safetensors +3 -0
- model-00009-of-00064.safetensors +3 -0
- model-00010-of-00064.safetensors +3 -0
- model-00011-of-00064.safetensors +3 -0
- model-00012-of-00064.safetensors +3 -0
- model-00013-of-00064.safetensors +3 -0
- model-00014-of-00064.safetensors +3 -0
- model-00015-of-00064.safetensors +3 -0
- model-00016-of-00064.safetensors +3 -0
- model-00017-of-00064.safetensors +3 -0
- model-00018-of-00064.safetensors +3 -0
- model-00019-of-00064.safetensors +3 -0
- model-00020-of-00064.safetensors +3 -0
- model-00021-of-00064.safetensors +3 -0
- model-00022-of-00064.safetensors +3 -0
- model-00023-of-00064.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
DeepSeek_V4.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa4a3490e2dcc03c9da61b04a8be471795e9966ebbbf292a3899fa62683a330e
|
| 3 |
+
size 4479901
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: transformers
|
| 4 |
+
base_model:
|
| 5 |
+
- deepseek-ai/DeepSeek-V4-Pro
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
This model is an int4 model with group_size 128 of [deepseek-ai/DeepSeek-V4-Pro](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro) generated by [intel/auto-round](https://github.com/intel/auto-round) with RTN mode. Please follow the license of the original model.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
## How to Run Locally
|
| 14 |
+
|
| 15 |
+
**vLLM and Sglang is not supported currently: https://huggingface.co/Intel/DeepSeek-V4-Flash-W4A16-AutoRound/discussions/1**
|
| 16 |
+
|
| 17 |
+
Please refer to the [inference](inference/README.md) folder for detailed instructions on running DeepSeek-V4 locally, including model weight conversion and interactive chat demos.
|
| 18 |
+
|
| 19 |
+
For local deployment, we recommend setting the sampling parameters to `temperature = 1.0, top_p = 1.0`. For the Think Max reasoning mode, we recommend setting the context window to at least **384K** tokens.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Generate the Model
|
| 23 |
+
|
| 24 |
+
This pr is required: [Support model_free WOQ quantization](https://github.com/intel/auto-round/pull/1699)
|
| 25 |
+
|
| 26 |
+
~~~bash
|
| 27 |
+
auto-round deepseek-ai/DeepSeek-V4-Pro --model_free --output_dir "./DeepSeek-V4-Pro-W4A16"
|
| 28 |
+
~~~
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## Ethical Considerations and Limitations
|
| 33 |
+
|
| 34 |
+
The model can produce factually incorrect output, and should not be relied on to produce factually accurate information. Because of the limitations of the pretrained model and the finetuning datasets, it is possible that this model could generate lewd, biased or otherwise offensive outputs.
|
| 35 |
+
|
| 36 |
+
Therefore, before deploying any applications of the model, developers should perform safety testing.
|
| 37 |
+
|
| 38 |
+
## Caveats and Recommendations
|
| 39 |
+
|
| 40 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model.
|
| 41 |
+
|
| 42 |
+
Here are a couple of useful links to learn more about Intel's AI software:
|
| 43 |
+
|
| 44 |
+
- [Intel Neural Compressor](https://github.com/intel/neural-compressor)
|
| 45 |
+
|
| 46 |
+
## Disclaimer
|
| 47 |
+
|
| 48 |
+
The license on this model does not constitute legal advice. We are not responsible for the actions of third parties who use this model. Please consult an attorney before using this model for commercial purposes.
|
| 49 |
+
|
| 50 |
+
## Cite
|
| 51 |
+
|
| 52 |
+
@article{cheng2023optimize, title={Optimize weight rounding via signed gradient descent for the quantization of llms}, author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi}, journal={arXiv preprint arXiv:2309.05516}, year={2023} }
|
| 53 |
+
|
| 54 |
+
[arxiv](https://arxiv.org/abs/2309.05516) [github](https://github.com/intel/auto-round)
|
assets/dsv4_performance.png
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"DeepseekV4ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"eos_token_id": 1,
|
| 9 |
+
"hc_eps": 1e-06,
|
| 10 |
+
"hc_mult": 4,
|
| 11 |
+
"hc_sinkhorn_iters": 20,
|
| 12 |
+
"head_dim": 512,
|
| 13 |
+
"hidden_act": "silu",
|
| 14 |
+
"hidden_size": 7168,
|
| 15 |
+
"index_head_dim": 128,
|
| 16 |
+
"index_n_heads": 64,
|
| 17 |
+
"index_topk": 1024,
|
| 18 |
+
"initializer_range": 0.02,
|
| 19 |
+
"max_position_embeddings": 1048576,
|
| 20 |
+
"model_type": "deepseek_v4",
|
| 21 |
+
"moe_intermediate_size": 3072,
|
| 22 |
+
"n_routed_experts": 384,
|
| 23 |
+
"n_shared_experts": 1,
|
| 24 |
+
"norm_topk_prob": true,
|
| 25 |
+
"num_attention_heads": 128,
|
| 26 |
+
"num_experts_per_tok": 6,
|
| 27 |
+
"num_hidden_layers": 61,
|
| 28 |
+
"num_hash_layers": 3,
|
| 29 |
+
"num_key_value_heads": 1,
|
| 30 |
+
"num_nextn_predict_layers": 1,
|
| 31 |
+
"o_groups": 16,
|
| 32 |
+
"o_lora_rank": 1024,
|
| 33 |
+
"q_lora_rank": 1536,
|
| 34 |
+
"qk_rope_head_dim": 64,
|
| 35 |
+
"quantization_config": {
|
| 36 |
+
"quant_method": "auto-round",
|
| 37 |
+
"packing_format": "auto_round:auto_gptq",
|
| 38 |
+
"bits": 4,
|
| 39 |
+
"group_size": 128,
|
| 40 |
+
"sym": true,
|
| 41 |
+
"data_type": "int",
|
| 42 |
+
"iters": 0,
|
| 43 |
+
"model_free": true,
|
| 44 |
+
"autoround_version": "0.13.0",
|
| 45 |
+
"extra_config": {
|
| 46 |
+
"embed": {
|
| 47 |
+
"bits": 16,
|
| 48 |
+
"data_type": "float"
|
| 49 |
+
},
|
| 50 |
+
"head": {
|
| 51 |
+
"bits": 16,
|
| 52 |
+
"data_type": "float"
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"rms_norm_eps": 1e-06,
|
| 57 |
+
"rope_scaling": {
|
| 58 |
+
"beta_fast": 32,
|
| 59 |
+
"beta_slow": 1,
|
| 60 |
+
"factor": 16,
|
| 61 |
+
"original_max_position_embeddings": 65536,
|
| 62 |
+
"type": "yarn"
|
| 63 |
+
},
|
| 64 |
+
"rope_theta": 10000,
|
| 65 |
+
"routed_scaling_factor": 2.5,
|
| 66 |
+
"scoring_func": "sqrtsoftplus",
|
| 67 |
+
"sliding_window": 128,
|
| 68 |
+
"swiglu_limit": 10.0,
|
| 69 |
+
"tie_word_embeddings": false,
|
| 70 |
+
"topk_method": "noaux_tc",
|
| 71 |
+
"torch_dtype": "bfloat16",
|
| 72 |
+
"transformers_version": "4.57.1",
|
| 73 |
+
"use_cache": true,
|
| 74 |
+
"vocab_size": 129280,
|
| 75 |
+
"compress_rope_theta": 160000,
|
| 76 |
+
"compress_ratios": [
|
| 77 |
+
128,
|
| 78 |
+
128,
|
| 79 |
+
4,
|
| 80 |
+
128,
|
| 81 |
+
4,
|
| 82 |
+
128,
|
| 83 |
+
4,
|
| 84 |
+
128,
|
| 85 |
+
4,
|
| 86 |
+
128,
|
| 87 |
+
4,
|
| 88 |
+
128,
|
| 89 |
+
4,
|
| 90 |
+
128,
|
| 91 |
+
4,
|
| 92 |
+
128,
|
| 93 |
+
4,
|
| 94 |
+
128,
|
| 95 |
+
4,
|
| 96 |
+
128,
|
| 97 |
+
4,
|
| 98 |
+
128,
|
| 99 |
+
4,
|
| 100 |
+
128,
|
| 101 |
+
4,
|
| 102 |
+
128,
|
| 103 |
+
4,
|
| 104 |
+
128,
|
| 105 |
+
4,
|
| 106 |
+
128,
|
| 107 |
+
4,
|
| 108 |
+
128,
|
| 109 |
+
4,
|
| 110 |
+
128,
|
| 111 |
+
4,
|
| 112 |
+
128,
|
| 113 |
+
4,
|
| 114 |
+
128,
|
| 115 |
+
4,
|
| 116 |
+
128,
|
| 117 |
+
4,
|
| 118 |
+
128,
|
| 119 |
+
4,
|
| 120 |
+
128,
|
| 121 |
+
4,
|
| 122 |
+
128,
|
| 123 |
+
4,
|
| 124 |
+
128,
|
| 125 |
+
4,
|
| 126 |
+
128,
|
| 127 |
+
4,
|
| 128 |
+
128,
|
| 129 |
+
4,
|
| 130 |
+
128,
|
| 131 |
+
4,
|
| 132 |
+
128,
|
| 133 |
+
4,
|
| 134 |
+
128,
|
| 135 |
+
4,
|
| 136 |
+
128,
|
| 137 |
+
4,
|
| 138 |
+
0
|
| 139 |
+
]
|
| 140 |
+
}
|
encoding/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepSeek-V4 Encoding
|
| 2 |
+
|
| 3 |
+
This document describes the prompt encoding format used by DeepSeek-V4 series models. The encoding handles multi-turn conversations, tool calling, extended thinking (reasoning), and quick instruction tasks.
|
| 4 |
+
|
| 5 |
+
A self-contained reference implementation is provided in `encoding_dsv4.py`.
|
| 6 |
+
|
| 7 |
+
## Quick Start
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 11 |
+
|
| 12 |
+
# Encode a conversation
|
| 13 |
+
messages = [
|
| 14 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 15 |
+
{"role": "user", "content": "What is 2+2?"},
|
| 16 |
+
]
|
| 17 |
+
prompt = encode_messages(messages, thinking_mode="thinking")
|
| 18 |
+
# => "<|begin▁of▁sentence|>You are a helpful assistant.<|User|>What is 2+2?<|Assistant|><think>"
|
| 19 |
+
|
| 20 |
+
# Parse model output back to structured message
|
| 21 |
+
completion = "Simple arithmetic.</think>2 + 2 = 4.<|end▁of▁sentence|>"
|
| 22 |
+
parsed = parse_message_from_completion_text(completion, thinking_mode="thinking")
|
| 23 |
+
# => {"role": "assistant", "reasoning_content": "Simple arithmetic.", "content": "2 + 2 = 4.", "tool_calls": []}
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
> **Note:** The `parse_message_from_completion_text` function is designed to handle well-formatted model output only. It does not attempt to correct or recover from malformed output that the model might occasionally generate. For production use, additional error handling is recommended.
|
| 27 |
+
|
| 28 |
+
## Message Format
|
| 29 |
+
|
| 30 |
+
### Special Tokens
|
| 31 |
+
|
| 32 |
+
| Token | Purpose |
|
| 33 |
+
|-------|---------|
|
| 34 |
+
| `<|begin▁of▁sentence|>` | Beginning of sequence (BOS) |
|
| 35 |
+
| `<|end▁of▁sentence|>` | End of assistant turn (EOS) |
|
| 36 |
+
| `<|User|>` | User turn prefix |
|
| 37 |
+
| `<|Assistant|>` | Assistant turn prefix |
|
| 38 |
+
| `<|latest_reminder|>` | Latest reminder (date, locale, etc.) |
|
| 39 |
+
| `<think>` / `</think>` | Reasoning block delimiters |
|
| 40 |
+
| `|DSML|` | DSML markup token |
|
| 41 |
+
|
| 42 |
+
### Roles
|
| 43 |
+
|
| 44 |
+
The encoding supports the following message roles: `system`, `user`, `assistant`, `tool`, `latest_reminder`, and `developer`.
|
| 45 |
+
|
| 46 |
+
> **Note on the `developer` role:** The `developer` role is used exclusively in the internal search agent pipeline. It is not needed for general-purpose chat or tool-calling tasks, and the official API does not accept messages with this role.
|
| 47 |
+
|
| 48 |
+
### Basic Chat
|
| 49 |
+
|
| 50 |
+
A simple multi-turn conversation is encoded as:
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
<|begin▁of▁sentence|>{system_prompt}
|
| 54 |
+
<|User|>{user_message}<|Assistant|></think>{response}<|end▁of▁sentence|>
|
| 55 |
+
<|User|>{user_message_2}<|Assistant|></think>{response_2}<|end▁of▁sentence|>
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
- The BOS token is prepended at the very beginning of the conversation.
|
| 59 |
+
- In **chat mode** (`thinking_mode="chat"`), `</think>` is placed right after `<|Assistant|>` to immediately close the thinking block, so the model generates content directly.
|
| 60 |
+
|
| 61 |
+
### Interleaved Thinking Mode
|
| 62 |
+
|
| 63 |
+
In **thinking mode** (`thinking_mode="thinking"`), the model produces explicit reasoning inside `<think>...</think>` blocks before responding.
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
<|begin▁of▁sentence|>{system_prompt}
|
| 67 |
+
<|User|>{message}<|Assistant|><think>{reasoning}</think>{response}<|end▁of▁sentence|>
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
The `drop_thinking` parameter (default `True`) controls whether reasoning from earlier turns is preserved:
|
| 71 |
+
|
| 72 |
+
- **Without tools**: `drop_thinking` takes effect. Reasoning content from assistant turns **before** the last user message is stripped. Only the final assistant turn retains its `<think>...</think>` block.
|
| 73 |
+
- **With tools** (on system or developer message): `drop_thinking` is automatically disabled. All turns retain their reasoning, because tool-calling conversations require full context for the model to track multi-step reasoning across tool calls.
|
| 74 |
+
|
| 75 |
+
### Tool Calling (DSML Format)
|
| 76 |
+
|
| 77 |
+
Tools are defined on the `system` or `developer` message via the `tools` field (OpenAI-compatible format). When tools are present, the following schema block is injected into the system/user prompt:
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
## Tools
|
| 81 |
+
|
| 82 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
| 83 |
+
|
| 84 |
+
<|DSML|tool_calls>
|
| 85 |
+
<|DSML|invoke name="$TOOL_NAME">
|
| 86 |
+
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
| 87 |
+
...
|
| 88 |
+
</|DSML|invoke>
|
| 89 |
+
<|DSML|invoke name="$TOOL_NAME2">
|
| 90 |
+
...
|
| 91 |
+
</|DSML|invoke>
|
| 92 |
+
</|DSML|tool_calls>
|
| 93 |
+
|
| 94 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 95 |
+
|
| 96 |
+
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
| 97 |
+
|
| 98 |
+
Otherwise, output directly after </think> with tool calls or final response.
|
| 99 |
+
|
| 100 |
+
### Available Tool Schemas
|
| 101 |
+
|
| 102 |
+
{tool_definitions_json}
|
| 103 |
+
|
| 104 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
An actual tool call in the assistant turn looks like:
|
| 108 |
+
|
| 109 |
+
```xml
|
| 110 |
+
<|DSML|tool_calls>
|
| 111 |
+
<|DSML|invoke name="function_name">
|
| 112 |
+
<|DSML|parameter name="param" string="true">string_value</|DSML|parameter>
|
| 113 |
+
<|DSML|parameter name="count" string="false">5</|DSML|parameter>
|
| 114 |
+
</|DSML|invoke>
|
| 115 |
+
</|DSML|tool_calls><|end▁of▁sentence|>
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
- `string="true"`: the parameter value is a raw string.
|
| 119 |
+
- `string="false"`: the parameter value is JSON (number, boolean, array, object).
|
| 120 |
+
|
| 121 |
+
Tool execution results are wrapped in `<tool_result>` tags within user messages:
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
<|User|><tool_result>{result_json}</tool_result><|Assistant|><think>...
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
When multiple tool results are present, they are sorted by the order of the corresponding `tool_calls` in the preceding assistant message.
|
| 128 |
+
|
| 129 |
+
### Reasoning Effort
|
| 130 |
+
|
| 131 |
+
When `reasoning_effort="max"` is set, a special prefix is prepended at the very beginning of the prompt (before the system message) to instruct the model to maximize its reasoning depth:
|
| 132 |
+
|
| 133 |
+
```
|
| 134 |
+
Reasoning Effort: Absolute maximum with no shortcuts permitted.
|
| 135 |
+
You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.
|
| 136 |
+
Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Quick Instruction Special Tokens
|
| 140 |
+
|
| 141 |
+
Quick instruction tokens are used for auxiliary classification and generation tasks. They are appended to messages via the `"task"` field to trigger specialized model behavior for a single-token or short-form output.
|
| 142 |
+
|
| 143 |
+
| Special Token | Description | Format |
|
| 144 |
+
|:---|:---|:---|
|
| 145 |
+
| `<|action|>` | Determines whether the user prompt requires a web search or can be answered directly. | `...<|User|>{prompt}<|Assistant|><think><|action|>` |
|
| 146 |
+
| `<|title|>` | Generates a concise conversation title after the first assistant response. | `...<|Assistant|>{response}<|end▁of▁sentence|><|title|>` |
|
| 147 |
+
| `<|query|>` | Generates search queries for the user prompt. | `...<|User|>{prompt}<|query|>` |
|
| 148 |
+
| `<|authority|>` | Classifies the user prompt's demand for source authoritativeness. | `...<|User|>{prompt}<|authority|>` |
|
| 149 |
+
| `<|domain|>` | Identifies the domain of the user prompt. | `...<|User|>{prompt}<|domain|>` |
|
| 150 |
+
| `<|extracted_url|>` `<|read_url|>` | Determines whether each URL in the user prompt should be fetched and read. | `...<|User|>{prompt}<|extracted_url|>{url}<|read_url|>` |
|
| 151 |
+
|
| 152 |
+
Usage in message format:
|
| 153 |
+
|
| 154 |
+
- **`action`** on a user message: the `<|action|>` token is placed after the assistant prefix and thinking token, triggering a routing decision (e.g., "Search" or "Answer").
|
| 155 |
+
- **Other tasks** (`query`, `authority`, `domain`, `read_url`) on a user message: the task token is appended directly after the user content.
|
| 156 |
+
- **`title`** on an assistant message: the `<|title|>` token is appended after the assistant's EOS. The next assistant message provides the generated title.
|
encoding/encoding_dsv4.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepSeek-V4 Encoding
|
| 3 |
+
|
| 4 |
+
A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
|
| 5 |
+
with tool calling, thinking mode, and quick instruction task support.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List, Union, Optional, Tuple
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
# ============================================================
|
| 14 |
+
# Special Tokens
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
bos_token: str = "<|begin▁of▁sentence|>"
|
| 18 |
+
eos_token: str = "<|end▁of▁sentence|>"
|
| 19 |
+
thinking_start_token: str = "<think>"
|
| 20 |
+
thinking_end_token: str = "</think>"
|
| 21 |
+
dsml_token: str = "|DSML|"
|
| 22 |
+
|
| 23 |
+
USER_SP_TOKEN = "<|User|>"
|
| 24 |
+
ASSISTANT_SP_TOKEN = "<|Assistant|>"
|
| 25 |
+
LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
|
| 26 |
+
|
| 27 |
+
# Task special tokens for internal classification tasks
|
| 28 |
+
DS_TASK_SP_TOKENS = {
|
| 29 |
+
"action": "<|action|>",
|
| 30 |
+
"query": "<|query|>",
|
| 31 |
+
"authority": "<|authority|>",
|
| 32 |
+
"domain": "<|domain|>",
|
| 33 |
+
"title": "<|title|>",
|
| 34 |
+
"read_url": "<|read_url|>",
|
| 35 |
+
}
|
| 36 |
+
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
|
| 37 |
+
|
| 38 |
+
# ============================================================
|
| 39 |
+
# Templates
|
| 40 |
+
# ============================================================
|
| 41 |
+
|
| 42 |
+
system_msg_template: str = "{content}"
|
| 43 |
+
user_msg_template: str = "{content}"
|
| 44 |
+
latest_reminder_msg_template: str = "{content}"
|
| 45 |
+
assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
|
| 46 |
+
assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
|
| 47 |
+
thinking_template: str = "{reasoning_content}"
|
| 48 |
+
|
| 49 |
+
response_format_template: str = (
|
| 50 |
+
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
|
| 51 |
+
)
|
| 52 |
+
tool_call_template: str = (
|
| 53 |
+
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
|
| 54 |
+
)
|
| 55 |
+
tool_calls_template = (
|
| 56 |
+
"<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
|
| 57 |
+
)
|
| 58 |
+
tool_calls_block_name: str = "tool_calls"
|
| 59 |
+
|
| 60 |
+
tool_output_template: str = (
|
| 61 |
+
"<tool_result>{content}</tool_result>"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
REASONING_EFFORT_MAX = (
|
| 65 |
+
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
|
| 66 |
+
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
|
| 67 |
+
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
TOOLS_TEMPLATE = """## Tools
|
| 71 |
+
|
| 72 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
|
| 73 |
+
|
| 74 |
+
<{dsml_token}tool_calls>
|
| 75 |
+
<{dsml_token}invoke name="$TOOL_NAME">
|
| 76 |
+
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
|
| 77 |
+
...
|
| 78 |
+
</{dsml_token}invoke>
|
| 79 |
+
<{dsml_token}invoke name="$TOOL_NAME2">
|
| 80 |
+
...
|
| 81 |
+
</{dsml_token}invoke>
|
| 82 |
+
</{dsml_token}tool_calls>
|
| 83 |
+
|
| 84 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 85 |
+
|
| 86 |
+
If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
|
| 87 |
+
|
| 88 |
+
Otherwise, output directly after {thinking_end_token} with tool calls or final response.
|
| 89 |
+
|
| 90 |
+
### Available Tool Schemas
|
| 91 |
+
|
| 92 |
+
{tool_schemas}
|
| 93 |
+
|
| 94 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
# ============================================================
|
| 98 |
+
# Utility Functions
|
| 99 |
+
# ============================================================
|
| 100 |
+
|
| 101 |
+
def to_json(value: Any) -> str:
|
| 102 |
+
"""Serialize a value to JSON string."""
|
| 103 |
+
try:
|
| 104 |
+
return json.dumps(value, ensure_ascii=False)
|
| 105 |
+
except:
|
| 106 |
+
return json.dumps(value, ensure_ascii=True)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def tools_from_openai_format(tools):
|
| 110 |
+
"""Extract function definitions from OpenAI-format tool list."""
|
| 111 |
+
return [tool["function"] for tool in tools]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def tool_calls_from_openai_format(tool_calls):
|
| 115 |
+
"""Convert OpenAI-format tool calls to internal format."""
|
| 116 |
+
return [
|
| 117 |
+
{
|
| 118 |
+
"name": tool_call["function"]["name"],
|
| 119 |
+
"arguments": tool_call["function"]["arguments"],
|
| 120 |
+
}
|
| 121 |
+
for tool_call in tool_calls
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def tool_calls_to_openai_format(tool_calls):
|
| 126 |
+
"""Convert internal tool calls to OpenAI format."""
|
| 127 |
+
return [
|
| 128 |
+
{
|
| 129 |
+
"type": "function",
|
| 130 |
+
"function": {
|
| 131 |
+
"name": tool_call["name"],
|
| 132 |
+
"arguments": tool_call["arguments"],
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
for tool_call in tool_calls
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Encode tool call arguments into DSML parameter format.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
tool_call: Dict with "name" and "arguments" (JSON string) keys.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
DSML-formatted parameter string.
|
| 148 |
+
"""
|
| 149 |
+
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
|
| 150 |
+
P_dsml_strs = []
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
arguments = json.loads(tool_call["arguments"])
|
| 154 |
+
except Exception as err:
|
| 155 |
+
arguments = {"arguments": tool_call["arguments"]}
|
| 156 |
+
|
| 157 |
+
for k, v in arguments.items():
|
| 158 |
+
p_dsml_str = p_dsml_template.format(
|
| 159 |
+
dsml_token=dsml_token,
|
| 160 |
+
key=k,
|
| 161 |
+
is_str="true" if isinstance(v, str) else "false",
|
| 162 |
+
value=v if isinstance(v, str) else to_json(v),
|
| 163 |
+
)
|
| 164 |
+
P_dsml_strs.append(p_dsml_str)
|
| 165 |
+
|
| 166 |
+
return "\n".join(P_dsml_strs)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
|
| 170 |
+
"""
|
| 171 |
+
Decode DSML parameters back to a tool call dict.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
tool_name: Name of the tool.
|
| 175 |
+
tool_args: Dict mapping param_name -> (value, is_string_flag).
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Dict with "name" and "arguments" (JSON string) keys.
|
| 179 |
+
"""
|
| 180 |
+
def _decode_value(key: str, value: str, string: str):
|
| 181 |
+
if string == "true":
|
| 182 |
+
value = to_json(value)
|
| 183 |
+
return f"{to_json(key)}: {value}"
|
| 184 |
+
|
| 185 |
+
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
|
| 186 |
+
return dict(name=tool_name, arguments=tool_args_json)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
|
| 190 |
+
"""
|
| 191 |
+
Render tool schemas into the system prompt format.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
tools: List of tool schema dicts (each with name, description, parameters).
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Formatted tools section string.
|
| 198 |
+
"""
|
| 199 |
+
tools_json = [to_json(t) for t in tools]
|
| 200 |
+
|
| 201 |
+
return TOOLS_TEMPLATE.format(
|
| 202 |
+
tool_schemas="\n".join(tools_json),
|
| 203 |
+
dsml_token=dsml_token,
|
| 204 |
+
thinking_start_token=thinking_start_token,
|
| 205 |
+
thinking_end_token=thinking_end_token,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
|
| 210 |
+
"""Find the index of the last user/developer message."""
|
| 211 |
+
last_user_index = -1
|
| 212 |
+
for idx in range(len(messages) - 1, -1, -1):
|
| 213 |
+
if messages[idx].get("role") in ["user", "developer"]:
|
| 214 |
+
last_user_index = idx
|
| 215 |
+
break
|
| 216 |
+
return last_user_index
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ============================================================
|
| 220 |
+
# Message Rendering
|
| 221 |
+
# ============================================================
|
| 222 |
+
|
| 223 |
+
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
|
| 224 |
+
"""
|
| 225 |
+
Render a single message at the given index into its encoded string form.
|
| 226 |
+
|
| 227 |
+
This is the core function that converts each message in the conversation
|
| 228 |
+
into the DeepSeek-V4 format.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
index: Index of the message to render.
|
| 232 |
+
messages: Full list of messages in the conversation.
|
| 233 |
+
thinking_mode: Either "chat" or "thinking".
|
| 234 |
+
drop_thinking: Whether to drop reasoning content from earlier turns.
|
| 235 |
+
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Encoded string for this message.
|
| 239 |
+
"""
|
| 240 |
+
assert 0 <= index < len(messages)
|
| 241 |
+
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
|
| 242 |
+
|
| 243 |
+
prompt = ""
|
| 244 |
+
msg = messages[index]
|
| 245 |
+
last_user_idx = find_last_user_index(messages)
|
| 246 |
+
|
| 247 |
+
role = msg.get("role")
|
| 248 |
+
content = msg.get("content")
|
| 249 |
+
tools = msg.get("tools")
|
| 250 |
+
response_format = msg.get("response_format")
|
| 251 |
+
tool_calls = msg.get("tool_calls")
|
| 252 |
+
reasoning_content = msg.get("reasoning_content")
|
| 253 |
+
wo_eos = msg.get("wo_eos", False)
|
| 254 |
+
|
| 255 |
+
if tools:
|
| 256 |
+
tools = tools_from_openai_format(tools)
|
| 257 |
+
if tool_calls:
|
| 258 |
+
tool_calls = tool_calls_from_openai_format(tool_calls)
|
| 259 |
+
|
| 260 |
+
# Reasoning effort prefix (only at index 0 in thinking mode with max effort)
|
| 261 |
+
assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
|
| 262 |
+
if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
|
| 263 |
+
prompt += REASONING_EFFORT_MAX
|
| 264 |
+
|
| 265 |
+
if role == "system":
|
| 266 |
+
prompt += system_msg_template.format(content=content or "")
|
| 267 |
+
if tools:
|
| 268 |
+
prompt += "\n\n" + render_tools(tools)
|
| 269 |
+
if response_format:
|
| 270 |
+
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
| 271 |
+
|
| 272 |
+
elif role == "developer":
|
| 273 |
+
assert content, f"Invalid message for role `{role}`: {msg}"
|
| 274 |
+
|
| 275 |
+
content_developer = USER_SP_TOKEN
|
| 276 |
+
content_developer += content
|
| 277 |
+
|
| 278 |
+
if tools:
|
| 279 |
+
content_developer += "\n\n" + render_tools(tools)
|
| 280 |
+
if response_format:
|
| 281 |
+
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
| 282 |
+
|
| 283 |
+
prompt += user_msg_template.format(content=content_developer)
|
| 284 |
+
|
| 285 |
+
elif role == "user":
|
| 286 |
+
prompt += USER_SP_TOKEN
|
| 287 |
+
|
| 288 |
+
# Handle content blocks (tool results mixed with text)
|
| 289 |
+
content_blocks = msg.get("content_blocks")
|
| 290 |
+
if content_blocks:
|
| 291 |
+
parts = []
|
| 292 |
+
for block in content_blocks:
|
| 293 |
+
block_type = block.get("type")
|
| 294 |
+
if block_type == "text":
|
| 295 |
+
parts.append(block.get("text", ""))
|
| 296 |
+
elif block_type == "tool_result":
|
| 297 |
+
tool_content = block.get("content", "")
|
| 298 |
+
if isinstance(tool_content, list):
|
| 299 |
+
text_parts = []
|
| 300 |
+
for b in tool_content:
|
| 301 |
+
if b.get("type") == "text":
|
| 302 |
+
text_parts.append(b.get("text", ""))
|
| 303 |
+
else:
|
| 304 |
+
text_parts.append(f"[Unsupported {b.get('type')}]")
|
| 305 |
+
tool_content = "\n\n".join(text_parts)
|
| 306 |
+
parts.append(tool_output_template.format(content=tool_content))
|
| 307 |
+
else:
|
| 308 |
+
parts.append(f"[Unsupported {block_type}]")
|
| 309 |
+
prompt += "\n\n".join(parts)
|
| 310 |
+
else:
|
| 311 |
+
prompt += content or ""
|
| 312 |
+
|
| 313 |
+
elif role == "latest_reminder":
|
| 314 |
+
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
|
| 315 |
+
|
| 316 |
+
elif role == "tool":
|
| 317 |
+
raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
|
| 318 |
+
|
| 319 |
+
elif role == "assistant":
|
| 320 |
+
thinking_part = ""
|
| 321 |
+
tc_content = ""
|
| 322 |
+
|
| 323 |
+
if tool_calls:
|
| 324 |
+
tc_list = [
|
| 325 |
+
tool_call_template.format(
|
| 326 |
+
dsml_token=dsml_token,
|
| 327 |
+
name=tc.get("name"),
|
| 328 |
+
arguments=encode_arguments_to_dsml(tc)
|
| 329 |
+
)
|
| 330 |
+
for tc in tool_calls
|
| 331 |
+
]
|
| 332 |
+
tc_content += '\n\n' + tool_calls_template.format(
|
| 333 |
+
dsml_token=dsml_token,
|
| 334 |
+
tool_calls="\n".join(tc_list),
|
| 335 |
+
tc_block_name=tool_calls_block_name,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
summary_content = content or ""
|
| 339 |
+
rc = reasoning_content or ""
|
| 340 |
+
|
| 341 |
+
# Check if previous message has a task - if so, this is a task output (no thinking)
|
| 342 |
+
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
|
| 343 |
+
|
| 344 |
+
if thinking_mode == "thinking" and not prev_has_task:
|
| 345 |
+
if not drop_thinking or index > last_user_idx:
|
| 346 |
+
thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
|
| 347 |
+
else:
|
| 348 |
+
thinking_part = ""
|
| 349 |
+
|
| 350 |
+
if wo_eos:
|
| 351 |
+
prompt += assistant_msg_wo_eos_template.format(
|
| 352 |
+
reasoning=thinking_part,
|
| 353 |
+
content=summary_content,
|
| 354 |
+
tool_calls=tc_content,
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
prompt += assistant_msg_template.format(
|
| 358 |
+
reasoning=thinking_part,
|
| 359 |
+
content=summary_content,
|
| 360 |
+
tool_calls=tc_content,
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
raise NotImplementedError(f"Unknown role: {role}")
|
| 364 |
+
|
| 365 |
+
# Append transition tokens based on what follows
|
| 366 |
+
if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
|
| 367 |
+
return prompt
|
| 368 |
+
|
| 369 |
+
task = messages[index].get("task")
|
| 370 |
+
if task is not None:
|
| 371 |
+
# Task special token for internal classification tasks
|
| 372 |
+
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
|
| 373 |
+
task_sp_token = DS_TASK_SP_TOKENS[task]
|
| 374 |
+
|
| 375 |
+
if task != "action":
|
| 376 |
+
# Non-action tasks: append task sp token directly after the message
|
| 377 |
+
prompt += task_sp_token
|
| 378 |
+
else:
|
| 379 |
+
# Action task: append Assistant + thinking token + action sp token
|
| 380 |
+
prompt += ASSISTANT_SP_TOKEN
|
| 381 |
+
prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
|
| 382 |
+
prompt += task_sp_token
|
| 383 |
+
|
| 384 |
+
elif messages[index].get("role") in ["user", "developer"]:
|
| 385 |
+
# Normal generation: append Assistant + thinking token
|
| 386 |
+
prompt += ASSISTANT_SP_TOKEN
|
| 387 |
+
if not drop_thinking and thinking_mode == "thinking":
|
| 388 |
+
prompt += thinking_start_token
|
| 389 |
+
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
|
| 390 |
+
prompt += thinking_start_token
|
| 391 |
+
else:
|
| 392 |
+
prompt += thinking_end_token
|
| 393 |
+
|
| 394 |
+
return prompt
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# ============================================================
|
| 398 |
+
# Preprocessing
|
| 399 |
+
# ============================================================
|
| 400 |
+
|
| 401 |
+
def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 402 |
+
"""
|
| 403 |
+
Merge tool messages into the preceding user message using content_blocks format.
|
| 404 |
+
|
| 405 |
+
DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
|
| 406 |
+
are encoded as <tool_result> blocks within user messages.
|
| 407 |
+
|
| 408 |
+
This function converts a standard OpenAI-format conversation (with separate
|
| 409 |
+
"tool" role messages) into V4 format where tool results are merged into
|
| 410 |
+
user messages.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
messages: List of message dicts in OpenAI format.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
Processed message list with tool messages merged into user messages.
|
| 417 |
+
"""
|
| 418 |
+
merged: List[Dict[str, Any]] = []
|
| 419 |
+
|
| 420 |
+
for msg in messages:
|
| 421 |
+
msg = copy.deepcopy(msg)
|
| 422 |
+
role = msg.get("role")
|
| 423 |
+
|
| 424 |
+
if role == "tool":
|
| 425 |
+
# Convert tool message to a user message with tool_result block
|
| 426 |
+
tool_block = {
|
| 427 |
+
"type": "tool_result",
|
| 428 |
+
"tool_use_id": msg.get("tool_call_id", ""),
|
| 429 |
+
"content": msg.get("content", ""),
|
| 430 |
+
}
|
| 431 |
+
# Merge into previous message if it's already a user (merged tool)
|
| 432 |
+
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
|
| 433 |
+
merged[-1]["content_blocks"].append(tool_block)
|
| 434 |
+
else:
|
| 435 |
+
merged.append({
|
| 436 |
+
"role": "user",
|
| 437 |
+
"content_blocks": [tool_block],
|
| 438 |
+
})
|
| 439 |
+
elif role == "user":
|
| 440 |
+
text_block = {"type": "text", "text": msg.get("content", "")}
|
| 441 |
+
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
|
| 442 |
+
merged[-1]["content_blocks"].append(text_block)
|
| 443 |
+
else:
|
| 444 |
+
new_msg = {
|
| 445 |
+
"role": "user",
|
| 446 |
+
"content": msg.get("content", ""),
|
| 447 |
+
"content_blocks": [text_block],
|
| 448 |
+
}
|
| 449 |
+
# Preserve extra fields (task, wo_eos, mask, etc.)
|
| 450 |
+
for key in ("task", "wo_eos", "mask"):
|
| 451 |
+
if key in msg:
|
| 452 |
+
new_msg[key] = msg[key]
|
| 453 |
+
merged.append(new_msg)
|
| 454 |
+
else:
|
| 455 |
+
merged.append(msg)
|
| 456 |
+
|
| 457 |
+
return merged
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 461 |
+
"""
|
| 462 |
+
Sort tool_result blocks within user messages by the order of tool_calls
|
| 463 |
+
in the preceding assistant message.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
messages: Preprocessed message list (after merge_tool_messages).
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
Message list with sorted tool result blocks.
|
| 470 |
+
"""
|
| 471 |
+
last_tool_call_order: Dict[str, int] = {}
|
| 472 |
+
|
| 473 |
+
for msg in messages:
|
| 474 |
+
role = msg.get("role")
|
| 475 |
+
if role == "assistant" and msg.get("tool_calls"):
|
| 476 |
+
last_tool_call_order = {}
|
| 477 |
+
for idx, tc in enumerate(msg["tool_calls"]):
|
| 478 |
+
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
|
| 479 |
+
if tc_id:
|
| 480 |
+
last_tool_call_order[tc_id] = idx
|
| 481 |
+
|
| 482 |
+
elif role == "user" and msg.get("content_blocks"):
|
| 483 |
+
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
|
| 484 |
+
if len(tool_blocks) > 1 and last_tool_call_order:
|
| 485 |
+
sorted_blocks = sorted(
|
| 486 |
+
tool_blocks,
|
| 487 |
+
key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
|
| 488 |
+
)
|
| 489 |
+
sorted_idx = 0
|
| 490 |
+
new_blocks = []
|
| 491 |
+
for block in msg["content_blocks"]:
|
| 492 |
+
if block.get("type") == "tool_result":
|
| 493 |
+
new_blocks.append(sorted_blocks[sorted_idx])
|
| 494 |
+
sorted_idx += 1
|
| 495 |
+
else:
|
| 496 |
+
new_blocks.append(block)
|
| 497 |
+
msg["content_blocks"] = new_blocks
|
| 498 |
+
|
| 499 |
+
return messages
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# ============================================================
|
| 503 |
+
# Main Encoding Function
|
| 504 |
+
# ============================================================
|
| 505 |
+
|
| 506 |
+
def encode_messages(
|
| 507 |
+
messages: List[Dict[str, Any]],
|
| 508 |
+
thinking_mode: str,
|
| 509 |
+
context: Optional[List[Dict[str, Any]]] = None,
|
| 510 |
+
drop_thinking: bool = True,
|
| 511 |
+
add_default_bos_token: bool = True,
|
| 512 |
+
reasoning_effort: Optional[str] = None,
|
| 513 |
+
) -> str:
|
| 514 |
+
"""
|
| 515 |
+
Encode a list of messages into the DeepSeek-V4 prompt format.
|
| 516 |
+
|
| 517 |
+
This is the main entry point for encoding conversations. It handles:
|
| 518 |
+
- BOS token insertion
|
| 519 |
+
- Thinking mode with optional reasoning content dropping
|
| 520 |
+
- Tool message merging into user messages
|
| 521 |
+
- Multi-turn conversation context
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
messages: List of message dicts to encode.
|
| 525 |
+
thinking_mode: Either "chat" or "thinking".
|
| 526 |
+
context: Optional preceding context messages (already encoded prefix).
|
| 527 |
+
drop_thinking: If True, drop reasoning_content from earlier assistant turns
|
| 528 |
+
(only keep reasoning for messages after the last user message).
|
| 529 |
+
add_default_bos_token: Whether to prepend BOS token at conversation start.
|
| 530 |
+
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
The encoded prompt string.
|
| 534 |
+
"""
|
| 535 |
+
context = context if context else []
|
| 536 |
+
|
| 537 |
+
# Preprocess: merge tool messages and sort tool results
|
| 538 |
+
messages = merge_tool_messages(messages)
|
| 539 |
+
messages = sort_tool_results_by_call_order(context + messages)[len(context):]
|
| 540 |
+
if context:
|
| 541 |
+
context = merge_tool_messages(context)
|
| 542 |
+
context = sort_tool_results_by_call_order(context)
|
| 543 |
+
|
| 544 |
+
full_messages = context + messages
|
| 545 |
+
|
| 546 |
+
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
|
| 547 |
+
|
| 548 |
+
# Resolve drop_thinking: if any message has tools defined, don't drop thinking
|
| 549 |
+
effective_drop_thinking = drop_thinking
|
| 550 |
+
if any(m.get("tools") for m in full_messages):
|
| 551 |
+
effective_drop_thinking = False
|
| 552 |
+
|
| 553 |
+
if thinking_mode == "thinking" and effective_drop_thinking:
|
| 554 |
+
full_messages = _drop_thinking_messages(full_messages)
|
| 555 |
+
# After dropping, recalculate how many messages to render
|
| 556 |
+
# (context may have shrunk too)
|
| 557 |
+
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
|
| 558 |
+
context_len = len(full_messages) - num_to_render
|
| 559 |
+
else:
|
| 560 |
+
num_to_render = len(messages)
|
| 561 |
+
context_len = len(context)
|
| 562 |
+
|
| 563 |
+
for idx in range(num_to_render):
|
| 564 |
+
prompt += render_message(
|
| 565 |
+
idx + context_len,
|
| 566 |
+
full_messages,
|
| 567 |
+
thinking_mode=thinking_mode,
|
| 568 |
+
drop_thinking=effective_drop_thinking,
|
| 569 |
+
reasoning_effort=reasoning_effort,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
return prompt
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 576 |
+
"""
|
| 577 |
+
Drop reasoning_content and non-essential messages before the last user message.
|
| 578 |
+
|
| 579 |
+
Behavior:
|
| 580 |
+
- Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
|
| 581 |
+
- Messages at or after the last user index are always kept.
|
| 582 |
+
- Assistant messages before the last user get reasoning_content removed.
|
| 583 |
+
- Developer messages before the last user are dropped entirely.
|
| 584 |
+
"""
|
| 585 |
+
last_user_idx = find_last_user_index(messages)
|
| 586 |
+
result = []
|
| 587 |
+
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
|
| 588 |
+
|
| 589 |
+
for idx, msg in enumerate(messages):
|
| 590 |
+
role = msg.get("role")
|
| 591 |
+
if role in keep_roles or idx >= last_user_idx:
|
| 592 |
+
result.append(msg)
|
| 593 |
+
elif role == "assistant":
|
| 594 |
+
msg = copy.copy(msg)
|
| 595 |
+
msg.pop("reasoning_content", None)
|
| 596 |
+
result.append(msg)
|
| 597 |
+
# developer and other roles before last_user_idx are dropped
|
| 598 |
+
|
| 599 |
+
return result
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# ============================================================
|
| 603 |
+
# Parsing (Decoding model output)
|
| 604 |
+
# ============================================================
|
| 605 |
+
|
| 606 |
+
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
|
| 607 |
+
"""
|
| 608 |
+
Read text from index until one of the stop strings is found.
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
|
| 612 |
+
"""
|
| 613 |
+
min_pos = len(text)
|
| 614 |
+
matched_stop = None
|
| 615 |
+
|
| 616 |
+
for s in stop:
|
| 617 |
+
pos = text.find(s, index)
|
| 618 |
+
if pos != -1 and pos < min_pos:
|
| 619 |
+
min_pos = pos
|
| 620 |
+
matched_stop = s
|
| 621 |
+
|
| 622 |
+
if matched_stop:
|
| 623 |
+
content = text[index:min_pos]
|
| 624 |
+
return min_pos + len(matched_stop), content, matched_stop
|
| 625 |
+
else:
|
| 626 |
+
content = text[index:]
|
| 627 |
+
return len(text), content, None
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
|
| 631 |
+
"""
|
| 632 |
+
Parse DSML tool calls from text starting at the given index.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
index: Starting position in text.
|
| 636 |
+
text: The full text to parse.
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
|
| 640 |
+
Each tool call dict has "name" and "arguments" keys.
|
| 641 |
+
"""
|
| 642 |
+
tool_calls: List[Dict[str, Any]] = []
|
| 643 |
+
stop_token = None
|
| 644 |
+
tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
|
| 645 |
+
|
| 646 |
+
while index < len(text):
|
| 647 |
+
index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
|
| 648 |
+
if _ != ">\n":
|
| 649 |
+
raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
|
| 650 |
+
|
| 651 |
+
if stop_token == tool_calls_end_token:
|
| 652 |
+
break
|
| 653 |
+
|
| 654 |
+
if stop_token is None:
|
| 655 |
+
raise ValueError("Missing special token in tool calls")
|
| 656 |
+
|
| 657 |
+
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
| 658 |
+
|
| 659 |
+
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
|
| 660 |
+
if len(p_tool_name) != 1:
|
| 661 |
+
raise ValueError(f"Tool name format error: '{tool_name_content}'")
|
| 662 |
+
tool_name = p_tool_name[0]
|
| 663 |
+
|
| 664 |
+
tool_args: Dict[str, Tuple[str, str]] = {}
|
| 665 |
+
while stop_token == f"<{dsml_token}parameter":
|
| 666 |
+
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
|
| 667 |
+
|
| 668 |
+
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
|
| 669 |
+
if len(param_kv) != 1:
|
| 670 |
+
raise ValueError(f"Parameter format error: '{param_content}'")
|
| 671 |
+
param_name, string, param_value = param_kv[0]
|
| 672 |
+
|
| 673 |
+
if param_name in tool_args:
|
| 674 |
+
raise ValueError(f"Duplicate parameter name: '{param_name}'")
|
| 675 |
+
tool_args[param_name] = (param_value, string)
|
| 676 |
+
|
| 677 |
+
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
| 678 |
+
if content != ">\n":
|
| 679 |
+
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
|
| 680 |
+
|
| 681 |
+
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
|
| 682 |
+
tool_calls.append(tool_call)
|
| 683 |
+
|
| 684 |
+
return index, stop_token, tool_calls
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
|
| 688 |
+
"""
|
| 689 |
+
Parse a model completion text into a structured assistant message.
|
| 690 |
+
|
| 691 |
+
This function takes the raw text output from the model (a single assistant turn)
|
| 692 |
+
and extracts:
|
| 693 |
+
- reasoning_content (thinking block)
|
| 694 |
+
- content (summary/response)
|
| 695 |
+
- tool_calls (if any)
|
| 696 |
+
|
| 697 |
+
NOTE: This function is designed to parse only correctly formatted strings and
|
| 698 |
+
will raise ValueError for malformed output.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
text: The raw completion text (including EOS token).
|
| 702 |
+
thinking_mode: Either "chat" or "thinking".
|
| 703 |
+
|
| 704 |
+
Returns:
|
| 705 |
+
Dict with keys: "role", "content", "reasoning_content", "tool_calls".
|
| 706 |
+
tool_calls are in OpenAI format.
|
| 707 |
+
"""
|
| 708 |
+
summary_content, reasoning_content, tool_calls = "", "", []
|
| 709 |
+
index, stop_token = 0, None
|
| 710 |
+
tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
|
| 711 |
+
|
| 712 |
+
is_thinking = thinking_mode == "thinking"
|
| 713 |
+
is_tool_calling = False
|
| 714 |
+
|
| 715 |
+
if is_thinking:
|
| 716 |
+
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
|
| 717 |
+
reasoning_content = content_delta
|
| 718 |
+
assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
|
| 719 |
+
|
| 720 |
+
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
|
| 721 |
+
summary_content = content_delta
|
| 722 |
+
if stop_token == tool_calls_start_token:
|
| 723 |
+
is_tool_calling = True
|
| 724 |
+
else:
|
| 725 |
+
assert stop_token == eos_token, "Invalid format: missing EOS token"
|
| 726 |
+
|
| 727 |
+
if is_tool_calling:
|
| 728 |
+
index, stop_token, tool_calls = parse_tool_calls(index, text)
|
| 729 |
+
|
| 730 |
+
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
|
| 731 |
+
assert not tool_ends_text, "Unexpected content after tool calls"
|
| 732 |
+
|
| 733 |
+
assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
|
| 734 |
+
|
| 735 |
+
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
|
| 736 |
+
assert sp_token not in summary_content and sp_token not in reasoning_content, \
|
| 737 |
+
f"Unexpected special token '{sp_token}' in content"
|
| 738 |
+
|
| 739 |
+
return {
|
| 740 |
+
"role": "assistant",
|
| 741 |
+
"content": summary_content,
|
| 742 |
+
"reasoning_content": reasoning_content,
|
| 743 |
+
"tool_calls": tool_calls_to_openai_format(tool_calls)
|
| 744 |
+
}
|
encoding/test_encoding_dsv4.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test suite for DeepSeek-V4 Encoding.
|
| 3 |
+
|
| 4 |
+
Run: python test_encoding_dsv4.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 11 |
+
|
| 12 |
+
TESTS_DIR = os.path.join(os.path.dirname(__file__), "tests")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_case_1():
|
| 16 |
+
"""Thinking mode with tool calls (multi-turn, tool results merged into user)."""
|
| 17 |
+
with open(os.path.join(TESTS_DIR, "test_input_1.json")) as f:
|
| 18 |
+
td = json.load(f)
|
| 19 |
+
messages = td["messages"]
|
| 20 |
+
messages[0]["tools"] = td["tools"]
|
| 21 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_1.txt")).read()
|
| 22 |
+
prompt = encode_messages(messages, thinking_mode="thinking")
|
| 23 |
+
assert prompt == gold
|
| 24 |
+
|
| 25 |
+
# Parse: assistant turn with tool call
|
| 26 |
+
marker = "<|Assistant|><think>"
|
| 27 |
+
first_start = prompt.find(marker) + len(marker)
|
| 28 |
+
first_end = prompt.find("<|User|>", first_start)
|
| 29 |
+
parsed_tc = parse_message_from_completion_text(prompt[first_start:first_end], thinking_mode="thinking")
|
| 30 |
+
assert parsed_tc["reasoning_content"] == "The user wants to know the weather in Beijing. I should use the get_weather tool."
|
| 31 |
+
assert parsed_tc["content"] == ""
|
| 32 |
+
assert len(parsed_tc["tool_calls"]) == 1
|
| 33 |
+
assert parsed_tc["tool_calls"][0]["function"]["name"] == "get_weather"
|
| 34 |
+
assert json.loads(parsed_tc["tool_calls"][0]["function"]["arguments"]) == {"location": "Beijing", "unit": "celsius"}
|
| 35 |
+
|
| 36 |
+
# Parse: final assistant turn with content
|
| 37 |
+
last_start = prompt.rfind(marker) + len(marker)
|
| 38 |
+
parsed_final = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
|
| 39 |
+
assert parsed_final["reasoning_content"] == "Got the weather data. Let me format a nice response."
|
| 40 |
+
assert "22°C" in parsed_final["content"]
|
| 41 |
+
assert parsed_final["tool_calls"] == []
|
| 42 |
+
|
| 43 |
+
print(" [PASS] case 1: thinking with tools (encode + parse)")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_case_2():
|
| 47 |
+
"""Thinking mode without tools (drop_thinking removes earlier reasoning)."""
|
| 48 |
+
messages = json.load(open(os.path.join(TESTS_DIR, "test_input_2.json")))
|
| 49 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_2.txt")).read()
|
| 50 |
+
prompt = encode_messages(messages, thinking_mode="thinking")
|
| 51 |
+
assert prompt == gold
|
| 52 |
+
|
| 53 |
+
# Parse: last assistant turn
|
| 54 |
+
marker = "<|Assistant|><think>"
|
| 55 |
+
last_start = prompt.rfind(marker) + len(marker)
|
| 56 |
+
parsed = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
|
| 57 |
+
assert parsed["reasoning_content"] == "The user asks about the capital of France. It is Paris."
|
| 58 |
+
assert parsed["content"] == "The capital of France is Paris."
|
| 59 |
+
assert parsed["tool_calls"] == []
|
| 60 |
+
|
| 61 |
+
# Verify drop_thinking: first assistant's reasoning should be absent
|
| 62 |
+
assert "The user said hello" not in prompt
|
| 63 |
+
|
| 64 |
+
print(" [PASS] case 2: thinking without tools (encode + parse)")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_case_3():
|
| 68 |
+
"""Interleaved thinking + search (developer with tools, latest_reminder)."""
|
| 69 |
+
messages = json.load(open(os.path.join(TESTS_DIR, "test_input_3.json")))
|
| 70 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_3.txt")).read()
|
| 71 |
+
assert encode_messages(messages, thinking_mode="thinking") == gold
|
| 72 |
+
print(" [PASS] case 3: interleaved thinking + search")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_case_4():
|
| 76 |
+
"""Quick instruction task with latest_reminder (chat mode, action task)."""
|
| 77 |
+
messages = json.load(open(os.path.join(TESTS_DIR, "test_input_4.json")))
|
| 78 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_4.txt")).read()
|
| 79 |
+
assert encode_messages(messages, thinking_mode="chat") == gold
|
| 80 |
+
print(" [PASS] case 4: quick instruction task")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
print("Running DeepSeek-V4 Encoding Tests...\n")
|
| 85 |
+
test_case_1()
|
| 86 |
+
test_case_2()
|
| 87 |
+
test_case_3()
|
| 88 |
+
test_case_4()
|
| 89 |
+
print("\nAll 4 tests passed!")
|
encoding/tests/test_input_1.json
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tools": [
|
| 3 |
+
{
|
| 4 |
+
"type": "function",
|
| 5 |
+
"function": {
|
| 6 |
+
"name": "get_weather",
|
| 7 |
+
"description": "Get the weather for a specific location",
|
| 8 |
+
"parameters": {
|
| 9 |
+
"type": "object",
|
| 10 |
+
"properties": {
|
| 11 |
+
"location": {
|
| 12 |
+
"type": "string",
|
| 13 |
+
"description": "The city name"
|
| 14 |
+
},
|
| 15 |
+
"unit": {
|
| 16 |
+
"type": "string",
|
| 17 |
+
"enum": ["celsius", "fahrenheit"],
|
| 18 |
+
"description": "Temperature unit"
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"required": ["location"]
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"type": "function",
|
| 27 |
+
"function": {
|
| 28 |
+
"name": "search",
|
| 29 |
+
"description": "Search the web for information",
|
| 30 |
+
"parameters": {
|
| 31 |
+
"type": "object",
|
| 32 |
+
"properties": {
|
| 33 |
+
"query": {
|
| 34 |
+
"type": "string",
|
| 35 |
+
"description": "Search query"
|
| 36 |
+
},
|
| 37 |
+
"num_results": {
|
| 38 |
+
"type": "integer",
|
| 39 |
+
"description": "Number of results to return"
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"required": ["query"]
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"messages": [
|
| 48 |
+
{
|
| 49 |
+
"role": "system",
|
| 50 |
+
"content": "You are a helpful assistant."
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"role": "user",
|
| 54 |
+
"content": "What's the weather in Beijing?"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"role": "assistant",
|
| 58 |
+
"reasoning_content": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
|
| 59 |
+
"tool_calls": [
|
| 60 |
+
{
|
| 61 |
+
"id": "call_001",
|
| 62 |
+
"type": "function",
|
| 63 |
+
"function": {
|
| 64 |
+
"name": "get_weather",
|
| 65 |
+
"arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"role": "tool",
|
| 72 |
+
"tool_call_id": "call_001",
|
| 73 |
+
"content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"role": "assistant",
|
| 77 |
+
"reasoning_content": "Got the weather data. Let me format a nice response.",
|
| 78 |
+
"content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
|
| 79 |
+
}
|
| 80 |
+
]
|
| 81 |
+
}
|
encoding/tests/test_input_2.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"role": "system",
|
| 4 |
+
"content": "You are a helpful assistant."
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"role": "user",
|
| 8 |
+
"content": "Hello"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"role": "assistant",
|
| 12 |
+
"reasoning_content": "The user said hello, I should greet back.",
|
| 13 |
+
"content": "Hi there! How can I help you?"
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"role": "user",
|
| 17 |
+
"content": "What is the capital of France?"
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"role": "assistant",
|
| 21 |
+
"reasoning_content": "The user asks about the capital of France. It is Paris.",
|
| 22 |
+
"content": "The capital of France is Paris."
|
| 23 |
+
}
|
| 24 |
+
]
|
encoding/tests/test_input_3.json
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"role": "system",
|
| 4 |
+
"content": "该助手为DeepSeek,由深度求索公司创造。"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"role": "latest_reminder",
|
| 8 |
+
"content": "2026-02-21,星期六,广州,App,中文"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"role": "developer",
|
| 12 |
+
"content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
|
| 13 |
+
"tools": [
|
| 14 |
+
{
|
| 15 |
+
"type": "function",
|
| 16 |
+
"function": {
|
| 17 |
+
"name": "search",
|
| 18 |
+
"description": "Web search. Split multiple queries with '||'.",
|
| 19 |
+
"parameters": {
|
| 20 |
+
"type": "object",
|
| 21 |
+
"properties": {
|
| 22 |
+
"queries": {
|
| 23 |
+
"type": "string",
|
| 24 |
+
"description": "query1||query2"
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
"required": [
|
| 28 |
+
"queries"
|
| 29 |
+
],
|
| 30 |
+
"additionalProperties": false,
|
| 31 |
+
"$schema": "http://json-schema.org/draft-07/schema#"
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"type": "function",
|
| 37 |
+
"function": {
|
| 38 |
+
"name": "open",
|
| 39 |
+
"description": "Batch open IDs (format 【{id}†...】) or URLs.",
|
| 40 |
+
"parameters": {
|
| 41 |
+
"type": "object",
|
| 42 |
+
"properties": {
|
| 43 |
+
"open_list": {
|
| 44 |
+
"type": "array",
|
| 45 |
+
"items": {
|
| 46 |
+
"type": "object",
|
| 47 |
+
"properties": {
|
| 48 |
+
"id": {
|
| 49 |
+
"description": "ID or URL",
|
| 50 |
+
"anyOf": [
|
| 51 |
+
{
|
| 52 |
+
"type": "integer"
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"type": "string"
|
| 56 |
+
}
|
| 57 |
+
],
|
| 58 |
+
"default": -1
|
| 59 |
+
},
|
| 60 |
+
"cursor": {
|
| 61 |
+
"type": "integer",
|
| 62 |
+
"description": "",
|
| 63 |
+
"default": -1
|
| 64 |
+
},
|
| 65 |
+
"loc": {
|
| 66 |
+
"type": "integer",
|
| 67 |
+
"description": "Start line",
|
| 68 |
+
"default": -1
|
| 69 |
+
},
|
| 70 |
+
"num_lines": {
|
| 71 |
+
"type": "integer",
|
| 72 |
+
"description": "",
|
| 73 |
+
"default": -1
|
| 74 |
+
},
|
| 75 |
+
"view_source": {
|
| 76 |
+
"type": "boolean",
|
| 77 |
+
"description": "",
|
| 78 |
+
"default": false
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
"additionalProperties": false
|
| 82 |
+
},
|
| 83 |
+
"description": ""
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"required": [
|
| 87 |
+
"open_list"
|
| 88 |
+
],
|
| 89 |
+
"additionalProperties": false,
|
| 90 |
+
"$schema": "http://json-schema.org/draft-07/schema#"
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"type": "function",
|
| 96 |
+
"function": {
|
| 97 |
+
"name": "find",
|
| 98 |
+
"description": "Find exact text pattern in pages.",
|
| 99 |
+
"parameters": {
|
| 100 |
+
"type": "object",
|
| 101 |
+
"properties": {
|
| 102 |
+
"find_list": {
|
| 103 |
+
"type": "array",
|
| 104 |
+
"items": {
|
| 105 |
+
"type": "object",
|
| 106 |
+
"properties": {
|
| 107 |
+
"pattern": {
|
| 108 |
+
"type": "string",
|
| 109 |
+
"description": ""
|
| 110 |
+
},
|
| 111 |
+
"cursor": {
|
| 112 |
+
"type": "integer",
|
| 113 |
+
"description": "",
|
| 114 |
+
"default": -1
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"required": [
|
| 118 |
+
"pattern"
|
| 119 |
+
],
|
| 120 |
+
"additionalProperties": false
|
| 121 |
+
},
|
| 122 |
+
"description": ""
|
| 123 |
+
}
|
| 124 |
+
},
|
| 125 |
+
"required": [
|
| 126 |
+
"find_list"
|
| 127 |
+
],
|
| 128 |
+
"additionalProperties": false,
|
| 129 |
+
"$schema": "http://json-schema.org/draft-07/schema#"
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"role": "assistant",
|
| 137 |
+
"content": "",
|
| 138 |
+
"reasoning_content": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
|
| 139 |
+
"tool_calls": [
|
| 140 |
+
{
|
| 141 |
+
"type": "function",
|
| 142 |
+
"function": {
|
| 143 |
+
"name": "search",
|
| 144 |
+
"arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"role": "tool",
|
| 151 |
+
"content": "[0]"
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"role": "assistant",
|
| 155 |
+
"content": "请及时就医。",
|
| 156 |
+
"reasoning_content": "现在开始组织回答。",
|
| 157 |
+
"tool_calls": []
|
| 158 |
+
}
|
| 159 |
+
]
|
encoding/tests/test_input_4.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"role": "system",
|
| 4 |
+
"content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"role": "latest_reminder",
|
| 8 |
+
"content": "2024-11-15,上海市,App,中文"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"role": "user",
|
| 12 |
+
"content": "热海大滚锅是世界著名温泉吗"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"role": "assistant",
|
| 16 |
+
"content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
|
| 17 |
+
"mask": 1
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"role": "user",
|
| 21 |
+
"content": "世界著名温泉有哪些",
|
| 22 |
+
"task": "action"
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"role": "assistant",
|
| 26 |
+
"content": "Search"
|
| 27 |
+
}
|
| 28 |
+
]
|
encoding/tests/test_output_1.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>You are a helpful assistant.
|
| 2 |
+
|
| 3 |
+
## Tools
|
| 4 |
+
|
| 5 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
| 6 |
+
|
| 7 |
+
<|DSML|tool_calls>
|
| 8 |
+
<|DSML|invoke name="$TOOL_NAME">
|
| 9 |
+
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
| 10 |
+
...
|
| 11 |
+
</|DSML|invoke>
|
| 12 |
+
<|DSML|invoke name="$TOOL_NAME2">
|
| 13 |
+
...
|
| 14 |
+
</|DSML|invoke>
|
| 15 |
+
</|DSML|tool_calls>
|
| 16 |
+
|
| 17 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 18 |
+
|
| 19 |
+
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
| 20 |
+
|
| 21 |
+
Otherwise, output directly after </think> with tool calls or final response.
|
| 22 |
+
|
| 23 |
+
### Available Tool Schemas
|
| 24 |
+
|
| 25 |
+
{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
|
| 26 |
+
{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
|
| 27 |
+
|
| 28 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 29 |
+
<|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
|
| 30 |
+
|
| 31 |
+
<|DSML|tool_calls>
|
| 32 |
+
<|DSML|invoke name="get_weather">
|
| 33 |
+
<|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
|
| 34 |
+
<|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
|
| 35 |
+
</|DSML|invoke>
|
| 36 |
+
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
|
encoding/tests/test_output_2.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
|
encoding/tests/test_output_3.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
|
| 2 |
+
|
| 3 |
+
CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
|
| 4 |
+
|
| 5 |
+
## Tools
|
| 6 |
+
|
| 7 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
| 8 |
+
|
| 9 |
+
<|DSML|tool_calls>
|
| 10 |
+
<|DSML|invoke name="$TOOL_NAME">
|
| 11 |
+
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
| 12 |
+
...
|
| 13 |
+
</|DSML|invoke>
|
| 14 |
+
<|DSML|invoke name="$TOOL_NAME2">
|
| 15 |
+
...
|
| 16 |
+
</|DSML|invoke>
|
| 17 |
+
</|DSML|tool_calls>
|
| 18 |
+
|
| 19 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 20 |
+
|
| 21 |
+
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
| 22 |
+
|
| 23 |
+
Otherwise, output directly after </think> with tool calls or final response.
|
| 24 |
+
|
| 25 |
+
### Available Tool Schemas
|
| 26 |
+
|
| 27 |
+
{"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
| 28 |
+
{"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
| 29 |
+
{"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
| 30 |
+
|
| 31 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 32 |
+
<|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
|
| 33 |
+
|
| 34 |
+
<|DSML|tool_calls>
|
| 35 |
+
<|DSML|invoke name="search">
|
| 36 |
+
<|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
|
| 37 |
+
</|DSML|invoke>
|
| 38 |
+
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
|
encoding/tests/test_output_4.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
|
| 2 |
+
今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
|
| 3 |
+
|
| 4 |
+
为了让你快速了解,我整理了一个简要的表格:
|
| 5 |
+
|
| 6 |
+
| 维度 | 热海大滚锅的地位与特点 |
|
| 7 |
+
| :--- | :--- |
|
| 8 |
+
| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
|
| 9 |
+
| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
|
| 10 |
+
| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
|
| 11 |
+
| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
|
| 12 |
+
|
| 13 |
+
### 💡 游玩攻略与温馨提示
|
| 14 |
+
|
| 15 |
+
如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
|
| 16 |
+
|
| 17 |
+
- **门票与开放时间**:
|
| 18 |
+
- **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
|
| 19 |
+
- **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
|
| 20 |
+
|
| 21 |
+
- **特色体验**:
|
| 22 |
+
- **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
|
| 23 |
+
- **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
|
| 24 |
+
|
| 25 |
+
- **注意事项**:
|
| 26 |
+
- **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
|
| 27 |
+
- **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
|
| 28 |
+
|
| 29 |
+
希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 1,
|
| 5 |
+
"do_sample": true,
|
| 6 |
+
"temperature": 1.0,
|
| 7 |
+
"top_p": 1.0,
|
| 8 |
+
"transformers_version": "4.46.3"
|
| 9 |
+
}
|
inference/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code for DeepSeek models
|
| 2 |
+
|
| 3 |
+
First convert huggingface model weight files to the format of this project.
|
| 4 |
+
```bash
|
| 5 |
+
export EXPERTS=384
|
| 6 |
+
export MP=8
|
| 7 |
+
export CONFIG=config_w4a16.json
|
| 8 |
+
python convert_w4a16.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
Then chat with DeepSeek model at will!
|
| 12 |
+
```bash
|
| 13 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Or batch inference from file.
|
| 17 |
+
```bash
|
| 18 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Or multi nodes inference.
|
| 22 |
+
```bash
|
| 23 |
+
torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
|
| 24 |
+
```
|
| 25 |
+
|
inference/config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 129280,
|
| 3 |
+
"dim": 7168,
|
| 4 |
+
"moe_inter_dim": 3072,
|
| 5 |
+
"n_layers": 61,
|
| 6 |
+
"n_hash_layers": 3,
|
| 7 |
+
"n_heads": 128,
|
| 8 |
+
"n_routed_experts": 384,
|
| 9 |
+
"n_shared_experts": 1,
|
| 10 |
+
"n_activated_experts": 6,
|
| 11 |
+
"score_func": "sqrtsoftplus",
|
| 12 |
+
"route_scale": 2.5,
|
| 13 |
+
"swiglu_limit": 10.0,
|
| 14 |
+
"q_lora_rank": 1536,
|
| 15 |
+
"head_dim": 512,
|
| 16 |
+
"rope_head_dim": 64,
|
| 17 |
+
"o_groups": 16,
|
| 18 |
+
"o_lora_rank": 1024,
|
| 19 |
+
"window_size": 128,
|
| 20 |
+
"original_seq_len": 65536,
|
| 21 |
+
"rope_theta": 10000,
|
| 22 |
+
"rope_factor": 16,
|
| 23 |
+
"beta_fast": 32,
|
| 24 |
+
"beta_slow": 1,
|
| 25 |
+
"index_n_heads": 64,
|
| 26 |
+
"index_head_dim": 128,
|
| 27 |
+
"index_topk": 1024,
|
| 28 |
+
"hc_mult": 4,
|
| 29 |
+
"hc_sinkhorn_iters": 20,
|
| 30 |
+
"dtype": "fp8",
|
| 31 |
+
"scale_fmt": "ue8m0",
|
| 32 |
+
"expert_dtype": "fp4",
|
| 33 |
+
"compress_rope_theta": 160000,
|
| 34 |
+
"compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
|
| 35 |
+
}
|
inference/config_w4a16.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
{
|
| 3 |
+
"vocab_size": 129280,
|
| 4 |
+
"dim": 7168,
|
| 5 |
+
"moe_inter_dim": 3072,
|
| 6 |
+
"n_layers": 61,
|
| 7 |
+
"n_hash_layers": 3,
|
| 8 |
+
"n_heads": 128,
|
| 9 |
+
"n_routed_experts": 384,
|
| 10 |
+
"n_shared_experts": 1,
|
| 11 |
+
"n_activated_experts": 6,
|
| 12 |
+
"score_func": "sqrtsoftplus",
|
| 13 |
+
"route_scale": 2.5,
|
| 14 |
+
"swiglu_limit": 10.0,
|
| 15 |
+
"q_lora_rank": 1536,
|
| 16 |
+
"head_dim": 512,
|
| 17 |
+
"rope_head_dim": 64,
|
| 18 |
+
"o_groups": 16,
|
| 19 |
+
"o_lora_rank": 1024,
|
| 20 |
+
"window_size": 128,
|
| 21 |
+
"original_seq_len": 65536,
|
| 22 |
+
"rope_theta": 10000,
|
| 23 |
+
"rope_factor": 16,
|
| 24 |
+
"beta_fast": 32,
|
| 25 |
+
"beta_slow": 1,
|
| 26 |
+
"index_n_heads": 64,
|
| 27 |
+
"index_head_dim": 128,
|
| 28 |
+
"index_topk": 1024,
|
| 29 |
+
"hc_mult": 4,
|
| 30 |
+
"hc_sinkhorn_iters": 20,
|
| 31 |
+
"dtype": "w4a16",
|
| 32 |
+
"compress_rope_theta": 160000,
|
| 33 |
+
"compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
|
| 34 |
+
}
|
inference/convert.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from glob import glob
|
| 5 |
+
from tqdm import tqdm, trange
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import safe_open, save_file
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
FP4_TABLE = torch.tensor([
|
| 12 |
+
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
| 13 |
+
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
|
| 14 |
+
], dtype=torch.float32)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
+
"""
|
| 19 |
+
Casts a tensor from e2m1fn to e4m3fn losslessly.
|
| 20 |
+
"""
|
| 21 |
+
assert x.dtype == torch.int8
|
| 22 |
+
assert x.ndim == 2
|
| 23 |
+
out_dim, in_dim = x.size()
|
| 24 |
+
in_dim *= 2
|
| 25 |
+
fp8_block_size = 128
|
| 26 |
+
fp4_block_size = 32
|
| 27 |
+
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
|
| 28 |
+
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
|
| 29 |
+
|
| 30 |
+
x = x.view(torch.uint8)
|
| 31 |
+
low = x & 0x0F
|
| 32 |
+
high = (x >> 4) & 0x0F
|
| 33 |
+
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
|
| 34 |
+
|
| 35 |
+
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
|
| 36 |
+
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
|
| 37 |
+
MAX_OFFSET_BITS = 6
|
| 38 |
+
|
| 39 |
+
bOut = out_dim // fp8_block_size
|
| 40 |
+
bIn = in_dim // fp8_block_size
|
| 41 |
+
# bOut, bIn, 128, 128
|
| 42 |
+
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
|
| 43 |
+
# bOut, bIn, 128*4
|
| 44 |
+
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
|
| 45 |
+
## bOut, bIn, 1
|
| 46 |
+
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
|
| 47 |
+
# bOut, bIn, 128*4
|
| 48 |
+
offset = scale / scale_max_offset_bits
|
| 49 |
+
# bOut, bIn, 128, 128
|
| 50 |
+
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
|
| 51 |
+
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
|
| 52 |
+
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
mapping = {
|
| 56 |
+
"embed_tokens": ("embed", 0),
|
| 57 |
+
"input_layernorm": ("attn_norm", None),
|
| 58 |
+
"post_attention_layernorm": ("ffn_norm", None),
|
| 59 |
+
"q_proj": ("wq", 0),
|
| 60 |
+
"q_a_proj": ("wq_a", None),
|
| 61 |
+
"q_a_layernorm": ("q_norm", None),
|
| 62 |
+
"q_b_proj": ("wq_b", 0),
|
| 63 |
+
"kv_a_proj_with_mqa": ("wkv_a", None),
|
| 64 |
+
"kv_a_layernorm": ("kv_norm", None),
|
| 65 |
+
"kv_b_proj": ("wkv_b", 0),
|
| 66 |
+
"o_proj": ("wo", 1),
|
| 67 |
+
"gate_proj": ("w1", 0),
|
| 68 |
+
"down_proj": ("w2", 1),
|
| 69 |
+
"up_proj": ("w3", 0),
|
| 70 |
+
"lm_head": ("head", 0),
|
| 71 |
+
|
| 72 |
+
"embed": ("embed", 0),
|
| 73 |
+
"wq_b": ("wq_b", 0),
|
| 74 |
+
"wo_a": ("wo_a", 0),
|
| 75 |
+
"wo_b": ("wo_b", 1),
|
| 76 |
+
"head": ("head", 0),
|
| 77 |
+
"attn_sink": ("attn_sink", 0),
|
| 78 |
+
"weights_proj": ("weights_proj", 0),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
|
| 83 |
+
"""
|
| 84 |
+
Converts and saves model checkpoint files into a specified format.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
| 88 |
+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
| 89 |
+
n_experts (int): Total number of experts in the model.
|
| 90 |
+
mp (int): Model parallelism factor.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
None
|
| 94 |
+
"""
|
| 95 |
+
torch.set_num_threads(8)
|
| 96 |
+
n_local_experts = n_experts // mp
|
| 97 |
+
state_dicts = [{} for _ in range(mp)]
|
| 98 |
+
|
| 99 |
+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
| 100 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 101 |
+
for name in f.keys():
|
| 102 |
+
param: torch.Tensor = f.get_tensor(name)
|
| 103 |
+
if name.startswith("model."):
|
| 104 |
+
name = name[len("model."):]
|
| 105 |
+
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
|
| 106 |
+
continue
|
| 107 |
+
name = name.replace("self_attn", "attn")
|
| 108 |
+
name = name.replace("mlp", "ffn")
|
| 109 |
+
name = name.replace("weight_scale_inv", "scale")
|
| 110 |
+
name = name.replace("e_score_correction_bias", "bias")
|
| 111 |
+
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
|
| 112 |
+
key = name.split(".")[-1]
|
| 113 |
+
else:
|
| 114 |
+
key = name.split(".")[-2]
|
| 115 |
+
if key in mapping:
|
| 116 |
+
new_key, dim = mapping[key]
|
| 117 |
+
else:
|
| 118 |
+
new_key, dim = key, None
|
| 119 |
+
name = name.replace(key, new_key)
|
| 120 |
+
for i in range(mp):
|
| 121 |
+
new_param = param
|
| 122 |
+
if "experts" in name and "shared_experts" not in name:
|
| 123 |
+
idx = int(name.split(".")[-3])
|
| 124 |
+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
| 125 |
+
continue
|
| 126 |
+
elif dim is not None:
|
| 127 |
+
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
| 128 |
+
shard_size = param.size(dim) // mp
|
| 129 |
+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
| 130 |
+
state_dicts[i][name] = new_param
|
| 131 |
+
|
| 132 |
+
os.makedirs(save_path, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
for i in trange(mp):
|
| 135 |
+
names = list(state_dicts[i].keys())
|
| 136 |
+
for name in names:
|
| 137 |
+
if name.endswith("wo_a.weight"):
|
| 138 |
+
weight = state_dicts[i][name]
|
| 139 |
+
scale = state_dicts[i].pop(name.replace("weight", "scale"))
|
| 140 |
+
weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
|
| 141 |
+
state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
|
| 142 |
+
elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
|
| 143 |
+
if expert_dtype == "fp8":
|
| 144 |
+
scale_name = name.replace("weight", "scale")
|
| 145 |
+
weight = state_dicts[i].pop(name)
|
| 146 |
+
scale = state_dicts[i].pop(scale_name)
|
| 147 |
+
state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
|
| 148 |
+
else:
|
| 149 |
+
state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
|
| 150 |
+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
| 151 |
+
|
| 152 |
+
for file in ["tokenizer.json", "tokenizer_config.json"]:
|
| 153 |
+
old_file_path = os.path.join(hf_ckpt_path, file)
|
| 154 |
+
new_file_path = os.path.join(save_path, file)
|
| 155 |
+
if os.path.exists(old_file_path):
|
| 156 |
+
shutil.copyfile(old_file_path, new_file_path)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
parser = ArgumentParser()
|
| 161 |
+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
| 162 |
+
parser.add_argument("--save-path", type=str, required=True)
|
| 163 |
+
parser.add_argument("--n-experts", type=int, required=True)
|
| 164 |
+
parser.add_argument("--model-parallel", type=int, required=True)
|
| 165 |
+
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
| 168 |
+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
|
inference/convert_w4a16.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert an auto-round / GPTQ W4A16 packed HuggingFace checkpoint of DeepSeek-V4
|
| 2 |
+
into the MP-sharded local format consumed by `model.py`/`generate.py`.
|
| 3 |
+
|
| 4 |
+
Packing convention (auto-round → auto_gptq):
|
| 5 |
+
- qweight : int32 [in_features // 8, out_features], LSB-first 4-bit packed along dim 0
|
| 6 |
+
- qzeros : int32 [in_features // group_size, out_features // 8], LSB-first 4-bit packed along dim 1
|
| 7 |
+
- scales : fp16 [in_features // group_size, out_features]
|
| 8 |
+
|
| 9 |
+
Sharding rules per linear:
|
| 10 |
+
- ColumnParallel (shard output dim, original `dim=0` in `mapping`):
|
| 11 |
+
qweight along dim 1; qzeros along dim 1 (must be divisible by 8 first, then by world_size);
|
| 12 |
+
scales along dim 1.
|
| 13 |
+
- RowParallel (shard input dim, original `dim=1` in `mapping`):
|
| 14 |
+
qweight along dim 0 (must be divisible by 8 first, then by world_size);
|
| 15 |
+
qzeros along dim 0 (must be divisible by group_size first, then by world_size);
|
| 16 |
+
scales along dim 0.
|
| 17 |
+
|
| 18 |
+
Non-quantised tensors (embed.weight, *.norm.weight, attn_sink, hc_*, ape, gate.bias,
|
| 19 |
+
gate.tid2eid, etc.) follow the same rules as the original `convert.py`.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import shutil
|
| 24 |
+
from argparse import ArgumentParser
|
| 25 |
+
from glob import glob
|
| 26 |
+
from tqdm import tqdm, trange
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from safetensors.torch import safe_open, save_file
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
GROUP_SIZE = 128
|
| 33 |
+
|
| 34 |
+
# Same name remapping as the original convert.py
|
| 35 |
+
mapping = {
|
| 36 |
+
"embed_tokens": ("embed", 0),
|
| 37 |
+
"input_layernorm": ("attn_norm", None),
|
| 38 |
+
"post_attention_layernorm": ("ffn_norm", None),
|
| 39 |
+
"q_proj": ("wq", 0),
|
| 40 |
+
"q_a_proj": ("wq_a", None),
|
| 41 |
+
"q_a_layernorm": ("q_norm", None),
|
| 42 |
+
"q_b_proj": ("wq_b", 0),
|
| 43 |
+
"kv_a_proj_with_mqa": ("wkv_a", None),
|
| 44 |
+
"kv_a_layernorm": ("kv_norm", None),
|
| 45 |
+
"kv_b_proj": ("wkv_b", 0),
|
| 46 |
+
"o_proj": ("wo", 1),
|
| 47 |
+
"gate_proj": ("w1", 0),
|
| 48 |
+
"down_proj": ("w2", 1),
|
| 49 |
+
"up_proj": ("w3", 0),
|
| 50 |
+
"lm_head": ("head", 0),
|
| 51 |
+
|
| 52 |
+
# Already-translated names (used by the inference checkpoints we already have)
|
| 53 |
+
"embed": ("embed", 0),
|
| 54 |
+
"wq_a": ("wq_a", None),
|
| 55 |
+
"wq_b": ("wq_b", 0),
|
| 56 |
+
"wkv": ("wkv", None),
|
| 57 |
+
"wo_a": ("wo_a", 0),
|
| 58 |
+
"wo_b": ("wo_b", 1),
|
| 59 |
+
"w1": ("w1", 0),
|
| 60 |
+
"w2": ("w2", 1),
|
| 61 |
+
"w3": ("w3", 0),
|
| 62 |
+
"head": ("head", 0),
|
| 63 |
+
"weights_proj": ("weights_proj", 0),
|
| 64 |
+
# special non-weight keys
|
| 65 |
+
"attn_sink": ("attn_sink", 0),
|
| 66 |
+
"ape": ("ape", None),
|
| 67 |
+
# NOTE: 'gate' is intentionally NOT in this mapping -- the routing gate is a
|
| 68 |
+
# plain nn.Parameter that is replicated on every rank.
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Suffixes that mark the three pieces of a packed W4A16 linear.
|
| 73 |
+
QUANT_SUFFIXES = (".qweight", ".qzeros", ".scales")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def shard_quant(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
|
| 77 |
+
shard_dim: int, mp: int):
|
| 78 |
+
"""Yield (qweight_i, qzeros_i, scales_i) for i in range(mp).
|
| 79 |
+
|
| 80 |
+
shard_dim is the *logical* dim of the dequantised weight: 0 == output (column parallel),
|
| 81 |
+
1 == input (row parallel)."""
|
| 82 |
+
out = qweight.size(1)
|
| 83 |
+
in_packed = qweight.size(0) # in_features // 8
|
| 84 |
+
n_groups = scales.size(0) # in_features // group_size
|
| 85 |
+
|
| 86 |
+
if shard_dim == 0: # ColumnParallel: shard along OUTPUT
|
| 87 |
+
assert out % mp == 0, f"out={out} not divisible by mp={mp}"
|
| 88 |
+
# qzeros packs 8 outputs per int32 in dim 1, so need (out/mp) % 8 == 0
|
| 89 |
+
assert (out // mp) % 8 == 0, f"shard {out//mp} of out dim not divisible by 8 (qzeros packing)"
|
| 90 |
+
sh_out = out // mp
|
| 91 |
+
sh_qz_cols = qzeros.size(1) // mp # == out / 8 / mp
|
| 92 |
+
for i in range(mp):
|
| 93 |
+
yield (
|
| 94 |
+
qweight.narrow(1, i * sh_out, sh_out).contiguous(),
|
| 95 |
+
qzeros.narrow(1, i * sh_qz_cols, sh_qz_cols).contiguous(),
|
| 96 |
+
scales.narrow(1, i * sh_out, sh_out).contiguous(),
|
| 97 |
+
)
|
| 98 |
+
elif shard_dim == 1: # RowParallel: shard along INPUT
|
| 99 |
+
# qweight packs 8 inputs per int32 in dim 0, scales/qzeros are per-group on dim 0
|
| 100 |
+
assert in_packed % mp == 0, f"in_packed={in_packed} not divisible by mp={mp}"
|
| 101 |
+
assert n_groups % mp == 0, f"n_groups={n_groups} not divisible by mp={mp}"
|
| 102 |
+
sh_in_packed = in_packed // mp
|
| 103 |
+
sh_groups = n_groups // mp
|
| 104 |
+
for i in range(mp):
|
| 105 |
+
yield (
|
| 106 |
+
qweight.narrow(0, i * sh_in_packed, sh_in_packed).contiguous(),
|
| 107 |
+
qzeros.narrow(0, i * sh_groups, sh_groups).contiguous(),
|
| 108 |
+
scales.narrow(0, i * sh_groups, sh_groups).contiguous(),
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
# Replicate
|
| 112 |
+
for _ in range(mp):
|
| 113 |
+
yield qweight, qzeros, scales
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_layer_key(name: str):
|
| 117 |
+
"""Return the linear-name token (e.g. wq_a, w1, head) used for the rename mapping."""
|
| 118 |
+
parts = name.split(".")
|
| 119 |
+
if name.endswith(QUANT_SUFFIXES):
|
| 120 |
+
return parts[-2] # ...x.qweight -> x
|
| 121 |
+
if name.endswith(".bias") and "gate" in name:
|
| 122 |
+
return "gate" # ffn.gate.bias
|
| 123 |
+
if name.endswith(".tid2eid"):
|
| 124 |
+
return "gate"
|
| 125 |
+
if any(k in parts for k in ("hc_attn_fn", "hc_attn_base", "hc_attn_scale",
|
| 126 |
+
"hc_ffn_fn", "hc_ffn_base", "hc_ffn_scale",
|
| 127 |
+
"hc_head_fn", "hc_head_base", "hc_head_scale",
|
| 128 |
+
"attn_sink", "ape")):
|
| 129 |
+
return parts[-1]
|
| 130 |
+
return parts[-2]
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main(hf_ckpt_path, save_path, n_experts, mp):
|
| 134 |
+
torch.set_num_threads(8)
|
| 135 |
+
n_local_experts = n_experts // mp
|
| 136 |
+
state_dicts = [{} for _ in range(mp)]
|
| 137 |
+
|
| 138 |
+
# Group all fragments belonging to the same logical linear so we can shard
|
| 139 |
+
# qweight/qzeros/scales together.
|
| 140 |
+
pending: dict[str, dict[str, torch.Tensor]] = {}
|
| 141 |
+
|
| 142 |
+
def emit_linear(base_name: str, parts: dict[str, torch.Tensor], shard_dim):
|
| 143 |
+
"""Distribute a quantised linear (3 tensors) across `mp` shards."""
|
| 144 |
+
qweight = parts["qweight"]
|
| 145 |
+
qzeros = parts["qzeros"]
|
| 146 |
+
scales = parts["scales"].to(torch.bfloat16) # store bf16 instead of fp16
|
| 147 |
+
# Expert-local pruning: only the rank that owns this expert keeps the tensors.
|
| 148 |
+
if "experts" in base_name and "shared_experts" not in base_name:
|
| 149 |
+
idx = int(base_name.split(".experts.")[1].split(".")[0])
|
| 150 |
+
owner = idx // n_local_experts
|
| 151 |
+
state_dicts[owner][base_name + ".qweight"] = qweight
|
| 152 |
+
state_dicts[owner][base_name + ".qzeros"] = qzeros
|
| 153 |
+
state_dicts[owner][base_name + ".scales"] = scales
|
| 154 |
+
return
|
| 155 |
+
if shard_dim is None:
|
| 156 |
+
# Replicate across all ranks
|
| 157 |
+
for i in range(mp):
|
| 158 |
+
state_dicts[i][base_name + ".qweight"] = qweight
|
| 159 |
+
state_dicts[i][base_name + ".qzeros"] = qzeros
|
| 160 |
+
state_dicts[i][base_name + ".scales"] = scales
|
| 161 |
+
else:
|
| 162 |
+
for i, (qw, qz, sc) in enumerate(shard_quant(qweight, qzeros, scales, shard_dim, mp)):
|
| 163 |
+
state_dicts[i][base_name + ".qweight"] = qw
|
| 164 |
+
state_dicts[i][base_name + ".qzeros"] = qz
|
| 165 |
+
state_dicts[i][base_name + ".scales"] = sc
|
| 166 |
+
|
| 167 |
+
files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
|
| 168 |
+
for file_path in tqdm(files, desc="files"):
|
| 169 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 170 |
+
for orig_name in f.keys():
|
| 171 |
+
# ----- name remapping (mirrors original convert.py) -----
|
| 172 |
+
name = orig_name
|
| 173 |
+
if name.startswith("model."):
|
| 174 |
+
name = name[len("model."):]
|
| 175 |
+
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
|
| 176 |
+
continue
|
| 177 |
+
name = name.replace("self_attn", "attn")
|
| 178 |
+
name = name.replace("mlp", "ffn")
|
| 179 |
+
name = name.replace("weight_scale_inv", "scale")
|
| 180 |
+
name = name.replace("e_score_correction_bias", "bias")
|
| 181 |
+
|
| 182 |
+
key = get_layer_key(name)
|
| 183 |
+
if key in mapping:
|
| 184 |
+
new_key, dim = mapping[key]
|
| 185 |
+
name = name.replace(key, new_key)
|
| 186 |
+
else:
|
| 187 |
+
dim = None
|
| 188 |
+
|
| 189 |
+
tensor = f.get_tensor(orig_name)
|
| 190 |
+
|
| 191 |
+
# ----- handle the three-piece quantised linear -----
|
| 192 |
+
# `shared_experts` are plain (non-parallel) Linears in the model;
|
| 193 |
+
# never shard them even though `w1/w2/w3` are in the mapping.
|
| 194 |
+
if "shared_experts" in name:
|
| 195 |
+
dim = None
|
| 196 |
+
|
| 197 |
+
if orig_name.endswith(QUANT_SUFFIXES):
|
| 198 |
+
base = name.rsplit(".", 1)[0]
|
| 199 |
+
suf = name.rsplit(".", 1)[1] # qweight|qzeros|scales
|
| 200 |
+
pending.setdefault(base, {"_dim": dim})[suf] = tensor
|
| 201 |
+
pending[base]["_dim"] = dim
|
| 202 |
+
parts = pending[base]
|
| 203 |
+
if all(s in parts for s in ("qweight", "qzeros", "scales")):
|
| 204 |
+
emit_linear(base, parts, parts["_dim"])
|
| 205 |
+
del pending[base]
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
# ----- non-quantised tensor -----
|
| 209 |
+
if "experts" in name and "shared_experts" not in name:
|
| 210 |
+
idx = int(name.split(".experts.")[1].split(".")[0])
|
| 211 |
+
owner = idx // n_local_experts
|
| 212 |
+
state_dicts[owner][name] = tensor
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
if dim is None:
|
| 216 |
+
for i in range(mp):
|
| 217 |
+
state_dicts[i][name] = tensor
|
| 218 |
+
else:
|
| 219 |
+
assert tensor.size(dim) % mp == 0, f"{name} dim {dim} ({tensor.size(dim)}) not divisible by {mp}"
|
| 220 |
+
sh = tensor.size(dim) // mp
|
| 221 |
+
for i in range(mp):
|
| 222 |
+
state_dicts[i][name] = tensor.narrow(dim, i * sh, sh).contiguous()
|
| 223 |
+
|
| 224 |
+
if pending:
|
| 225 |
+
raise RuntimeError(f"Incomplete quantised linears: {list(pending)[:5]}")
|
| 226 |
+
|
| 227 |
+
os.makedirs(save_path, exist_ok=True)
|
| 228 |
+
for i in trange(mp, desc="write shards"):
|
| 229 |
+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
| 230 |
+
|
| 231 |
+
for fn in ["tokenizer.json", "tokenizer_config.json"]:
|
| 232 |
+
src = os.path.join(hf_ckpt_path, fn)
|
| 233 |
+
dst = os.path.join(save_path, fn)
|
| 234 |
+
if os.path.exists(src):
|
| 235 |
+
shutil.copyfile(src, dst)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
p = ArgumentParser()
|
| 240 |
+
p.add_argument("--hf-ckpt-path", required=True)
|
| 241 |
+
p.add_argument("--save-path", required=True)
|
| 242 |
+
p.add_argument("--n-experts", type=int, required=True)
|
| 243 |
+
p.add_argument("--model-parallel", type=int, required=True)
|
| 244 |
+
a = p.parse_args()
|
| 245 |
+
assert a.n_experts % a.model_parallel == 0
|
| 246 |
+
main(a.hf_ckpt_path, a.save_path, a.n_experts, a.model_parallel)
|
inference/generate.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Prevent gptqmodel from setting CUDA_DEVICE_ORDER=PCI_BUS_ID (breaks multi-GPU on some systems)
|
| 3 |
+
os.environ.setdefault("CUDA_DEVICE_ORDER", "FASTEST_FIRST")
|
| 4 |
+
import json
|
| 5 |
+
import sys
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from safetensors.torch import load_model
|
| 13 |
+
|
| 14 |
+
from model import Transformer, ModelArgs
|
| 15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
encoding_dir = os.path.join(current_dir, '../encoding')
|
| 17 |
+
sys.path.insert(0, os.path.abspath(encoding_dir))
|
| 18 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def sample(logits, temperature: float = 1.0):
|
| 22 |
+
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
|
| 23 |
+
since it avoids the GPU-to-CPU sync in torch.multinomial."""
|
| 24 |
+
logits = logits / max(temperature, 1e-5)
|
| 25 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 26 |
+
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@torch.inference_mode()
|
| 30 |
+
def generate(
|
| 31 |
+
model: Transformer,
|
| 32 |
+
prompt_tokens: List[List[int]],
|
| 33 |
+
max_new_tokens: int,
|
| 34 |
+
eos_id: int,
|
| 35 |
+
temperature: float = 1.0
|
| 36 |
+
) -> List[List[int]]:
|
| 37 |
+
"""Batch generation with left-padded prompts.
|
| 38 |
+
|
| 39 |
+
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
|
| 40 |
+
Subsequent passes generate one token at a time (decode phase). For positions
|
| 41 |
+
still within a prompt, the ground-truth token overrides the model's prediction.
|
| 42 |
+
"""
|
| 43 |
+
prompt_lens = [len(t) for t in prompt_tokens]
|
| 44 |
+
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 45 |
+
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 46 |
+
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
|
| 47 |
+
for i, t in enumerate(prompt_tokens):
|
| 48 |
+
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
|
| 49 |
+
prev_pos = 0
|
| 50 |
+
finished = torch.tensor([False] * len(prompt_tokens))
|
| 51 |
+
prompt_mask = tokens != -1
|
| 52 |
+
for cur_pos in range(min(prompt_lens), total_len):
|
| 53 |
+
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 54 |
+
if temperature > 0:
|
| 55 |
+
next_token = sample(logits, temperature)
|
| 56 |
+
else:
|
| 57 |
+
next_token = logits.argmax(dim=-1)
|
| 58 |
+
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
| 59 |
+
tokens[:, cur_pos] = next_token
|
| 60 |
+
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
| 61 |
+
prev_pos = cur_pos
|
| 62 |
+
if finished.all():
|
| 63 |
+
break
|
| 64 |
+
completion_tokens = []
|
| 65 |
+
for i, toks in enumerate(tokens.tolist()):
|
| 66 |
+
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
| 67 |
+
if eos_id in toks:
|
| 68 |
+
toks = toks[:toks.index(eos_id)]
|
| 69 |
+
toks.append(eos_id)
|
| 70 |
+
completion_tokens.append(toks)
|
| 71 |
+
return completion_tokens
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main(
|
| 75 |
+
ckpt_path: str,
|
| 76 |
+
config: str,
|
| 77 |
+
input_file: str = "",
|
| 78 |
+
interactive: bool = True,
|
| 79 |
+
max_new_tokens: int = 100,
|
| 80 |
+
temperature: float = 1.0,
|
| 81 |
+
) -> None:
|
| 82 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 83 |
+
rank = int(os.getenv("RANK", "0"))
|
| 84 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 85 |
+
if world_size > 1:
|
| 86 |
+
dist.init_process_group("nccl")
|
| 87 |
+
global print
|
| 88 |
+
if rank != 0:
|
| 89 |
+
print = lambda *_, **__: None
|
| 90 |
+
torch.cuda.set_device(local_rank)
|
| 91 |
+
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
| 92 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 93 |
+
torch.set_num_threads(8)
|
| 94 |
+
torch.manual_seed(33377335)
|
| 95 |
+
with open(config) as f:
|
| 96 |
+
args = ModelArgs(**json.load(f))
|
| 97 |
+
if interactive:
|
| 98 |
+
args.max_batch_size = 1
|
| 99 |
+
print(args)
|
| 100 |
+
with torch.device("cuda"):
|
| 101 |
+
model = Transformer(args)
|
| 102 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 103 |
+
print("load model")
|
| 104 |
+
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
|
| 105 |
+
if args.dtype == "w4a16":
|
| 106 |
+
model.init_woq_layers()
|
| 107 |
+
torch.set_default_device("cuda")
|
| 108 |
+
print("I'm DeepSeek 👋")
|
| 109 |
+
|
| 110 |
+
if interactive:
|
| 111 |
+
messages = []
|
| 112 |
+
while True:
|
| 113 |
+
if world_size == 1:
|
| 114 |
+
prompt = input(">>> ")
|
| 115 |
+
elif rank == 0:
|
| 116 |
+
prompt = input(">>> ")
|
| 117 |
+
objects = [prompt]
|
| 118 |
+
dist.broadcast_object_list(objects, 0)
|
| 119 |
+
else:
|
| 120 |
+
objects = [None]
|
| 121 |
+
dist.broadcast_object_list(objects, 0)
|
| 122 |
+
prompt = objects[0]
|
| 123 |
+
if prompt == "/exit":
|
| 124 |
+
break
|
| 125 |
+
elif prompt == "/clear":
|
| 126 |
+
messages.clear()
|
| 127 |
+
continue
|
| 128 |
+
messages.append({"role": "user", "content": prompt})
|
| 129 |
+
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
|
| 130 |
+
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 131 |
+
completion = tokenizer.decode(completion_tokens[0])
|
| 132 |
+
print(completion)
|
| 133 |
+
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
|
| 134 |
+
else:
|
| 135 |
+
with open(input_file) as f:
|
| 136 |
+
prompts = f.read().split("\n\n")
|
| 137 |
+
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
|
| 138 |
+
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 139 |
+
completions = tokenizer.batch_decode(completion_tokens)
|
| 140 |
+
for prompt, completion in zip(prompts, completions):
|
| 141 |
+
print("Prompt:", prompt)
|
| 142 |
+
print("Completion:", completion)
|
| 143 |
+
print()
|
| 144 |
+
|
| 145 |
+
if world_size > 1:
|
| 146 |
+
dist.destroy_process_group()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
parser = ArgumentParser()
|
| 151 |
+
parser.add_argument("--ckpt-path", type=str, required=True)
|
| 152 |
+
parser.add_argument("--config", type=str, required=True)
|
| 153 |
+
parser.add_argument("--input-file", type=str, default="")
|
| 154 |
+
parser.add_argument("--interactive", action="store_true")
|
| 155 |
+
parser.add_argument("--max-new-tokens", type=int, default=300)
|
| 156 |
+
parser.add_argument("--temperature", type=float, default=0.6)
|
| 157 |
+
args = parser.parse_args()
|
| 158 |
+
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
| 159 |
+
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
inference/kernel.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tilelang
|
| 3 |
+
import tilelang.language as T
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
tilelang.set_log_level("WARNING")
|
| 8 |
+
|
| 9 |
+
pass_configs = {
|
| 10 |
+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
| 11 |
+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
FP8 = "float8_e4m3"
|
| 15 |
+
FP4 = "float4_e2m1fn"
|
| 16 |
+
FE8M0 = "float8_e8m0fnu"
|
| 17 |
+
BF16 = "bfloat16"
|
| 18 |
+
FP32 = "float32"
|
| 19 |
+
INT32 = "int32"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def fast_log2_ceil(x):
|
| 23 |
+
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
|
| 24 |
+
bits_x = T.reinterpret("uint32", x)
|
| 25 |
+
exp_x = (bits_x >> 23) & 0xFF
|
| 26 |
+
man_bits = bits_x & ((1 << 23) - 1)
|
| 27 |
+
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def fast_pow2(x):
|
| 31 |
+
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
|
| 32 |
+
bits_x = (x + 127) << 23
|
| 33 |
+
return T.reinterpret("float32", bits_x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fast_round_scale(amax, fp8_max_inv):
|
| 37 |
+
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 41 |
+
def act_quant_kernel(
|
| 42 |
+
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
|
| 43 |
+
round_scale=False, inplace=False
|
| 44 |
+
):
|
| 45 |
+
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
|
| 46 |
+
M = T.symbolic("M")
|
| 47 |
+
fp8_min = -448.0
|
| 48 |
+
fp8_max = 448.0
|
| 49 |
+
fp8_max_inv = 1 / fp8_max
|
| 50 |
+
num_stages = 0 if round_scale or inplace else 2
|
| 51 |
+
blk_m = 32
|
| 52 |
+
group_size = block_size
|
| 53 |
+
# Internal computation in FP32; scale_dtype controls output storage format.
|
| 54 |
+
compute_dtype = FP32
|
| 55 |
+
out_dtype = in_dtype if inplace else out_dtype
|
| 56 |
+
|
| 57 |
+
@T.prim_func
|
| 58 |
+
def act_quant_kernel_(
|
| 59 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 60 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 61 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 62 |
+
):
|
| 63 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 64 |
+
pid_m,
|
| 65 |
+
pid_n,
|
| 66 |
+
):
|
| 67 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 68 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 69 |
+
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 70 |
+
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 71 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 72 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 73 |
+
|
| 74 |
+
for _ in T.Pipelined(1, num_stages=num_stages):
|
| 75 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 76 |
+
T.copy(x_shared, x_local)
|
| 77 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 78 |
+
for i in T.Parallel(blk_m):
|
| 79 |
+
amax_local[i] = T.max(amax_local[i], 1e-4)
|
| 80 |
+
if round_scale:
|
| 81 |
+
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
| 82 |
+
else:
|
| 83 |
+
s_local[i] = amax_local[i] * fp8_max_inv
|
| 84 |
+
if inplace:
|
| 85 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 86 |
+
y_local[i, j] = T.Cast(
|
| 87 |
+
out_dtype,
|
| 88 |
+
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
|
| 89 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 90 |
+
))) * s_local[i],
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 94 |
+
y_local[i, j] = T.clamp(
|
| 95 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 96 |
+
)
|
| 97 |
+
for i in T.Parallel(blk_m):
|
| 98 |
+
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
| 99 |
+
T.copy(y_local, y_shared)
|
| 100 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 101 |
+
|
| 102 |
+
return act_quant_kernel_
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def act_quant(
|
| 106 |
+
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
|
| 107 |
+
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
|
| 110 |
+
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
|
| 111 |
+
N = x.size(-1)
|
| 112 |
+
assert N % block_size == 0
|
| 113 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 114 |
+
z = x.contiguous()
|
| 115 |
+
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
|
| 116 |
+
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
|
| 117 |
+
kernel = act_quant_kernel(
|
| 118 |
+
N, block_size, scale_dtype=tl_dtype,
|
| 119 |
+
round_scale=scale_fmt is not None, inplace=inplace,
|
| 120 |
+
)
|
| 121 |
+
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
| 122 |
+
if inplace:
|
| 123 |
+
x.copy_(y)
|
| 124 |
+
return x
|
| 125 |
+
return y, s
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 129 |
+
def fp4_quant_kernel(
|
| 130 |
+
N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
|
| 131 |
+
):
|
| 132 |
+
"""Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
|
| 133 |
+
M = T.symbolic("M")
|
| 134 |
+
fp4_max = 6.0
|
| 135 |
+
fp4_max_inv = 1.0 / fp4_max
|
| 136 |
+
blk_m = 32
|
| 137 |
+
group_size = block_size
|
| 138 |
+
compute_dtype = FP32
|
| 139 |
+
out_dtype = in_dtype if inplace else FP4
|
| 140 |
+
|
| 141 |
+
@T.prim_func
|
| 142 |
+
def fp4_quant_kernel_(
|
| 143 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 144 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 145 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 146 |
+
):
|
| 147 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 148 |
+
pid_m,
|
| 149 |
+
pid_n,
|
| 150 |
+
):
|
| 151 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 152 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 153 |
+
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 154 |
+
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 155 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 156 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 157 |
+
|
| 158 |
+
for _ in T.Pipelined(1, num_stages=2):
|
| 159 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 160 |
+
T.copy(x_shared, x_local)
|
| 161 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 162 |
+
for i in T.Parallel(blk_m):
|
| 163 |
+
amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
|
| 164 |
+
s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
|
| 165 |
+
if inplace:
|
| 166 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 167 |
+
y_local[i, j] = T.Cast(
|
| 168 |
+
out_dtype,
|
| 169 |
+
T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
|
| 170 |
+
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
| 171 |
+
))) * s_local[i],
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 175 |
+
y_local[i, j] = T.clamp(
|
| 176 |
+
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
| 177 |
+
)
|
| 178 |
+
for i in T.Parallel(blk_m):
|
| 179 |
+
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
| 180 |
+
T.copy(y_local, y_shared)
|
| 181 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 182 |
+
|
| 183 |
+
return fp4_quant_kernel_
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def fp4_act_quant(
|
| 187 |
+
x: torch.Tensor, block_size: int = 32, inplace: bool = False,
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
"""Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
|
| 190 |
+
N = x.size(-1)
|
| 191 |
+
assert N % block_size == 0
|
| 192 |
+
z = x.contiguous()
|
| 193 |
+
y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
|
| 194 |
+
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
|
| 195 |
+
kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
|
| 196 |
+
kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
|
| 197 |
+
if inplace:
|
| 198 |
+
x.copy_(y)
|
| 199 |
+
return x
|
| 200 |
+
return y, s
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 204 |
+
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
| 205 |
+
assert out_dtype in [BF16, FP32]
|
| 206 |
+
|
| 207 |
+
M = T.symbolic("M")
|
| 208 |
+
group_size = 128
|
| 209 |
+
block_M = 32
|
| 210 |
+
block_N = 128
|
| 211 |
+
block_K = 128
|
| 212 |
+
|
| 213 |
+
@T.prim_func
|
| 214 |
+
def fp8_gemm_kernel_(
|
| 215 |
+
A: T.Tensor[(M, K), FP8],
|
| 216 |
+
B: T.Tensor[(N, K), FP8],
|
| 217 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 218 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
|
| 219 |
+
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
|
| 220 |
+
):
|
| 221 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 222 |
+
bx,
|
| 223 |
+
by,
|
| 224 |
+
):
|
| 225 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 226 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 227 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 228 |
+
Scale_C_shared = T.alloc_shared((block_M), FP32)
|
| 229 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 230 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 231 |
+
|
| 232 |
+
# Improve L2 Cache
|
| 233 |
+
T.use_swizzle(panel_size=10)
|
| 234 |
+
T.clear(C_local)
|
| 235 |
+
T.clear(C_local_accum)
|
| 236 |
+
|
| 237 |
+
K_iters = T.ceildiv(K, block_K)
|
| 238 |
+
for k in T.Pipelined(K_iters, num_stages=4):
|
| 239 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 240 |
+
T.copy(B[bx * block_N, k * block_K], B_shared)
|
| 241 |
+
# Cast scales to FP32 for computation; scales_b has one value per block_N group
|
| 242 |
+
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
|
| 243 |
+
for i in T.Parallel(block_M):
|
| 244 |
+
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
|
| 245 |
+
|
| 246 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 247 |
+
# Separate accumulator for scale-corrected results (2x accumulation precision)
|
| 248 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 249 |
+
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
| 250 |
+
T.clear(C_local)
|
| 251 |
+
T.copy(C_local_accum, C_shared)
|
| 252 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 253 |
+
|
| 254 |
+
return fp8_gemm_kernel_
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def fp8_gemm(
|
| 258 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
| 259 |
+
scale_dtype: torch.dtype = torch.float32,
|
| 260 |
+
) -> torch.Tensor:
|
| 261 |
+
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
|
| 262 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 263 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 264 |
+
"Scaling factor tensors must be contiguous"
|
| 265 |
+
)
|
| 266 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 267 |
+
K = a.size(-1)
|
| 268 |
+
M = a.numel() // K
|
| 269 |
+
N = b.size(0)
|
| 270 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 271 |
+
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
| 272 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 273 |
+
return c
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 277 |
+
def sparse_attn_kernel(h: int, d: int, scale=None):
|
| 278 |
+
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
|
| 279 |
+
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
|
| 280 |
+
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
|
| 281 |
+
b = T.symbolic("b")
|
| 282 |
+
m = T.symbolic("m")
|
| 283 |
+
n = T.symbolic("n")
|
| 284 |
+
topk = T.symbolic("topk")
|
| 285 |
+
if scale is None:
|
| 286 |
+
scale = (1.0 / d) ** 0.5
|
| 287 |
+
|
| 288 |
+
num_stages = 2
|
| 289 |
+
threads = 256
|
| 290 |
+
block = 64
|
| 291 |
+
num_blocks = tilelang.cdiv(topk, block)
|
| 292 |
+
|
| 293 |
+
@T.prim_func
|
| 294 |
+
def sparse_attn_kernel_(
|
| 295 |
+
q: T.Tensor[(b, m, h, d), BF16],
|
| 296 |
+
kv: T.Tensor[(b, n, d), BF16],
|
| 297 |
+
o: T.Tensor[(b, m, h, d), BF16],
|
| 298 |
+
attn_sink: T.Tensor[(h,), FP32],
|
| 299 |
+
topk_idxs: T.Tensor[(b, m, topk), INT32],
|
| 300 |
+
):
|
| 301 |
+
with T.Kernel(m, b, threads=threads) as (bx, by):
|
| 302 |
+
q_shared = T.alloc_shared((h, d), BF16)
|
| 303 |
+
kv_shared = T.alloc_shared((block, d), BF16)
|
| 304 |
+
o_shared = T.alloc_shared((h, d), BF16)
|
| 305 |
+
acc_s_cast = T.alloc_shared((h, block), BF16)
|
| 306 |
+
|
| 307 |
+
idxs = T.alloc_fragment(block, INT32)
|
| 308 |
+
acc_s = T.alloc_fragment((h, block), FP32)
|
| 309 |
+
acc_o = T.alloc_fragment((h, d), FP32)
|
| 310 |
+
scores_max = T.alloc_fragment(h, FP32)
|
| 311 |
+
scores_max_prev = T.alloc_fragment(h, FP32)
|
| 312 |
+
scores_scale = T.alloc_fragment(h, FP32)
|
| 313 |
+
scores_sum = T.alloc_fragment(h, FP32)
|
| 314 |
+
sum_exp = T.alloc_fragment(h, FP32)
|
| 315 |
+
|
| 316 |
+
T.clear(acc_o)
|
| 317 |
+
T.clear(sum_exp)
|
| 318 |
+
T.fill(scores_max, -T.infinity(FP32))
|
| 319 |
+
T.copy(q[by, bx, :, :], q_shared)
|
| 320 |
+
|
| 321 |
+
for t in T.Pipelined(num_blocks, num_stages=num_stages):
|
| 322 |
+
for i in T.Parallel(block):
|
| 323 |
+
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
|
| 324 |
+
for i, j in T.Parallel(block, d):
|
| 325 |
+
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
|
| 326 |
+
for i, j in T.Parallel(h, block):
|
| 327 |
+
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
|
| 328 |
+
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
|
| 329 |
+
for i, j in T.Parallel(h, block):
|
| 330 |
+
acc_s[i, j] *= scale
|
| 331 |
+
T.copy(scores_max, scores_max_prev)
|
| 332 |
+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
|
| 333 |
+
for i in T.Parallel(h):
|
| 334 |
+
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
|
| 335 |
+
for i, j in T.Parallel(h, block):
|
| 336 |
+
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
|
| 337 |
+
T.reduce_sum(acc_s, scores_sum, dim=1)
|
| 338 |
+
for i in T.Parallel(h):
|
| 339 |
+
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
|
| 340 |
+
T.copy(acc_s, acc_s_cast)
|
| 341 |
+
for i, j in T.Parallel(h, d):
|
| 342 |
+
acc_o[i, j] *= scores_scale[i]
|
| 343 |
+
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
|
| 344 |
+
|
| 345 |
+
for i in T.Parallel(h):
|
| 346 |
+
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
|
| 347 |
+
for i, j in T.Parallel(h, d):
|
| 348 |
+
acc_o[i, j] /= sum_exp[i]
|
| 349 |
+
T.copy(acc_o, o_shared)
|
| 350 |
+
T.copy(o_shared, o[by, bx, :, :])
|
| 351 |
+
|
| 352 |
+
return sparse_attn_kernel_
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def sparse_attn(
|
| 356 |
+
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
b, s, h, d = q.size()
|
| 359 |
+
# Pad heads to 16 for kernel efficiency (stripped after)
|
| 360 |
+
if h < 16:
|
| 361 |
+
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
|
| 362 |
+
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
|
| 363 |
+
o = torch.empty_like(q)
|
| 364 |
+
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
|
| 365 |
+
kernel(q, kv, o, attn_sink, topk_idxs)
|
| 366 |
+
if h < 16:
|
| 367 |
+
o = o.narrow(2, 0, h).contiguous()
|
| 368 |
+
return o
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 372 |
+
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
|
| 373 |
+
n = T.symbolic("n")
|
| 374 |
+
mix_hc = (2 + hc) * hc
|
| 375 |
+
threads = 64
|
| 376 |
+
|
| 377 |
+
@T.prim_func
|
| 378 |
+
def hc_split_sinkhorn_kernel_(
|
| 379 |
+
mixes: T.Tensor[(n, mix_hc), FP32],
|
| 380 |
+
hc_scale: T.Tensor[(3,), FP32],
|
| 381 |
+
hc_base: T.Tensor[(mix_hc,), FP32],
|
| 382 |
+
pre: T.Tensor[(n, hc), FP32],
|
| 383 |
+
post: T.Tensor[(n, hc), FP32],
|
| 384 |
+
comb: T.Tensor[(n, hc, hc), FP32],
|
| 385 |
+
):
|
| 386 |
+
with T.Kernel(n, threads=threads) as i:
|
| 387 |
+
mixes_shared = T.alloc_shared(mix_hc, FP32)
|
| 388 |
+
comb_frag = T.alloc_fragment((hc, hc), FP32)
|
| 389 |
+
T.copy(mixes[i, :], mixes_shared)
|
| 390 |
+
|
| 391 |
+
for j in T.Parallel(hc):
|
| 392 |
+
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
|
| 393 |
+
for j in T.Parallel(hc):
|
| 394 |
+
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
|
| 395 |
+
for j, k in T.Parallel(hc, hc):
|
| 396 |
+
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
|
| 397 |
+
|
| 398 |
+
row_sum = T.alloc_fragment(hc, FP32)
|
| 399 |
+
col_sum = T.alloc_fragment(hc, FP32)
|
| 400 |
+
|
| 401 |
+
# comb = comb.softmax(-1) + eps
|
| 402 |
+
row_max = T.alloc_fragment(hc, FP32)
|
| 403 |
+
T.reduce_max(comb_frag, row_max, dim=1)
|
| 404 |
+
for j, k in T.Parallel(hc, hc):
|
| 405 |
+
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
|
| 406 |
+
T.reduce_sum(comb_frag, row_sum, dim=1)
|
| 407 |
+
for j, k in T.Parallel(hc, hc):
|
| 408 |
+
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
|
| 409 |
+
|
| 410 |
+
# comb = comb / (comb.sum(-2) + eps)
|
| 411 |
+
T.reduce_sum(comb_frag, col_sum, dim=0)
|
| 412 |
+
for j, k in T.Parallel(hc, hc):
|
| 413 |
+
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
| 414 |
+
|
| 415 |
+
for _ in T.serial(sinkhorn_iters - 1):
|
| 416 |
+
# comb = comb / (comb.sum(-1) + eps)
|
| 417 |
+
T.reduce_sum(comb_frag, row_sum, dim=1)
|
| 418 |
+
for j, k in T.Parallel(hc, hc):
|
| 419 |
+
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
|
| 420 |
+
# comb = comb / (comb.sum(-2) + eps)
|
| 421 |
+
T.reduce_sum(comb_frag, col_sum, dim=0)
|
| 422 |
+
for j, k in T.Parallel(hc, hc):
|
| 423 |
+
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
| 424 |
+
|
| 425 |
+
T.copy(comb_frag, comb[i, :, :])
|
| 426 |
+
|
| 427 |
+
return hc_split_sinkhorn_kernel_
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
|
| 431 |
+
b, s, _ = mixes.size()
|
| 432 |
+
pre = mixes.new_empty(b, s, hc_mult)
|
| 433 |
+
post = mixes.new_empty(b, s, hc_mult)
|
| 434 |
+
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
|
| 435 |
+
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
|
| 436 |
+
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
|
| 437 |
+
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
|
| 438 |
+
return pre, post, comb
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 442 |
+
def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
| 443 |
+
"""FP8 act x FP4 weight GEMM kernel.
|
| 444 |
+
|
| 445 |
+
C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
|
| 446 |
+
|
| 447 |
+
Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
|
| 448 |
+
Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
|
| 449 |
+
|
| 450 |
+
B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
|
| 451 |
+
The FP4 values are packed along the K (last) dimension.
|
| 452 |
+
|
| 453 |
+
Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
|
| 454 |
+
cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
|
| 455 |
+
Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
|
| 456 |
+
"""
|
| 457 |
+
M = T.symbolic("M")
|
| 458 |
+
act_group_size = 128
|
| 459 |
+
weight_group_size = 32
|
| 460 |
+
block_M = 32
|
| 461 |
+
block_N = 128
|
| 462 |
+
block_K = 32 # matches weight_group_size for simple scale handling
|
| 463 |
+
n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
|
| 464 |
+
|
| 465 |
+
@T.prim_func
|
| 466 |
+
def fp4_gemm_kernel_(
|
| 467 |
+
A: T.Tensor[(M, K), FP8],
|
| 468 |
+
B: T.Tensor[(N, K), FP4],
|
| 469 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 470 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
|
| 471 |
+
scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
|
| 472 |
+
):
|
| 473 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 474 |
+
bx,
|
| 475 |
+
by,
|
| 476 |
+
):
|
| 477 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 478 |
+
B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
|
| 479 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 480 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 481 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 482 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 483 |
+
scale_a_frag = T.alloc_fragment((block_M,), FP32)
|
| 484 |
+
scale_b_frag = T.alloc_fragment((block_N,), FP32)
|
| 485 |
+
|
| 486 |
+
T.use_swizzle(panel_size=10)
|
| 487 |
+
T.clear(C_local)
|
| 488 |
+
T.clear(C_local_accum)
|
| 489 |
+
|
| 490 |
+
K_iters = T.ceildiv(K, block_K)
|
| 491 |
+
for k in T.Pipelined(K_iters, num_stages=2):
|
| 492 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 493 |
+
T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
|
| 494 |
+
# FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
|
| 495 |
+
for i, j in T.Parallel(block_N, block_K):
|
| 496 |
+
B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
|
| 497 |
+
|
| 498 |
+
# Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
|
| 499 |
+
for i in T.Parallel(block_N):
|
| 500 |
+
scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
|
| 501 |
+
|
| 502 |
+
# Act scale: per 128 on K, indexed by k // 4
|
| 503 |
+
for i in T.Parallel(block_M):
|
| 504 |
+
scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
|
| 505 |
+
|
| 506 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 507 |
+
|
| 508 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 509 |
+
C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
|
| 510 |
+
T.clear(C_local)
|
| 511 |
+
|
| 512 |
+
T.copy(C_local_accum, C_shared)
|
| 513 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 514 |
+
|
| 515 |
+
return fp4_gemm_kernel_
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def fp4_gemm(
|
| 519 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
| 520 |
+
scale_dtype: torch.dtype = torch.float32,
|
| 521 |
+
) -> torch.Tensor:
|
| 522 |
+
"""C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
|
| 523 |
+
A has per-128 act scale; B has per-32 E8M0 weight scale.
|
| 524 |
+
B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
|
| 525 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 526 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 527 |
+
"Scaling factor tensors must be contiguous"
|
| 528 |
+
)
|
| 529 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 530 |
+
K = a.size(-1)
|
| 531 |
+
M = a.numel() // K
|
| 532 |
+
N = b.size(0)
|
| 533 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 534 |
+
kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
| 535 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 536 |
+
return c
|
inference/model.py
ADDED
|
@@ -0,0 +1,992 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Optional, Literal
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
world_size = 1
|
| 16 |
+
rank = 0
|
| 17 |
+
block_size = 128
|
| 18 |
+
fp4_block_size = 32
|
| 19 |
+
w4a16_group_size = 128
|
| 20 |
+
default_dtype = torch.bfloat16
|
| 21 |
+
scale_fmt = None
|
| 22 |
+
scale_dtype = torch.float32
|
| 23 |
+
w4a16_mode = False # set in Transformer.__init__ when args.dtype == "w4a16"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def dequantize_w4a16(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
|
| 27 |
+
group_size: int = 128) -> torch.Tensor:
|
| 28 |
+
"""Auto-round / auto_gptq W4A16 packing -> BF16 weight [out, in].
|
| 29 |
+
|
| 30 |
+
qweight: int32 [in/8, out], LSB-first 4-bit packed along input dim
|
| 31 |
+
qzeros : int32 [in/g, out/8], LSB-first 4-bit packed along output dim
|
| 32 |
+
scales : bf16 [in/g, out]
|
| 33 |
+
"""
|
| 34 |
+
in_packed, out_features = qweight.shape
|
| 35 |
+
in_features = in_packed * 8
|
| 36 |
+
n_groups = scales.shape[0]
|
| 37 |
+
device = qweight.device
|
| 38 |
+
shifts = torch.arange(0, 32, 4, device=device, dtype=torch.int32)
|
| 39 |
+
w = (qweight.unsqueeze(1) >> shifts.view(1, 8, 1)) & 0xF # [in/8, 8, out]
|
| 40 |
+
w = w.reshape(in_features, out_features).to(torch.float32)
|
| 41 |
+
z = (qzeros.unsqueeze(2) >> shifts.view(1, 1, 8)) & 0xF # [in/g, out/8, 8]
|
| 42 |
+
z = z.reshape(n_groups, out_features).to(torch.float32) + 1.0 # GPTQ stores zero - 1
|
| 43 |
+
s = scales.to(torch.float32)
|
| 44 |
+
w = w.view(n_groups, group_size, out_features)
|
| 45 |
+
deq = (w - z.unsqueeze(1)) * s.unsqueeze(1)
|
| 46 |
+
deq = deq.view(in_features, out_features)
|
| 47 |
+
return deq.t().contiguous().to(torch.bfloat16)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@contextmanager
|
| 51 |
+
def set_dtype(dtype):
|
| 52 |
+
"""Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
|
| 53 |
+
prev = torch.get_default_dtype()
|
| 54 |
+
torch.set_default_dtype(dtype)
|
| 55 |
+
try:
|
| 56 |
+
yield
|
| 57 |
+
finally:
|
| 58 |
+
torch.set_default_dtype(prev)
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class ModelArgs:
|
| 62 |
+
"""Model hyperparameters. Field names match the config JSON keys."""
|
| 63 |
+
max_batch_size: int = 4
|
| 64 |
+
max_seq_len: int = 4096
|
| 65 |
+
dtype: Literal["bf16", "fp8", "w4a16"] = "fp8"
|
| 66 |
+
scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
|
| 67 |
+
expert_dtype: Literal[None, "fp4"] = None
|
| 68 |
+
scale_dtype: Literal["fp32", "fp8"] = "fp8"
|
| 69 |
+
vocab_size: int = 129280
|
| 70 |
+
dim: int = 4096
|
| 71 |
+
moe_inter_dim: int = 4096
|
| 72 |
+
n_layers: int = 7
|
| 73 |
+
n_hash_layers: int = 0
|
| 74 |
+
n_mtp_layers: int = 1
|
| 75 |
+
n_heads: int = 64
|
| 76 |
+
# moe
|
| 77 |
+
n_routed_experts: int = 8
|
| 78 |
+
n_shared_experts: int = 1
|
| 79 |
+
n_activated_experts: int = 2
|
| 80 |
+
score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
|
| 81 |
+
route_scale: float = 1.
|
| 82 |
+
swiglu_limit: float = 0.
|
| 83 |
+
# mqa
|
| 84 |
+
q_lora_rank: int = 1024
|
| 85 |
+
head_dim: int = 512
|
| 86 |
+
rope_head_dim: int = 64
|
| 87 |
+
norm_eps: float = 1e-6
|
| 88 |
+
o_groups: int = 8
|
| 89 |
+
o_lora_rank: int = 1024
|
| 90 |
+
window_size: int = 128
|
| 91 |
+
compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
|
| 92 |
+
# yarn
|
| 93 |
+
compress_rope_theta: float = 40000.0
|
| 94 |
+
original_seq_len: int = 0
|
| 95 |
+
rope_theta: float = 10000.0
|
| 96 |
+
rope_factor: float = 40
|
| 97 |
+
beta_fast: int = 32
|
| 98 |
+
beta_slow: int = 1
|
| 99 |
+
# index
|
| 100 |
+
index_n_heads: int = 64
|
| 101 |
+
index_head_dim: int = 128
|
| 102 |
+
index_topk: int = 512
|
| 103 |
+
# hc
|
| 104 |
+
hc_mult: int = 4
|
| 105 |
+
hc_sinkhorn_iters: int = 20
|
| 106 |
+
hc_eps: float = 1e-6
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ParallelEmbedding(nn.Module):
|
| 110 |
+
"""Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
|
| 111 |
+
Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
|
| 112 |
+
def __init__(self, vocab_size: int, dim: int):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.vocab_size = vocab_size
|
| 115 |
+
self.dim = dim
|
| 116 |
+
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
| 117 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 118 |
+
self.vocab_start_idx = rank * self.part_vocab_size
|
| 119 |
+
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
| 120 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
| 121 |
+
|
| 122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
if world_size > 1:
|
| 124 |
+
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
| 125 |
+
x = x - self.vocab_start_idx
|
| 126 |
+
x[mask] = 0
|
| 127 |
+
y = F.embedding(x, self.weight)
|
| 128 |
+
if world_size > 1:
|
| 129 |
+
y[mask] = 0
|
| 130 |
+
dist.all_reduce(y)
|
| 131 |
+
return y
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 135 |
+
"""Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
|
| 136 |
+
For quantized weights, x is first quantized to FP8 via act_quant."""
|
| 137 |
+
assert bias is None
|
| 138 |
+
|
| 139 |
+
if weight.dtype == torch.float4_e2m1fn_x2:
|
| 140 |
+
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
| 141 |
+
return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
|
| 142 |
+
elif weight.dtype == torch.float8_e4m3fn:
|
| 143 |
+
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
| 144 |
+
return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
|
| 145 |
+
else:
|
| 146 |
+
return F.linear(x, weight)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Linear(nn.Module):
|
| 150 |
+
"""Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
|
| 151 |
+
|
| 152 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.in_features = in_features
|
| 155 |
+
self.out_features = out_features
|
| 156 |
+
# In a W4A16 build every Linear becomes W4A16 regardless of the dtype the
|
| 157 |
+
# original FP8/FP4 model wanted. The non-quant special cases (RMSNorm,
|
| 158 |
+
# embed, attn_sink, etc.) are NOT instances of `Linear`, so they are
|
| 159 |
+
# untouched.
|
| 160 |
+
if w4a16_mode:
|
| 161 |
+
dtype = "w4a16"
|
| 162 |
+
else:
|
| 163 |
+
dtype = dtype or default_dtype
|
| 164 |
+
self.is_w4a16 = (dtype == "w4a16")
|
| 165 |
+
if self.is_w4a16:
|
| 166 |
+
assert in_features % 8 == 0 and in_features % w4a16_group_size == 0
|
| 167 |
+
assert out_features % 8 == 0
|
| 168 |
+
self.group_size = w4a16_group_size
|
| 169 |
+
self.qweight = nn.Parameter(
|
| 170 |
+
torch.empty(in_features // 8, out_features, dtype=torch.int32),
|
| 171 |
+
requires_grad=False,
|
| 172 |
+
)
|
| 173 |
+
self.qzeros = nn.Parameter(
|
| 174 |
+
torch.empty(in_features // self.group_size, out_features // 8, dtype=torch.int32),
|
| 175 |
+
requires_grad=False,
|
| 176 |
+
)
|
| 177 |
+
self.scales = nn.Parameter(
|
| 178 |
+
torch.empty(in_features // self.group_size, out_features, dtype=torch.bfloat16),
|
| 179 |
+
requires_grad=False,
|
| 180 |
+
)
|
| 181 |
+
self.register_parameter("weight", None)
|
| 182 |
+
self.register_parameter("scale", None)
|
| 183 |
+
elif dtype == torch.float4_e2m1fn_x2:
|
| 184 |
+
# FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
|
| 185 |
+
# Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
|
| 186 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
|
| 187 |
+
scale_out_features = out_features
|
| 188 |
+
scale_in_features = in_features // fp4_block_size
|
| 189 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
| 190 |
+
elif dtype == torch.float8_e4m3fn:
|
| 191 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
| 192 |
+
scale_out_features = (out_features + block_size - 1) // block_size
|
| 193 |
+
scale_in_features = (in_features + block_size - 1) // block_size
|
| 194 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
| 195 |
+
else:
|
| 196 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
| 197 |
+
self.register_parameter("scale", None)
|
| 198 |
+
if bias:
|
| 199 |
+
self.bias = nn.Parameter(torch.empty(out_features))
|
| 200 |
+
else:
|
| 201 |
+
self.register_parameter("bias", None)
|
| 202 |
+
|
| 203 |
+
def init_woq(self, QuantLinear):
|
| 204 |
+
"""Create a QuantLinear from loaded GPTQ parameters."""
|
| 205 |
+
if not self.is_w4a16:
|
| 206 |
+
return
|
| 207 |
+
# Marlin requires out_features % 64 == 0; fall back to manual dequant
|
| 208 |
+
if self.out_features % 64 != 0:
|
| 209 |
+
self._woq = None
|
| 210 |
+
return
|
| 211 |
+
dev = self.qweight.device
|
| 212 |
+
layer = QuantLinear(
|
| 213 |
+
bits=4, group_size=self.group_size,
|
| 214 |
+
in_features=self.in_features, out_features=self.out_features,
|
| 215 |
+
bias=False, desc_act=False, sym=True, register_buffers=True,
|
| 216 |
+
)
|
| 217 |
+
layer = layer.to(dev)
|
| 218 |
+
layer.qweight.copy_(self.qweight.data)
|
| 219 |
+
layer.qzeros.copy_(self.qzeros.data)
|
| 220 |
+
layer.scales.copy_(self.scales.to(layer.scales.dtype).data)
|
| 221 |
+
layer.g_idx.copy_(torch.arange(self.in_features, dtype=torch.int32, device=dev) // self.group_size)
|
| 222 |
+
layer.post_init()
|
| 223 |
+
self._woq = layer
|
| 224 |
+
# Free original parameters to save memory
|
| 225 |
+
self.qweight = None
|
| 226 |
+
self.qzeros = None
|
| 227 |
+
self.scales = None
|
| 228 |
+
|
| 229 |
+
def get_weight(self) -> torch.Tensor:
|
| 230 |
+
"""Return the dequantised BF16 weight [out, in]. For non-W4A16 modes
|
| 231 |
+
returns ``self.weight`` unchanged. Used only for wo_a einsum path."""
|
| 232 |
+
if self.is_w4a16:
|
| 233 |
+
if self._woq is not None:
|
| 234 |
+
return dequantize_w4a16(self._woq.qweight, self._woq.qzeros, self._woq.scales, self.group_size)
|
| 235 |
+
return dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
|
| 236 |
+
return self.weight
|
| 237 |
+
|
| 238 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
if self.is_w4a16:
|
| 240 |
+
if hasattr(self, '_woq') and self._woq is not None:
|
| 241 |
+
y = self._woq(x.to(torch.bfloat16))
|
| 242 |
+
else:
|
| 243 |
+
w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
|
| 244 |
+
y = F.linear(x.to(w.dtype), w)
|
| 245 |
+
if self.bias is not None:
|
| 246 |
+
y = y + self.bias
|
| 247 |
+
return y.type_as(x)
|
| 248 |
+
return linear(x, self.weight, self.bias)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class ColumnParallelLinear(Linear):
|
| 252 |
+
"""Shards output dim across TP ranks. No all-reduce needed on output."""
|
| 253 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 254 |
+
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
| 255 |
+
self.part_out_features = out_features // world_size
|
| 256 |
+
super().__init__(in_features, self.part_out_features, bias, dtype)
|
| 257 |
+
|
| 258 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 259 |
+
if self.is_w4a16:
|
| 260 |
+
return Linear.forward(self, x)
|
| 261 |
+
return linear(x, self.weight, self.bias)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class RowParallelLinear(Linear):
|
| 265 |
+
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
|
| 266 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 267 |
+
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
| 268 |
+
self.part_in_features = in_features // world_size
|
| 269 |
+
super().__init__(self.part_in_features, out_features, bias, dtype)
|
| 270 |
+
|
| 271 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 272 |
+
if self.is_w4a16:
|
| 273 |
+
if hasattr(self, '_woq') and self._woq is not None:
|
| 274 |
+
y = self._woq(x.to(torch.bfloat16))
|
| 275 |
+
else:
|
| 276 |
+
w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
|
| 277 |
+
y = F.linear(x.to(w.dtype), w)
|
| 278 |
+
else:
|
| 279 |
+
y = linear(x, self.weight, None)
|
| 280 |
+
if world_size > 1:
|
| 281 |
+
y = y.float()
|
| 282 |
+
dist.all_reduce(y)
|
| 283 |
+
if self.bias is not None:
|
| 284 |
+
y += self.bias
|
| 285 |
+
return y.type_as(x)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class RMSNorm(nn.Module):
|
| 289 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.dim = dim
|
| 292 |
+
self.eps = eps
|
| 293 |
+
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 294 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 295 |
+
|
| 296 |
+
def forward(self, x: torch.Tensor):
|
| 297 |
+
dtype = x.dtype
|
| 298 |
+
x = x.float()
|
| 299 |
+
var = x.square().mean(-1, keepdim=True)
|
| 300 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 301 |
+
return (self.weight * x).to(dtype)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@lru_cache(2)
|
| 305 |
+
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
|
| 306 |
+
"""Precomputes complex exponentials for rotary embeddings with YaRN scaling.
|
| 307 |
+
When original_seq_len > 0, applies frequency interpolation with a smooth
|
| 308 |
+
linear ramp between beta_fast and beta_slow correction ranges."""
|
| 309 |
+
|
| 310 |
+
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
| 311 |
+
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
| 312 |
+
|
| 313 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
| 314 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
| 315 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
| 316 |
+
return max(low, 0), min(high, dim-1)
|
| 317 |
+
|
| 318 |
+
def linear_ramp_factor(min, max, dim):
|
| 319 |
+
if min == max:
|
| 320 |
+
max += 0.001
|
| 321 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 322 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 323 |
+
return ramp_func
|
| 324 |
+
|
| 325 |
+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 326 |
+
if original_seq_len > 0:
|
| 327 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
|
| 328 |
+
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
| 329 |
+
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
| 330 |
+
|
| 331 |
+
t = torch.arange(seqlen)
|
| 332 |
+
freqs = torch.outer(t, freqs)
|
| 333 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 334 |
+
return freqs_cis
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
|
| 338 |
+
"""Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
|
| 339 |
+
y = x
|
| 340 |
+
x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
|
| 341 |
+
if inverse:
|
| 342 |
+
freqs_cis = freqs_cis.conj()
|
| 343 |
+
if x.ndim == 3:
|
| 344 |
+
freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
|
| 345 |
+
else:
|
| 346 |
+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 347 |
+
x = torch.view_as_real(x * freqs_cis).flatten(-2)
|
| 348 |
+
y.copy_(x)
|
| 349 |
+
return y
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
| 353 |
+
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
|
| 354 |
+
assert x.dtype == torch.bfloat16
|
| 355 |
+
from fast_hadamard_transform import hadamard_transform
|
| 356 |
+
return hadamard_transform(x, scale=x.size(-1) ** -0.5)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@lru_cache(1)
|
| 360 |
+
def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
|
| 361 |
+
if start_pos >= window_size - 1:
|
| 362 |
+
start_pos %= window_size
|
| 363 |
+
matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
|
| 364 |
+
elif start_pos > 0:
|
| 365 |
+
matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
|
| 366 |
+
else:
|
| 367 |
+
base = torch.arange(seqlen).unsqueeze(1)
|
| 368 |
+
matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
|
| 369 |
+
matrix = torch.where(matrix > base, -1, matrix)
|
| 370 |
+
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@lru_cache(2)
|
| 374 |
+
def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
|
| 375 |
+
if start_pos > 0:
|
| 376 |
+
matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
|
| 377 |
+
else:
|
| 378 |
+
matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
|
| 379 |
+
mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 380 |
+
matrix = torch.where(mask, -1, matrix + offset)
|
| 381 |
+
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class Compressor(nn.Module):
|
| 385 |
+
"""Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
|
| 386 |
+
When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
|
| 387 |
+
|
| 388 |
+
def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.dim = args.dim
|
| 391 |
+
self.head_dim = head_dim
|
| 392 |
+
self.rope_head_dim = args.rope_head_dim
|
| 393 |
+
self.nope_head_dim = head_dim - args.rope_head_dim
|
| 394 |
+
self.compress_ratio = compress_ratio
|
| 395 |
+
self.overlap = compress_ratio == 4
|
| 396 |
+
self.rotate = rotate
|
| 397 |
+
coff = 1 + self.overlap
|
| 398 |
+
|
| 399 |
+
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
| 400 |
+
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 401 |
+
# When overlap, the first half of dims is for overlapping compression, second half for normal.
|
| 402 |
+
self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
| 403 |
+
self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
| 404 |
+
self.norm = RMSNorm(self.head_dim, args.norm_eps)
|
| 405 |
+
self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
|
| 406 |
+
# State buffers for decode-phase incremental compression.
|
| 407 |
+
# With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
|
| 408 |
+
self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
|
| 409 |
+
self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
|
| 410 |
+
self.freqs_cis: torch.Tensor = None
|
| 411 |
+
|
| 412 |
+
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
| 413 |
+
# tensor: [b,s,r,2d]
|
| 414 |
+
b, s, _, _ = tensor.size()
|
| 415 |
+
ratio, d = self.compress_ratio, self.head_dim
|
| 416 |
+
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
| 417 |
+
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
| 418 |
+
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
| 419 |
+
return new_tensor
|
| 420 |
+
|
| 421 |
+
def forward(self, x: torch.Tensor, start_pos: int):
|
| 422 |
+
assert self.kv_cache is not None
|
| 423 |
+
bsz, seqlen, _ = x.size()
|
| 424 |
+
ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
|
| 425 |
+
dtype = x.dtype
|
| 426 |
+
# compression need fp32
|
| 427 |
+
x = x.float()
|
| 428 |
+
kv = self.wkv(x)
|
| 429 |
+
score = self.wgate(x)
|
| 430 |
+
if start_pos == 0:
|
| 431 |
+
should_compress = seqlen >= ratio
|
| 432 |
+
remainder = seqlen % ratio
|
| 433 |
+
cutoff = seqlen - remainder
|
| 434 |
+
offset = ratio if overlap else 0
|
| 435 |
+
if overlap and cutoff >= ratio:
|
| 436 |
+
self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
|
| 437 |
+
self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
|
| 438 |
+
if remainder > 0:
|
| 439 |
+
kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
|
| 440 |
+
self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
|
| 441 |
+
score = score[:, :cutoff]
|
| 442 |
+
kv = kv.unflatten(1, (-1, ratio))
|
| 443 |
+
score = score.unflatten(1, (-1, ratio)) + self.ape
|
| 444 |
+
if overlap:
|
| 445 |
+
kv = self.overlap_transform(kv, 0)
|
| 446 |
+
score = self.overlap_transform(score, float("-inf"))
|
| 447 |
+
kv = (kv * score.softmax(dim=2)).sum(dim=2)
|
| 448 |
+
else:
|
| 449 |
+
should_compress = (start_pos + 1) % self.compress_ratio == 0
|
| 450 |
+
score += self.ape[start_pos % ratio]
|
| 451 |
+
if overlap:
|
| 452 |
+
self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
|
| 453 |
+
self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
|
| 454 |
+
if should_compress:
|
| 455 |
+
kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
|
| 456 |
+
score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
|
| 457 |
+
kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
|
| 458 |
+
self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
|
| 459 |
+
self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
|
| 460 |
+
else:
|
| 461 |
+
self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
|
| 462 |
+
self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
|
| 463 |
+
if should_compress:
|
| 464 |
+
kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
|
| 465 |
+
if not should_compress:
|
| 466 |
+
return
|
| 467 |
+
kv = self.norm(kv.to(dtype))
|
| 468 |
+
if start_pos == 0:
|
| 469 |
+
freqs_cis = self.freqs_cis[:cutoff:ratio]
|
| 470 |
+
else:
|
| 471 |
+
freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
|
| 472 |
+
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
| 473 |
+
if self.rotate:
|
| 474 |
+
kv = rotate_activation(kv)
|
| 475 |
+
fp4_act_quant(kv, fp4_block_size, True)
|
| 476 |
+
else:
|
| 477 |
+
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
| 478 |
+
if start_pos == 0:
|
| 479 |
+
self.kv_cache[:bsz, :seqlen // ratio] = kv
|
| 480 |
+
else:
|
| 481 |
+
self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
|
| 482 |
+
return kv
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class Indexer(torch.nn.Module):
|
| 486 |
+
"""Selects top-k compressed KV positions for sparse attention via learned scoring.
|
| 487 |
+
Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
|
| 488 |
+
|
| 489 |
+
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
|
| 490 |
+
super().__init__()
|
| 491 |
+
self.dim = args.dim
|
| 492 |
+
self.n_heads = args.index_n_heads
|
| 493 |
+
self.n_local_heads = args.index_n_heads // world_size
|
| 494 |
+
self.head_dim = args.index_head_dim
|
| 495 |
+
self.rope_head_dim = args.rope_head_dim
|
| 496 |
+
self.index_topk = args.index_topk
|
| 497 |
+
self.q_lora_rank = args.q_lora_rank
|
| 498 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 499 |
+
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
|
| 500 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 501 |
+
self.compress_ratio = compress_ratio
|
| 502 |
+
|
| 503 |
+
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
|
| 504 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
|
| 505 |
+
self.freqs_cis = None
|
| 506 |
+
|
| 507 |
+
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
|
| 508 |
+
bsz, seqlen, _ = x.size()
|
| 509 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 510 |
+
ratio = self.compress_ratio
|
| 511 |
+
rd = self.rope_head_dim
|
| 512 |
+
end_pos = start_pos + seqlen
|
| 513 |
+
if self.compressor.kv_cache is None:
|
| 514 |
+
self.compressor.kv_cache = self.kv_cache
|
| 515 |
+
self.compressor.freqs_cis = self.freqs_cis
|
| 516 |
+
q = self.wq_b(qr)
|
| 517 |
+
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
|
| 518 |
+
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
| 519 |
+
q = rotate_activation(q)
|
| 520 |
+
# use fp4 simulation for q and kv in indexer
|
| 521 |
+
fp4_act_quant(q, fp4_block_size, True)
|
| 522 |
+
self.compressor(x, start_pos)
|
| 523 |
+
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
|
| 524 |
+
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
| 525 |
+
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
|
| 526 |
+
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
|
| 527 |
+
if world_size > 1:
|
| 528 |
+
dist.all_reduce(index_score)
|
| 529 |
+
if start_pos == 0:
|
| 530 |
+
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 531 |
+
index_score += torch.where(mask, float("-inf"), 0)
|
| 532 |
+
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
|
| 533 |
+
if start_pos == 0:
|
| 534 |
+
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 535 |
+
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
|
| 536 |
+
else:
|
| 537 |
+
topk_idxs += offset
|
| 538 |
+
return topk_idxs
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class Attention(nn.Module):
|
| 542 |
+
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
|
| 543 |
+
Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
|
| 544 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.layer_id = layer_id
|
| 547 |
+
self.dim = args.dim
|
| 548 |
+
self.n_heads = args.n_heads
|
| 549 |
+
self.n_local_heads = args.n_heads // world_size
|
| 550 |
+
self.q_lora_rank = args.q_lora_rank
|
| 551 |
+
self.o_lora_rank = args.o_lora_rank
|
| 552 |
+
self.head_dim = args.head_dim
|
| 553 |
+
self.rope_head_dim = args.rope_head_dim
|
| 554 |
+
self.nope_head_dim = args.head_dim - args.rope_head_dim
|
| 555 |
+
self.n_groups = args.o_groups
|
| 556 |
+
self.n_local_groups = self.n_groups // world_size
|
| 557 |
+
self.window_size = args.window_size
|
| 558 |
+
self.compress_ratio = args.compress_ratios[layer_id]
|
| 559 |
+
self.eps = args.norm_eps
|
| 560 |
+
|
| 561 |
+
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
|
| 562 |
+
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
| 563 |
+
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
| 564 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 565 |
+
self.wkv = Linear(self.dim, self.head_dim)
|
| 566 |
+
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
| 567 |
+
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
|
| 568 |
+
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
|
| 569 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 570 |
+
|
| 571 |
+
if self.compress_ratio:
|
| 572 |
+
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
|
| 573 |
+
if self.compress_ratio == 4:
|
| 574 |
+
self.indexer = Indexer(args, self.compress_ratio)
|
| 575 |
+
else:
|
| 576 |
+
self.indexer = None
|
| 577 |
+
|
| 578 |
+
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
|
| 579 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
|
| 580 |
+
if self.compress_ratio:
|
| 581 |
+
original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
|
| 582 |
+
else:
|
| 583 |
+
# disable YaRN and use base rope_theta in pure sliding-window attention
|
| 584 |
+
original_seq_len, rope_theta = 0, args.rope_theta
|
| 585 |
+
freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
|
| 586 |
+
rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
|
| 587 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 588 |
+
|
| 589 |
+
def forward(self, x: torch.Tensor, start_pos: int):
|
| 590 |
+
bsz, seqlen, _ = x.size()
|
| 591 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 592 |
+
win = self.window_size
|
| 593 |
+
ratio = self.compress_ratio
|
| 594 |
+
rd = self.rope_head_dim
|
| 595 |
+
if self.compress_ratio and self.compressor.kv_cache is None:
|
| 596 |
+
self.compressor.kv_cache = self.kv_cache[:, win:]
|
| 597 |
+
self.compressor.freqs_cis = self.freqs_cis
|
| 598 |
+
if self.indexer is not None:
|
| 599 |
+
self.indexer.freqs_cis = self.freqs_cis
|
| 600 |
+
# q
|
| 601 |
+
qr = q = self.q_norm(self.wq_a(x))
|
| 602 |
+
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
|
| 603 |
+
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
|
| 604 |
+
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
| 605 |
+
|
| 606 |
+
# win kv & topk_idxs
|
| 607 |
+
kv = self.wkv(x)
|
| 608 |
+
kv = self.kv_norm(kv)
|
| 609 |
+
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
| 610 |
+
# FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
|
| 611 |
+
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
| 612 |
+
topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
|
| 613 |
+
if self.compress_ratio:
|
| 614 |
+
offset = kv.size(1) if start_pos == 0 else win
|
| 615 |
+
if self.indexer is not None:
|
| 616 |
+
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
|
| 617 |
+
else:
|
| 618 |
+
compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
|
| 619 |
+
topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
|
| 620 |
+
topk_idxs = topk_idxs.int()
|
| 621 |
+
|
| 622 |
+
# compress kv & attn
|
| 623 |
+
if start_pos == 0:
|
| 624 |
+
if seqlen <= win:
|
| 625 |
+
self.kv_cache[:bsz, :seqlen] = kv
|
| 626 |
+
else:
|
| 627 |
+
cutoff = seqlen % win
|
| 628 |
+
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
|
| 629 |
+
if self.compress_ratio:
|
| 630 |
+
if (kv_compress := self.compressor(x, start_pos)) is not None:
|
| 631 |
+
kv = torch.cat([kv, kv_compress], dim=1)
|
| 632 |
+
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
| 633 |
+
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
|
| 634 |
+
else:
|
| 635 |
+
self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
|
| 636 |
+
if self.compress_ratio:
|
| 637 |
+
self.compressor(x, start_pos)
|
| 638 |
+
o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
|
| 639 |
+
apply_rotary_emb(o[..., -rd:], freqs_cis, True)
|
| 640 |
+
|
| 641 |
+
# o: apply wo_a per-group projection then wo_b
|
| 642 |
+
# Flatten groups into the feature dim, call wo_a as a normal linear, then reshape back.
|
| 643 |
+
# Equivalent to the per-group einsum when wo_a weight is block-diagonal across groups
|
| 644 |
+
# (always true here since n_local_groups = n_groups/world_size = 1 for 8-GPU deploy).
|
| 645 |
+
o = o.view(bsz, seqlen, self.n_local_groups, -1)
|
| 646 |
+
o = self.wo_a(o.flatten(2)).view(bsz, seqlen, self.n_local_groups, self.o_lora_rank)
|
| 647 |
+
x = self.wo_b(o.flatten(2))
|
| 648 |
+
return x
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
class Gate(nn.Module):
|
| 652 |
+
"""MoE gating: computes expert routing scores and selects top-k experts.
|
| 653 |
+
Supports hash-based routing (first n_hash_layers) where expert indices are
|
| 654 |
+
predetermined per token ID, and score-based routing (remaining layers)."""
|
| 655 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 656 |
+
super().__init__()
|
| 657 |
+
self.dim = args.dim
|
| 658 |
+
self.topk = args.n_activated_experts
|
| 659 |
+
self.score_func = args.score_func
|
| 660 |
+
self.route_scale = args.route_scale
|
| 661 |
+
self.hash = layer_id < args.n_hash_layers
|
| 662 |
+
self.is_w4a16 = w4a16_mode
|
| 663 |
+
if self.is_w4a16:
|
| 664 |
+
in_f, out_f = args.dim, args.n_routed_experts
|
| 665 |
+
assert in_f % w4a16_group_size == 0 and out_f % 8 == 0
|
| 666 |
+
self.group_size = w4a16_group_size
|
| 667 |
+
self.qweight = nn.Parameter(
|
| 668 |
+
torch.empty(in_f // 8, out_f, dtype=torch.int32), requires_grad=False)
|
| 669 |
+
self.qzeros = nn.Parameter(
|
| 670 |
+
torch.empty(in_f // self.group_size, out_f // 8, dtype=torch.int32), requires_grad=False)
|
| 671 |
+
self.scales = nn.Parameter(
|
| 672 |
+
torch.empty(in_f // self.group_size, out_f, dtype=torch.bfloat16), requires_grad=False)
|
| 673 |
+
self.register_parameter("weight", None)
|
| 674 |
+
else:
|
| 675 |
+
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
| 676 |
+
if self.hash:
|
| 677 |
+
self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
|
| 678 |
+
self.bias = None
|
| 679 |
+
else:
|
| 680 |
+
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
|
| 681 |
+
|
| 682 |
+
def init_woq(self, QuantLinear):
|
| 683 |
+
if not self.is_w4a16:
|
| 684 |
+
return
|
| 685 |
+
dev = self.qweight.device
|
| 686 |
+
in_f, out_f = self.dim, self.qweight.shape[1]
|
| 687 |
+
if out_f % 64 != 0:
|
| 688 |
+
self._woq = None
|
| 689 |
+
return
|
| 690 |
+
layer = QuantLinear(
|
| 691 |
+
bits=4, group_size=self.group_size,
|
| 692 |
+
in_features=in_f, out_features=out_f,
|
| 693 |
+
bias=False, desc_act=False, sym=True, register_buffers=True,
|
| 694 |
+
)
|
| 695 |
+
layer = layer.to(dev)
|
| 696 |
+
layer.qweight.copy_(self.qweight.data)
|
| 697 |
+
layer.qzeros.copy_(self.qzeros.data)
|
| 698 |
+
layer.scales.copy_(self.scales.to(layer.scales.dtype).data)
|
| 699 |
+
layer.g_idx.copy_(torch.arange(in_f, dtype=torch.int32, device=dev) // self.group_size)
|
| 700 |
+
layer.post_init()
|
| 701 |
+
self._woq = layer
|
| 702 |
+
self.qweight = None
|
| 703 |
+
self.qzeros = None
|
| 704 |
+
self.scales = None
|
| 705 |
+
|
| 706 |
+
def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 707 |
+
if self.is_w4a16:
|
| 708 |
+
if hasattr(self, '_woq') and self._woq is not None:
|
| 709 |
+
scores = self._woq(x.to(torch.bfloat16)).float()
|
| 710 |
+
else:
|
| 711 |
+
w = dequantize_w4a16(self.qweight, self.qzeros, self.scales, self.group_size)
|
| 712 |
+
scores = F.linear(x.to(w.dtype), w).float()
|
| 713 |
+
else:
|
| 714 |
+
scores = linear(x.float(), self.weight.float())
|
| 715 |
+
if self.score_func == "softmax":
|
| 716 |
+
scores = scores.softmax(dim=-1)
|
| 717 |
+
elif self.score_func == "sigmoid":
|
| 718 |
+
scores = scores.sigmoid()
|
| 719 |
+
else:
|
| 720 |
+
scores = F.softplus(scores).sqrt()
|
| 721 |
+
original_scores = scores
|
| 722 |
+
# Bias shifts scores for expert selection (topk) but does not affect routing weights.
|
| 723 |
+
if self.bias is not None:
|
| 724 |
+
scores = scores + self.bias
|
| 725 |
+
if self.hash:
|
| 726 |
+
indices = self.tid2eid[input_ids]
|
| 727 |
+
else:
|
| 728 |
+
indices = scores.topk(self.topk, dim=-1)[1]
|
| 729 |
+
weights = original_scores.gather(1, indices)
|
| 730 |
+
if self.score_func != "softmax":
|
| 731 |
+
weights /= weights.sum(dim=-1, keepdim=True)
|
| 732 |
+
weights *= self.route_scale
|
| 733 |
+
return weights, indices
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class Expert(nn.Module):
|
| 737 |
+
"""Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
|
| 738 |
+
def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
|
| 739 |
+
super().__init__()
|
| 740 |
+
self.w1 = Linear(dim, inter_dim, dtype=dtype)
|
| 741 |
+
self.w2 = Linear(inter_dim, dim, dtype=dtype)
|
| 742 |
+
self.w3 = Linear(dim, inter_dim, dtype=dtype)
|
| 743 |
+
self.swiglu_limit = swiglu_limit
|
| 744 |
+
|
| 745 |
+
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 746 |
+
dtype = x.dtype
|
| 747 |
+
gate = self.w1(x).float()
|
| 748 |
+
up = self.w3(x).float()
|
| 749 |
+
if self.swiglu_limit > 0:
|
| 750 |
+
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
|
| 751 |
+
gate = torch.clamp(gate, max=self.swiglu_limit)
|
| 752 |
+
x = F.silu(gate) * up
|
| 753 |
+
if weights is not None:
|
| 754 |
+
x = weights * x
|
| 755 |
+
return self.w2(x.to(dtype))
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class MoE(nn.Module):
|
| 759 |
+
"""Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
|
| 760 |
+
Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
|
| 761 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 762 |
+
super().__init__()
|
| 763 |
+
self.layer_id = layer_id
|
| 764 |
+
self.dim = args.dim
|
| 765 |
+
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
| 766 |
+
self.n_routed_experts = args.n_routed_experts
|
| 767 |
+
self.n_local_experts = args.n_routed_experts // world_size
|
| 768 |
+
self.n_activated_experts = args.n_activated_experts
|
| 769 |
+
self.experts_start_idx = rank * self.n_local_experts
|
| 770 |
+
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
| 771 |
+
self.gate = Gate(layer_id, args)
|
| 772 |
+
expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
|
| 773 |
+
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 774 |
+
for i in range(self.n_routed_experts)])
|
| 775 |
+
assert args.n_shared_experts == 1
|
| 776 |
+
# no swiglu_limit
|
| 777 |
+
self.shared_experts = Expert(args.dim, args.moe_inter_dim, swiglu_limit=args.swiglu_limit)
|
| 778 |
+
|
| 779 |
+
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 780 |
+
shape = x.size()
|
| 781 |
+
x = x.view(-1, self.dim)
|
| 782 |
+
weights, indices = self.gate(x, input_ids.flatten())
|
| 783 |
+
y = torch.zeros_like(x, dtype=torch.float32)
|
| 784 |
+
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
| 785 |
+
for i in range(self.experts_start_idx, self.experts_end_idx):
|
| 786 |
+
if counts[i] == 0:
|
| 787 |
+
continue
|
| 788 |
+
expert = self.experts[i]
|
| 789 |
+
idx, top = torch.where(indices == i)
|
| 790 |
+
y[idx] += expert(x[idx], weights[idx, top, None])
|
| 791 |
+
if world_size > 1:
|
| 792 |
+
dist.all_reduce(y)
|
| 793 |
+
y += self.shared_experts(x)
|
| 794 |
+
return y.type_as(x).view(shape)
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class Block(nn.Module):
|
| 798 |
+
"""Transformer block with Hyper-Connections (HC) mixing.
|
| 799 |
+
Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
|
| 800 |
+
hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
|
| 801 |
+
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
|
| 802 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 803 |
+
super().__init__()
|
| 804 |
+
self.layer_id = layer_id
|
| 805 |
+
self.norm_eps = args.norm_eps
|
| 806 |
+
self.attn = Attention(layer_id, args)
|
| 807 |
+
self.ffn = MoE(layer_id, args)
|
| 808 |
+
self.attn_norm = RMSNorm(args.dim, self.norm_eps)
|
| 809 |
+
self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
|
| 810 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 811 |
+
self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
|
| 812 |
+
self.hc_eps = args.hc_eps
|
| 813 |
+
mix_hc = (2 + hc_mult) * hc_mult
|
| 814 |
+
hc_dim = hc_mult * args.dim
|
| 815 |
+
with set_dtype(torch.float32):
|
| 816 |
+
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
| 817 |
+
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
| 818 |
+
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
|
| 819 |
+
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
|
| 820 |
+
self.hc_attn_scale = nn.Parameter(torch.empty(3))
|
| 821 |
+
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
|
| 822 |
+
|
| 823 |
+
def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
| 824 |
+
# x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
|
| 825 |
+
shape, dtype = x.size(), x.dtype
|
| 826 |
+
x = x.flatten(2).float()
|
| 827 |
+
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
| 828 |
+
mixes = F.linear(x, hc_fn) * rsqrt
|
| 829 |
+
pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
|
| 830 |
+
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
| 831 |
+
return y.to(dtype), post, comb
|
| 832 |
+
|
| 833 |
+
def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
|
| 834 |
+
# x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
|
| 835 |
+
y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
|
| 836 |
+
return y.type_as(x)
|
| 837 |
+
|
| 838 |
+
def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
|
| 839 |
+
residual = x
|
| 840 |
+
x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
|
| 841 |
+
x = self.attn_norm(x)
|
| 842 |
+
x = self.attn(x, start_pos)
|
| 843 |
+
x = self.hc_post(x, residual, post, comb)
|
| 844 |
+
|
| 845 |
+
residual = x
|
| 846 |
+
x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
|
| 847 |
+
x = self.ffn_norm(x)
|
| 848 |
+
x = self.ffn(x, input_ids)
|
| 849 |
+
x = self.hc_post(x, residual, post, comb)
|
| 850 |
+
return x
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
class ParallelHead(nn.Module):
|
| 854 |
+
|
| 855 |
+
def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
|
| 856 |
+
super().__init__()
|
| 857 |
+
self.vocab_size = vocab_size
|
| 858 |
+
self.dim = dim
|
| 859 |
+
self.norm_eps = norm_eps
|
| 860 |
+
self.hc_eps = hc_eps
|
| 861 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 862 |
+
# lm_head is always stored as bf16 (even in W4A16 checkpoints); use fp32 for logit precision
|
| 863 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
|
| 864 |
+
|
| 865 |
+
def get_logits(self, x):
|
| 866 |
+
return F.linear(x[:, -1].float(), self.weight)
|
| 867 |
+
|
| 868 |
+
def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
|
| 869 |
+
# x: [b,s,hc,d]
|
| 870 |
+
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
|
| 871 |
+
logits = self.get_logits(norm(x))
|
| 872 |
+
if world_size > 1:
|
| 873 |
+
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
| 874 |
+
dist.all_gather(all_logits, logits)
|
| 875 |
+
logits = torch.cat(all_logits, dim=-1)
|
| 876 |
+
return logits
|
| 877 |
+
|
| 878 |
+
def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
| 879 |
+
shape, dtype = x.size(), x.dtype
|
| 880 |
+
x = x.flatten(2).float()
|
| 881 |
+
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
| 882 |
+
mixes = F.linear(x, hc_fn) * rsqrt
|
| 883 |
+
pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
|
| 884 |
+
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
| 885 |
+
return y.to(dtype)
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
class MTPBlock(Block):
|
| 889 |
+
|
| 890 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 891 |
+
super().__init__(layer_id, args)
|
| 892 |
+
self.e_proj = Linear(args.dim, args.dim)
|
| 893 |
+
self.h_proj = Linear(args.dim, args.dim)
|
| 894 |
+
self.enorm = RMSNorm(args.dim, args.norm_eps)
|
| 895 |
+
self.hnorm = RMSNorm(args.dim, args.norm_eps)
|
| 896 |
+
self.norm = RMSNorm(args.dim, args.norm_eps)
|
| 897 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 898 |
+
hc_dim = hc_mult * args.dim
|
| 899 |
+
with set_dtype(torch.float32):
|
| 900 |
+
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
| 901 |
+
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
| 902 |
+
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
| 903 |
+
self.embed: ParallelEmbedding = None
|
| 904 |
+
self.head: ParallelHead = None
|
| 905 |
+
|
| 906 |
+
@torch.inference_mode()
|
| 907 |
+
def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
|
| 908 |
+
# x: [b,s,hc,d]
|
| 909 |
+
assert self.embed is not None and self.head is not None
|
| 910 |
+
e = self.embed(input_ids)
|
| 911 |
+
e = self.enorm(e)
|
| 912 |
+
x = self.hnorm(x)
|
| 913 |
+
x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
|
| 914 |
+
x = super().forward(x, start_pos, input_ids)
|
| 915 |
+
logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
| 916 |
+
return logits
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
class Transformer(nn.Module):
|
| 920 |
+
"""Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
|
| 921 |
+
Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
|
| 922 |
+
def __init__(self, args: ModelArgs):
|
| 923 |
+
global world_size, rank, default_dtype, scale_fmt, scale_dtype, w4a16_mode
|
| 924 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 925 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 926 |
+
w4a16_mode = (args.dtype == "w4a16")
|
| 927 |
+
if w4a16_mode:
|
| 928 |
+
default_dtype = torch.bfloat16
|
| 929 |
+
scale_fmt = None
|
| 930 |
+
scale_dtype = torch.float32
|
| 931 |
+
else:
|
| 932 |
+
default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
| 933 |
+
scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
|
| 934 |
+
scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
|
| 935 |
+
super().__init__()
|
| 936 |
+
self.max_seq_len = args.max_seq_len
|
| 937 |
+
self.norm_eps = args.norm_eps
|
| 938 |
+
self.hc_eps = args.hc_eps
|
| 939 |
+
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
| 940 |
+
self.layers = torch.nn.ModuleList()
|
| 941 |
+
for layer_id in range(args.n_layers):
|
| 942 |
+
self.layers.append(Block(layer_id, args))
|
| 943 |
+
self.norm = RMSNorm(args.dim, self.norm_eps)
|
| 944 |
+
self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
|
| 945 |
+
self.mtp = torch.nn.ModuleList()
|
| 946 |
+
for layer_id in range(args.n_mtp_layers):
|
| 947 |
+
self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
|
| 948 |
+
self.mtp[-1].embed = self.embed
|
| 949 |
+
self.mtp[-1].head = self.head
|
| 950 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 951 |
+
hc_dim = hc_mult * args.dim
|
| 952 |
+
with set_dtype(torch.float32):
|
| 953 |
+
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
| 954 |
+
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
| 955 |
+
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
| 956 |
+
|
| 957 |
+
def init_woq_layers(self):
|
| 958 |
+
"""After load_model(), convert all W4A16 parameters into QuantLinear layers."""
|
| 959 |
+
# from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear as QuantLinear
|
| 960 |
+
from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear as QuantLinear
|
| 961 |
+
for module in self.modules():
|
| 962 |
+
if hasattr(module, 'init_woq') and module is not self:
|
| 963 |
+
module.init_woq(QuantLinear)
|
| 964 |
+
torch.cuda.empty_cache()
|
| 965 |
+
|
| 966 |
+
@torch.inference_mode()
|
| 967 |
+
def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
|
| 968 |
+
h = self.embed(input_ids)
|
| 969 |
+
# Expand to hc_mult copies for Hyper-Connections
|
| 970 |
+
h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
|
| 971 |
+
for layer in self.layers:
|
| 972 |
+
h = layer(h, start_pos, input_ids)
|
| 973 |
+
logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
| 974 |
+
return logits
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
if __name__ == "__main__":
|
| 978 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 979 |
+
torch.set_default_device("cuda")
|
| 980 |
+
torch.manual_seed(0)
|
| 981 |
+
args = ModelArgs(n_hash_layers=0)
|
| 982 |
+
x = torch.randint(0, args.vocab_size, (2, 128))
|
| 983 |
+
model = Transformer(args)
|
| 984 |
+
|
| 985 |
+
print(model(x).size())
|
| 986 |
+
for i in range(128, 150):
|
| 987 |
+
print(i, model(x[:, 0:1], i).size())
|
| 988 |
+
|
| 989 |
+
h = torch.randn(2, 128, args.hc_mult, args.dim)
|
| 990 |
+
mtp = model.mtp[0]
|
| 991 |
+
print(mtp(h, 0, x).size())
|
| 992 |
+
print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
|
inference/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.10.0
|
| 2 |
+
transformers>=5.0.0
|
| 3 |
+
safetensors>=0.7.0
|
| 4 |
+
fast_hadamard_transform
|
| 5 |
+
tilelang==0.1.8
|
| 6 |
+
gptqmodel==6.0.3
|
model-00001-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ef64d991b80c86f24bd78d1bac9d452d95bc78e6b1d8feb6a182dae0240c7e5
|
| 3 |
+
size 1853358176
|
model-00002-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7d7f2862248d66164e5bf9007a9da3b0f1d6c0b9d5143c43dbc2d9ad4a7ff12
|
| 3 |
+
size 13390865040
|
model-00003-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:881002f38770cf6e0be5c24801d11c595674c3c4d7ff3485c45bf06d848e90a7
|
| 3 |
+
size 13390865040
|
model-00004-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c1c824e837c5c78d2467c70222e5b690c91a9bc772f6662ad48b39db06d116d
|
| 3 |
+
size 13403120848
|
model-00005-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c7cb7212f3374b69f29e1198b572392084135bdd0bda38282f8bd9a0c8ed46b
|
| 3 |
+
size 13384661096
|
model-00006-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:683444102f6eeae98ccaa00d1aada1e44253a80541a17d457382bd070947ac59
|
| 3 |
+
size 13396916904
|
model-00007-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:863e5634abb25b3cf83625ff90f3c5783b5619435ae184112f05c2d65c407a84
|
| 3 |
+
size 13384661096
|
model-00008-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a86efce4fc51c41911f78480863ba467d97f8a65ce03365d3f0a002b892674ad
|
| 3 |
+
size 13396916904
|
model-00009-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6cdf0891b172d21168999bd0adb520b820a078351ee15a77d7b07867908347cd
|
| 3 |
+
size 13384661096
|
model-00010-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2615fd85e6a5913201d05b1cefe8e664dc910b88531ff1ae46da21ed41a9558e
|
| 3 |
+
size 13396916904
|
model-00011-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4922eb51e18579581a83bce28ab880b8dd5390cd4d03dd0c24e4b2c4f6fd3dc6
|
| 3 |
+
size 13384661096
|
model-00012-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1b289316453fe719094a30e5cea8e577787dc7560512f872db669962addac88
|
| 3 |
+
size 13396920416
|
model-00013-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53c5b20d54738bdcfc827f165a6543848b6ccf384ef66bf9402ea7244ef67c52
|
| 3 |
+
size 13384664600
|
model-00014-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7423ce801964f2af7f4056531432fcc1a140a7cd33a2a5b530fb30e49609e618
|
| 3 |
+
size 13396920416
|
model-00015-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d13483f05edd67b110bcf0cefe7a84f9bfe94ce47e654af4e60238955fa50989
|
| 3 |
+
size 13384664600
|
model-00016-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75704e031c9be328e8133ee2c6603a430bcfd61979f5a561964bf29d0134104b
|
| 3 |
+
size 13396920416
|
model-00017-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7f0c6876bade6ea29e2662ebae000ce3cfde96887c4ca66a90db77ee530db8b
|
| 3 |
+
size 13384664600
|
model-00018-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4abe0458401fc5e990e3699dbff2b98e05355f66bcc5d974ece79e7e14672765
|
| 3 |
+
size 13396920416
|
model-00019-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:29238f11c48bbb56de1c513e9806edcaf63a142b191e36668b64d93df29b5238
|
| 3 |
+
size 13384664600
|
model-00020-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bd7d5d0e0337dc0af3b68287c237c6ac6c922596a64c88320fa8c62ddc32ceb
|
| 3 |
+
size 13396920416
|
model-00021-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:572a2a0e0910421cd7f4519da30c604cddf4a1aa61fe50e51788ac8f3e0875d7
|
| 3 |
+
size 13384664600
|
model-00022-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:077bce599585a64769465c6ef3614d4717184ce7f776162ee63884cc38e5ded3
|
| 3 |
+
size 13396920416
|
model-00023-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5ea118977b28e91802ff0df2cfd262c42eb2b439d6bd019bb2f830ea9ee7de2
|
| 3 |
+
size 13384664600
|