add inference code
Browse files- inference/README.md +26 -0
- inference/__pycache__/encoding_dsv4.cpython-312.pyc +0 -0
- inference/__pycache__/kernel.cpython-312.pyc +0 -0
- inference/__pycache__/model.cpython-312.pyc +0 -0
- inference/config.json +35 -0
- inference/convert.py +168 -0
- inference/generate.py +151 -0
- inference/kernel.py +536 -0
- inference/model.py +828 -0
- inference/requirements.txt +5 -0
inference/README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code for DeepSeek models
|
| 2 |
+
|
| 3 |
+
First convert huggingface model weight files to the format of this project.
|
| 4 |
+
```bash
|
| 5 |
+
export EXPERTS=384
|
| 6 |
+
export MP=8
|
| 7 |
+
export CONFIG=config.json
|
| 8 |
+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
Then chat with DeepSeek model at will!
|
| 12 |
+
```bash
|
| 13 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Or batch inference from file.
|
| 17 |
+
```bash
|
| 18 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Or multi nodes inference.
|
| 22 |
+
```bash
|
| 23 |
+
torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
If you want to use fp8, just remove `"expert_dtype": "fp4"` in `config.json` and specify `--expert-dtype fp8` in `convert.py`.
|
inference/__pycache__/encoding_dsv4.cpython-312.pyc
ADDED
|
Binary file (28.6 kB). View file
|
|
|
inference/__pycache__/kernel.cpython-312.pyc
ADDED
|
Binary file (35.4 kB). View file
|
|
|
inference/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (63.5 kB). View file
|
|
|
inference/config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 129280,
|
| 3 |
+
"dim": 7168,
|
| 4 |
+
"moe_inter_dim": 3072,
|
| 5 |
+
"n_layers": 61,
|
| 6 |
+
"n_hash_layers": 3,
|
| 7 |
+
"n_heads": 128,
|
| 8 |
+
"n_routed_experts": 384,
|
| 9 |
+
"n_shared_experts": 1,
|
| 10 |
+
"n_activated_experts": 6,
|
| 11 |
+
"score_func": "sqrtsoftplus",
|
| 12 |
+
"route_scale": 2.5,
|
| 13 |
+
"swiglu_limit": 10.0,
|
| 14 |
+
"q_lora_rank": 1536,
|
| 15 |
+
"head_dim": 512,
|
| 16 |
+
"rope_head_dim": 64,
|
| 17 |
+
"o_groups": 16,
|
| 18 |
+
"o_lora_rank": 1024,
|
| 19 |
+
"window_size": 128,
|
| 20 |
+
"original_seq_len": 65536,
|
| 21 |
+
"rope_theta": 10000,
|
| 22 |
+
"rope_factor": 16,
|
| 23 |
+
"beta_fast": 32,
|
| 24 |
+
"beta_slow": 1,
|
| 25 |
+
"index_n_heads": 64,
|
| 26 |
+
"index_head_dim": 128,
|
| 27 |
+
"index_topk": 512,
|
| 28 |
+
"hc_mult": 4,
|
| 29 |
+
"hc_sinkhorn_iters": 20,
|
| 30 |
+
"dtype": "fp8",
|
| 31 |
+
"scale_fmt": "ue8m0",
|
| 32 |
+
"expert_dtype": "fp4",
|
| 33 |
+
"compress_rope_theta": 160000,
|
| 34 |
+
"compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
|
| 35 |
+
}
|
inference/convert.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from glob import glob
|
| 5 |
+
from tqdm import tqdm, trange
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import safe_open, save_file
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
FP4_TABLE = torch.tensor([
|
| 12 |
+
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
| 13 |
+
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
|
| 14 |
+
], dtype=torch.float32)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
+
"""
|
| 19 |
+
Casts a tensor from e2m1fn to e4m3fn losslessly.
|
| 20 |
+
"""
|
| 21 |
+
assert x.dtype == torch.int8
|
| 22 |
+
assert x.ndim == 2
|
| 23 |
+
out_dim, in_dim = x.size()
|
| 24 |
+
in_dim *= 2
|
| 25 |
+
fp8_block_size = 128
|
| 26 |
+
fp4_block_size = 32
|
| 27 |
+
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
|
| 28 |
+
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
|
| 29 |
+
|
| 30 |
+
x = x.view(torch.uint8)
|
| 31 |
+
low = x & 0x0F
|
| 32 |
+
high = (x >> 4) & 0x0F
|
| 33 |
+
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
|
| 34 |
+
|
| 35 |
+
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
|
| 36 |
+
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
|
| 37 |
+
MAX_OFFSET_BITS = 6
|
| 38 |
+
|
| 39 |
+
bOut = out_dim // fp8_block_size
|
| 40 |
+
bIn = in_dim // fp8_block_size
|
| 41 |
+
# bOut, bIn, 128, 128
|
| 42 |
+
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
|
| 43 |
+
# bOut, bIn, 128*4
|
| 44 |
+
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
|
| 45 |
+
## bOut, bIn, 1
|
| 46 |
+
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
|
| 47 |
+
# bOut, bIn, 128*4
|
| 48 |
+
offset = scale / scale_max_offset_bits
|
| 49 |
+
# bOut, bIn, 128, 128
|
| 50 |
+
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
|
| 51 |
+
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
|
| 52 |
+
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
mapping = {
|
| 56 |
+
"embed_tokens": ("embed", 0),
|
| 57 |
+
"input_layernorm": ("attn_norm", None),
|
| 58 |
+
"post_attention_layernorm": ("ffn_norm", None),
|
| 59 |
+
"q_proj": ("wq", 0),
|
| 60 |
+
"q_a_proj": ("wq_a", None),
|
| 61 |
+
"q_a_layernorm": ("q_norm", None),
|
| 62 |
+
"q_b_proj": ("wq_b", 0),
|
| 63 |
+
"kv_a_proj_with_mqa": ("wkv_a", None),
|
| 64 |
+
"kv_a_layernorm": ("kv_norm", None),
|
| 65 |
+
"kv_b_proj": ("wkv_b", 0),
|
| 66 |
+
"o_proj": ("wo", 1),
|
| 67 |
+
"gate_proj": ("w1", 0),
|
| 68 |
+
"down_proj": ("w2", 1),
|
| 69 |
+
"up_proj": ("w3", 0),
|
| 70 |
+
"lm_head": ("head", 0),
|
| 71 |
+
|
| 72 |
+
"embed": ("embed", 0),
|
| 73 |
+
"wq_b": ("wq_b", 0),
|
| 74 |
+
"wo_a": ("wo_a", 0),
|
| 75 |
+
"wo_b": ("wo_b", 1),
|
| 76 |
+
"head": ("head", 0),
|
| 77 |
+
"attn_sink": ("attn_sink", 0),
|
| 78 |
+
"weights_proj": ("weights_proj", 0),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
|
| 83 |
+
"""
|
| 84 |
+
Converts and saves model checkpoint files into a specified format.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
| 88 |
+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
| 89 |
+
n_experts (int): Total number of experts in the model.
|
| 90 |
+
mp (int): Model parallelism factor.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
None
|
| 94 |
+
"""
|
| 95 |
+
torch.set_num_threads(8)
|
| 96 |
+
n_local_experts = n_experts // mp
|
| 97 |
+
state_dicts = [{} for _ in range(mp)]
|
| 98 |
+
|
| 99 |
+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
| 100 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 101 |
+
for name in f.keys():
|
| 102 |
+
param: torch.Tensor = f.get_tensor(name)
|
| 103 |
+
if name.startswith("model."):
|
| 104 |
+
name = name[len("model."):]
|
| 105 |
+
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
|
| 106 |
+
continue
|
| 107 |
+
name = name.replace("self_attn", "attn")
|
| 108 |
+
name = name.replace("mlp", "ffn")
|
| 109 |
+
name = name.replace("weight_scale_inv", "scale")
|
| 110 |
+
name = name.replace("e_score_correction_bias", "bias")
|
| 111 |
+
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
|
| 112 |
+
key = name.split(".")[-1]
|
| 113 |
+
else:
|
| 114 |
+
key = name.split(".")[-2]
|
| 115 |
+
if key in mapping:
|
| 116 |
+
new_key, dim = mapping[key]
|
| 117 |
+
else:
|
| 118 |
+
new_key, dim = key, None
|
| 119 |
+
name = name.replace(key, new_key)
|
| 120 |
+
for i in range(mp):
|
| 121 |
+
new_param = param
|
| 122 |
+
if "experts" in name and "shared_experts" not in name:
|
| 123 |
+
idx = int(name.split(".")[-3])
|
| 124 |
+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
| 125 |
+
continue
|
| 126 |
+
elif dim is not None:
|
| 127 |
+
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
| 128 |
+
shard_size = param.size(dim) // mp
|
| 129 |
+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
| 130 |
+
state_dicts[i][name] = new_param
|
| 131 |
+
|
| 132 |
+
os.makedirs(save_path, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
for i in trange(mp):
|
| 135 |
+
names = list(state_dicts[i].keys())
|
| 136 |
+
for name in names:
|
| 137 |
+
if name.endswith("wo_a.weight"):
|
| 138 |
+
weight = state_dicts[i][name]
|
| 139 |
+
scale = state_dicts[i].pop(name.replace("weight", "scale"))
|
| 140 |
+
weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
|
| 141 |
+
state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
|
| 142 |
+
elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
|
| 143 |
+
if expert_dtype == "fp8":
|
| 144 |
+
scale_name = name.replace("weight", "scale")
|
| 145 |
+
weight = state_dicts[i].pop(name)
|
| 146 |
+
scale = state_dicts[i].pop(scale_name)
|
| 147 |
+
state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
|
| 148 |
+
else:
|
| 149 |
+
state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
|
| 150 |
+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
| 151 |
+
|
| 152 |
+
for file in ["tokenizer.json", "tokenizer_config.json"]:
|
| 153 |
+
old_file_path = os.path.join(hf_ckpt_path, file)
|
| 154 |
+
new_file_path = os.path.join(save_path, file)
|
| 155 |
+
if os.path.exists(old_file_path):
|
| 156 |
+
shutil.copyfile(old_file_path, new_file_path)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
parser = ArgumentParser()
|
| 161 |
+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
| 162 |
+
parser.add_argument("--save-path", type=str, required=True)
|
| 163 |
+
parser.add_argument("--n-experts", type=int, required=True)
|
| 164 |
+
parser.add_argument("--model-parallel", type=int, required=True)
|
| 165 |
+
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
| 168 |
+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
|
inference/generate.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from safetensors.torch import load_model
|
| 10 |
+
|
| 11 |
+
from model import Transformer, ModelArgs
|
| 12 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def sample(logits, temperature: float = 1.0):
|
| 16 |
+
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
|
| 17 |
+
since it avoids the GPU-to-CPU sync in torch.multinomial."""
|
| 18 |
+
logits = logits / max(temperature, 1e-5)
|
| 19 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 20 |
+
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.inference_mode()
|
| 24 |
+
def generate(
|
| 25 |
+
model: Transformer,
|
| 26 |
+
prompt_tokens: List[List[int]],
|
| 27 |
+
max_new_tokens: int,
|
| 28 |
+
eos_id: int,
|
| 29 |
+
temperature: float = 1.0
|
| 30 |
+
) -> List[List[int]]:
|
| 31 |
+
"""Batch generation with left-padded prompts.
|
| 32 |
+
|
| 33 |
+
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
|
| 34 |
+
Subsequent passes generate one token at a time (decode phase). For positions
|
| 35 |
+
still within a prompt, the ground-truth token overrides the model's prediction.
|
| 36 |
+
"""
|
| 37 |
+
prompt_lens = [len(t) for t in prompt_tokens]
|
| 38 |
+
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 39 |
+
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 40 |
+
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
|
| 41 |
+
for i, t in enumerate(prompt_tokens):
|
| 42 |
+
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
|
| 43 |
+
prev_pos = 0
|
| 44 |
+
finished = torch.tensor([False] * len(prompt_tokens))
|
| 45 |
+
prompt_mask = tokens != -1
|
| 46 |
+
for cur_pos in range(min(prompt_lens), total_len):
|
| 47 |
+
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 48 |
+
if temperature > 0:
|
| 49 |
+
next_token = sample(logits, temperature)
|
| 50 |
+
else:
|
| 51 |
+
next_token = logits.argmax(dim=-1)
|
| 52 |
+
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
| 53 |
+
tokens[:, cur_pos] = next_token
|
| 54 |
+
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
| 55 |
+
prev_pos = cur_pos
|
| 56 |
+
if finished.all():
|
| 57 |
+
break
|
| 58 |
+
completion_tokens = []
|
| 59 |
+
for i, toks in enumerate(tokens.tolist()):
|
| 60 |
+
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
| 61 |
+
if eos_id in toks:
|
| 62 |
+
toks = toks[:toks.index(eos_id)]
|
| 63 |
+
toks.append(eos_id)
|
| 64 |
+
completion_tokens.append(toks)
|
| 65 |
+
return completion_tokens
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main(
|
| 69 |
+
ckpt_path: str,
|
| 70 |
+
config: str,
|
| 71 |
+
input_file: str = "",
|
| 72 |
+
interactive: bool = True,
|
| 73 |
+
max_new_tokens: int = 100,
|
| 74 |
+
temperature: float = 1.0,
|
| 75 |
+
) -> None:
|
| 76 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 77 |
+
rank = int(os.getenv("RANK", "0"))
|
| 78 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 79 |
+
if world_size > 1:
|
| 80 |
+
dist.init_process_group("nccl")
|
| 81 |
+
global print
|
| 82 |
+
if rank != 0:
|
| 83 |
+
print = lambda *_, **__: None
|
| 84 |
+
torch.cuda.set_device(local_rank)
|
| 85 |
+
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
| 86 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 87 |
+
torch.set_num_threads(8)
|
| 88 |
+
torch.manual_seed(33377335)
|
| 89 |
+
with open(config) as f:
|
| 90 |
+
args = ModelArgs(**json.load(f))
|
| 91 |
+
if interactive:
|
| 92 |
+
args.max_batch_size = 1
|
| 93 |
+
print(args)
|
| 94 |
+
with torch.device("cuda"):
|
| 95 |
+
model = Transformer(args)
|
| 96 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 97 |
+
print("load model")
|
| 98 |
+
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
|
| 99 |
+
torch.set_default_device("cuda")
|
| 100 |
+
print("I'm DeepSeek 👋")
|
| 101 |
+
|
| 102 |
+
if interactive:
|
| 103 |
+
messages = []
|
| 104 |
+
while True:
|
| 105 |
+
if world_size == 1:
|
| 106 |
+
prompt = input(">>> ")
|
| 107 |
+
elif rank == 0:
|
| 108 |
+
prompt = input(">>> ")
|
| 109 |
+
objects = [prompt]
|
| 110 |
+
dist.broadcast_object_list(objects, 0)
|
| 111 |
+
else:
|
| 112 |
+
objects = [None]
|
| 113 |
+
dist.broadcast_object_list(objects, 0)
|
| 114 |
+
prompt = objects[0]
|
| 115 |
+
if prompt == "/exit":
|
| 116 |
+
break
|
| 117 |
+
elif prompt == "/clear":
|
| 118 |
+
messages.clear()
|
| 119 |
+
continue
|
| 120 |
+
messages.append({"role": "user", "content": prompt})
|
| 121 |
+
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
|
| 122 |
+
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 123 |
+
completion = tokenizer.decode(completion_tokens[0])
|
| 124 |
+
print(completion)
|
| 125 |
+
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
|
| 126 |
+
else:
|
| 127 |
+
with open(input_file) as f:
|
| 128 |
+
prompts = f.read().split("\n\n")
|
| 129 |
+
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
|
| 130 |
+
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 131 |
+
completions = tokenizer.batch_decode(completion_tokens)
|
| 132 |
+
for prompt, completion in zip(prompts, completions):
|
| 133 |
+
print("Prompt:", prompt)
|
| 134 |
+
print("Completion:", completion)
|
| 135 |
+
print()
|
| 136 |
+
|
| 137 |
+
if world_size > 1:
|
| 138 |
+
dist.destroy_process_group()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
parser = ArgumentParser()
|
| 143 |
+
parser.add_argument("--ckpt-path", type=str, required=True)
|
| 144 |
+
parser.add_argument("--config", type=str, required=True)
|
| 145 |
+
parser.add_argument("--input-file", type=str, default="")
|
| 146 |
+
parser.add_argument("--interactive", action="store_true")
|
| 147 |
+
parser.add_argument("--max-new-tokens", type=int, default=300)
|
| 148 |
+
parser.add_argument("--temperature", type=float, default=0.6)
|
| 149 |
+
args = parser.parse_args()
|
| 150 |
+
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
| 151 |
+
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
inference/kernel.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tilelang
|
| 3 |
+
import tilelang.language as T
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
tilelang.set_log_level("WARNING")
|
| 8 |
+
|
| 9 |
+
pass_configs = {
|
| 10 |
+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
| 11 |
+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
FP8 = "float8_e4m3"
|
| 15 |
+
FP4 = "float4_e2m1fn"
|
| 16 |
+
FE8M0 = "float8_e8m0fnu"
|
| 17 |
+
BF16 = "bfloat16"
|
| 18 |
+
FP32 = "float32"
|
| 19 |
+
INT32 = "int32"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def fast_log2_ceil(x):
|
| 23 |
+
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
|
| 24 |
+
bits_x = T.reinterpret("uint32", x)
|
| 25 |
+
exp_x = (bits_x >> 23) & 0xFF
|
| 26 |
+
man_bits = bits_x & ((1 << 23) - 1)
|
| 27 |
+
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def fast_pow2(x):
|
| 31 |
+
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
|
| 32 |
+
bits_x = (x + 127) << 23
|
| 33 |
+
return T.reinterpret("float32", bits_x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fast_round_scale(amax, fp8_max_inv):
|
| 37 |
+
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 41 |
+
def act_quant_kernel(
|
| 42 |
+
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
|
| 43 |
+
round_scale=False, inplace=False
|
| 44 |
+
):
|
| 45 |
+
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
|
| 46 |
+
M = T.symbolic("M")
|
| 47 |
+
fp8_min = -448.0
|
| 48 |
+
fp8_max = 448.0
|
| 49 |
+
fp8_max_inv = 1 / fp8_max
|
| 50 |
+
num_stages = 0 if round_scale or inplace else 2
|
| 51 |
+
blk_m = 32
|
| 52 |
+
group_size = block_size
|
| 53 |
+
# Internal computation in FP32; scale_dtype controls output storage format.
|
| 54 |
+
compute_dtype = FP32
|
| 55 |
+
out_dtype = in_dtype if inplace else out_dtype
|
| 56 |
+
|
| 57 |
+
@T.prim_func
|
| 58 |
+
def act_quant_kernel_(
|
| 59 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 60 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 61 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 62 |
+
):
|
| 63 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 64 |
+
pid_m,
|
| 65 |
+
pid_n,
|
| 66 |
+
):
|
| 67 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 68 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 69 |
+
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 70 |
+
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 71 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 72 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 73 |
+
|
| 74 |
+
for _ in T.Pipelined(1, num_stages=num_stages):
|
| 75 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 76 |
+
T.copy(x_shared, x_local)
|
| 77 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 78 |
+
for i in T.Parallel(blk_m):
|
| 79 |
+
amax_local[i] = T.max(amax_local[i], 1e-4)
|
| 80 |
+
if round_scale:
|
| 81 |
+
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
| 82 |
+
else:
|
| 83 |
+
s_local[i] = amax_local[i] * fp8_max_inv
|
| 84 |
+
if inplace:
|
| 85 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 86 |
+
y_local[i, j] = T.Cast(
|
| 87 |
+
out_dtype,
|
| 88 |
+
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
|
| 89 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 90 |
+
))) * s_local[i],
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 94 |
+
y_local[i, j] = T.clamp(
|
| 95 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 96 |
+
)
|
| 97 |
+
for i in T.Parallel(blk_m):
|
| 98 |
+
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
| 99 |
+
T.copy(y_local, y_shared)
|
| 100 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 101 |
+
|
| 102 |
+
return act_quant_kernel_
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def act_quant(
|
| 106 |
+
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
|
| 107 |
+
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
|
| 110 |
+
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
|
| 111 |
+
N = x.size(-1)
|
| 112 |
+
assert N % block_size == 0
|
| 113 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 114 |
+
z = x.contiguous()
|
| 115 |
+
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
|
| 116 |
+
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
|
| 117 |
+
kernel = act_quant_kernel(
|
| 118 |
+
N, block_size, scale_dtype=tl_dtype,
|
| 119 |
+
round_scale=scale_fmt is not None, inplace=inplace,
|
| 120 |
+
)
|
| 121 |
+
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
| 122 |
+
if inplace:
|
| 123 |
+
x.copy_(y)
|
| 124 |
+
return x
|
| 125 |
+
return y, s
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 129 |
+
def fp4_quant_kernel(
|
| 130 |
+
N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
|
| 131 |
+
):
|
| 132 |
+
"""Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
|
| 133 |
+
M = T.symbolic("M")
|
| 134 |
+
fp4_max = 6.0
|
| 135 |
+
fp4_max_inv = 1.0 / fp4_max
|
| 136 |
+
blk_m = 32
|
| 137 |
+
group_size = block_size
|
| 138 |
+
compute_dtype = FP32
|
| 139 |
+
out_dtype = in_dtype if inplace else FP4
|
| 140 |
+
|
| 141 |
+
@T.prim_func
|
| 142 |
+
def fp4_quant_kernel_(
|
| 143 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 144 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 145 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 146 |
+
):
|
| 147 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 148 |
+
pid_m,
|
| 149 |
+
pid_n,
|
| 150 |
+
):
|
| 151 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 152 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 153 |
+
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 154 |
+
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 155 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 156 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 157 |
+
|
| 158 |
+
for _ in T.Pipelined(1, num_stages=2):
|
| 159 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 160 |
+
T.copy(x_shared, x_local)
|
| 161 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 162 |
+
for i in T.Parallel(blk_m):
|
| 163 |
+
amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
|
| 164 |
+
s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
|
| 165 |
+
if inplace:
|
| 166 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 167 |
+
y_local[i, j] = T.Cast(
|
| 168 |
+
out_dtype,
|
| 169 |
+
T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
|
| 170 |
+
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
| 171 |
+
))) * s_local[i],
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 175 |
+
y_local[i, j] = T.clamp(
|
| 176 |
+
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
| 177 |
+
)
|
| 178 |
+
for i in T.Parallel(blk_m):
|
| 179 |
+
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
| 180 |
+
T.copy(y_local, y_shared)
|
| 181 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 182 |
+
|
| 183 |
+
return fp4_quant_kernel_
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def fp4_act_quant(
|
| 187 |
+
x: torch.Tensor, block_size: int = 32, inplace: bool = False,
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
"""Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
|
| 190 |
+
N = x.size(-1)
|
| 191 |
+
assert N % block_size == 0
|
| 192 |
+
z = x.contiguous()
|
| 193 |
+
y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
|
| 194 |
+
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
|
| 195 |
+
kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
|
| 196 |
+
kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
|
| 197 |
+
if inplace:
|
| 198 |
+
x.copy_(y)
|
| 199 |
+
return x
|
| 200 |
+
return y, s
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 204 |
+
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
| 205 |
+
assert out_dtype in [BF16, FP32]
|
| 206 |
+
|
| 207 |
+
M = T.symbolic("M")
|
| 208 |
+
group_size = 128
|
| 209 |
+
block_M = 32
|
| 210 |
+
block_N = 128
|
| 211 |
+
block_K = 128
|
| 212 |
+
|
| 213 |
+
@T.prim_func
|
| 214 |
+
def fp8_gemm_kernel_(
|
| 215 |
+
A: T.Tensor[(M, K), FP8],
|
| 216 |
+
B: T.Tensor[(N, K), FP8],
|
| 217 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 218 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
|
| 219 |
+
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
|
| 220 |
+
):
|
| 221 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 222 |
+
bx,
|
| 223 |
+
by,
|
| 224 |
+
):
|
| 225 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 226 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 227 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 228 |
+
Scale_C_shared = T.alloc_shared((block_M), FP32)
|
| 229 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 230 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 231 |
+
|
| 232 |
+
# Improve L2 Cache
|
| 233 |
+
T.use_swizzle(panel_size=10)
|
| 234 |
+
T.clear(C_local)
|
| 235 |
+
T.clear(C_local_accum)
|
| 236 |
+
|
| 237 |
+
K_iters = T.ceildiv(K, block_K)
|
| 238 |
+
for k in T.Pipelined(K_iters, num_stages=4):
|
| 239 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 240 |
+
T.copy(B[bx * block_N, k * block_K], B_shared)
|
| 241 |
+
# Cast scales to FP32 for computation; scales_b has one value per block_N group
|
| 242 |
+
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
|
| 243 |
+
for i in T.Parallel(block_M):
|
| 244 |
+
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
|
| 245 |
+
|
| 246 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 247 |
+
# Separate accumulator for scale-corrected results (2x accumulation precision)
|
| 248 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 249 |
+
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
| 250 |
+
T.clear(C_local)
|
| 251 |
+
T.copy(C_local_accum, C_shared)
|
| 252 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 253 |
+
|
| 254 |
+
return fp8_gemm_kernel_
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def fp8_gemm(
|
| 258 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
| 259 |
+
scale_dtype: torch.dtype = torch.float32,
|
| 260 |
+
) -> torch.Tensor:
|
| 261 |
+
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
|
| 262 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 263 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 264 |
+
"Scaling factor tensors must be contiguous"
|
| 265 |
+
)
|
| 266 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 267 |
+
K = a.size(-1)
|
| 268 |
+
M = a.numel() // K
|
| 269 |
+
N = b.size(0)
|
| 270 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 271 |
+
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
| 272 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 273 |
+
return c
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 277 |
+
def sparse_attn_kernel(h: int, d: int, scale=None):
|
| 278 |
+
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
|
| 279 |
+
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
|
| 280 |
+
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
|
| 281 |
+
b = T.symbolic("b")
|
| 282 |
+
m = T.symbolic("m")
|
| 283 |
+
n = T.symbolic("n")
|
| 284 |
+
topk = T.symbolic("topk")
|
| 285 |
+
if scale is None:
|
| 286 |
+
scale = (1.0 / d) ** 0.5
|
| 287 |
+
|
| 288 |
+
num_stages = 2
|
| 289 |
+
threads = 256
|
| 290 |
+
block = 64
|
| 291 |
+
num_blocks = tilelang.cdiv(topk, block)
|
| 292 |
+
|
| 293 |
+
@T.prim_func
|
| 294 |
+
def sparse_attn_kernel_(
|
| 295 |
+
q: T.Tensor[(b, m, h, d), BF16],
|
| 296 |
+
kv: T.Tensor[(b, n, d), BF16],
|
| 297 |
+
o: T.Tensor[(b, m, h, d), BF16],
|
| 298 |
+
attn_sink: T.Tensor[(h,), FP32],
|
| 299 |
+
topk_idxs: T.Tensor[(b, m, topk), INT32],
|
| 300 |
+
):
|
| 301 |
+
with T.Kernel(m, b, threads=threads) as (bx, by):
|
| 302 |
+
q_shared = T.alloc_shared((h, d), BF16)
|
| 303 |
+
kv_shared = T.alloc_shared((block, d), BF16)
|
| 304 |
+
o_shared = T.alloc_shared((h, d), BF16)
|
| 305 |
+
acc_s_cast = T.alloc_shared((h, block), BF16)
|
| 306 |
+
|
| 307 |
+
idxs = T.alloc_fragment(block, INT32)
|
| 308 |
+
acc_s = T.alloc_fragment((h, block), FP32)
|
| 309 |
+
acc_o = T.alloc_fragment((h, d), FP32)
|
| 310 |
+
scores_max = T.alloc_fragment(h, FP32)
|
| 311 |
+
scores_max_prev = T.alloc_fragment(h, FP32)
|
| 312 |
+
scores_scale = T.alloc_fragment(h, FP32)
|
| 313 |
+
scores_sum = T.alloc_fragment(h, FP32)
|
| 314 |
+
sum_exp = T.alloc_fragment(h, FP32)
|
| 315 |
+
|
| 316 |
+
T.clear(acc_o)
|
| 317 |
+
T.clear(sum_exp)
|
| 318 |
+
T.fill(scores_max, -T.infinity(FP32))
|
| 319 |
+
T.copy(q[by, bx, :, :], q_shared)
|
| 320 |
+
|
| 321 |
+
for t in T.Pipelined(num_blocks, num_stages=num_stages):
|
| 322 |
+
for i in T.Parallel(block):
|
| 323 |
+
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
|
| 324 |
+
for i, j in T.Parallel(block, d):
|
| 325 |
+
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
|
| 326 |
+
for i, j in T.Parallel(h, block):
|
| 327 |
+
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
|
| 328 |
+
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
|
| 329 |
+
for i, j in T.Parallel(h, block):
|
| 330 |
+
acc_s[i, j] *= scale
|
| 331 |
+
T.copy(scores_max, scores_max_prev)
|
| 332 |
+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
|
| 333 |
+
for i in T.Parallel(h):
|
| 334 |
+
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
|
| 335 |
+
for i, j in T.Parallel(h, block):
|
| 336 |
+
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
|
| 337 |
+
T.reduce_sum(acc_s, scores_sum, dim=1)
|
| 338 |
+
for i in T.Parallel(h):
|
| 339 |
+
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
|
| 340 |
+
T.copy(acc_s, acc_s_cast)
|
| 341 |
+
for i, j in T.Parallel(h, d):
|
| 342 |
+
acc_o[i, j] *= scores_scale[i]
|
| 343 |
+
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
|
| 344 |
+
|
| 345 |
+
for i in T.Parallel(h):
|
| 346 |
+
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
|
| 347 |
+
for i, j in T.Parallel(h, d):
|
| 348 |
+
acc_o[i, j] /= sum_exp[i]
|
| 349 |
+
T.copy(acc_o, o_shared)
|
| 350 |
+
T.copy(o_shared, o[by, bx, :, :])
|
| 351 |
+
|
| 352 |
+
return sparse_attn_kernel_
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def sparse_attn(
|
| 356 |
+
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
b, s, h, d = q.size()
|
| 359 |
+
# Pad heads to 16 for kernel efficiency (stripped after)
|
| 360 |
+
if h < 16:
|
| 361 |
+
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
|
| 362 |
+
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
|
| 363 |
+
o = torch.empty_like(q)
|
| 364 |
+
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
|
| 365 |
+
kernel(q, kv, o, attn_sink, topk_idxs)
|
| 366 |
+
if h < 16:
|
| 367 |
+
o = o.narrow(2, 0, h).contiguous()
|
| 368 |
+
return o
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 372 |
+
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
|
| 373 |
+
n = T.symbolic("n")
|
| 374 |
+
mix_hc = (2 + hc) * hc
|
| 375 |
+
threads = 64
|
| 376 |
+
|
| 377 |
+
@T.prim_func
|
| 378 |
+
def hc_split_sinkhorn_kernel_(
|
| 379 |
+
mixes: T.Tensor[(n, mix_hc), FP32],
|
| 380 |
+
hc_scale: T.Tensor[(3,), FP32],
|
| 381 |
+
hc_base: T.Tensor[(mix_hc,), FP32],
|
| 382 |
+
pre: T.Tensor[(n, hc), FP32],
|
| 383 |
+
post: T.Tensor[(n, hc), FP32],
|
| 384 |
+
comb: T.Tensor[(n, hc, hc), FP32],
|
| 385 |
+
):
|
| 386 |
+
with T.Kernel(n, threads=threads) as i:
|
| 387 |
+
mixes_shared = T.alloc_shared(mix_hc, FP32)
|
| 388 |
+
comb_frag = T.alloc_fragment((hc, hc), FP32)
|
| 389 |
+
T.copy(mixes[i, :], mixes_shared)
|
| 390 |
+
|
| 391 |
+
for j in T.Parallel(hc):
|
| 392 |
+
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
|
| 393 |
+
for j in T.Parallel(hc):
|
| 394 |
+
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
|
| 395 |
+
for j, k in T.Parallel(hc, hc):
|
| 396 |
+
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
|
| 397 |
+
|
| 398 |
+
row_sum = T.alloc_fragment(hc, FP32)
|
| 399 |
+
col_sum = T.alloc_fragment(hc, FP32)
|
| 400 |
+
|
| 401 |
+
# comb = comb.softmax(-1) + eps
|
| 402 |
+
row_max = T.alloc_fragment(hc, FP32)
|
| 403 |
+
T.reduce_max(comb_frag, row_max, dim=1)
|
| 404 |
+
for j, k in T.Parallel(hc, hc):
|
| 405 |
+
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
|
| 406 |
+
T.reduce_sum(comb_frag, row_sum, dim=1)
|
| 407 |
+
for j, k in T.Parallel(hc, hc):
|
| 408 |
+
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
|
| 409 |
+
|
| 410 |
+
# comb = comb / (comb.sum(-2) + eps)
|
| 411 |
+
T.reduce_sum(comb_frag, col_sum, dim=0)
|
| 412 |
+
for j, k in T.Parallel(hc, hc):
|
| 413 |
+
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
| 414 |
+
|
| 415 |
+
for _ in T.serial(sinkhorn_iters - 1):
|
| 416 |
+
# comb = comb / (comb.sum(-1) + eps)
|
| 417 |
+
T.reduce_sum(comb_frag, row_sum, dim=1)
|
| 418 |
+
for j, k in T.Parallel(hc, hc):
|
| 419 |
+
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
|
| 420 |
+
# comb = comb / (comb.sum(-2) + eps)
|
| 421 |
+
T.reduce_sum(comb_frag, col_sum, dim=0)
|
| 422 |
+
for j, k in T.Parallel(hc, hc):
|
| 423 |
+
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
| 424 |
+
|
| 425 |
+
T.copy(comb_frag, comb[i, :, :])
|
| 426 |
+
|
| 427 |
+
return hc_split_sinkhorn_kernel_
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
|
| 431 |
+
b, s, _ = mixes.size()
|
| 432 |
+
pre = mixes.new_empty(b, s, hc_mult)
|
| 433 |
+
post = mixes.new_empty(b, s, hc_mult)
|
| 434 |
+
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
|
| 435 |
+
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
|
| 436 |
+
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
|
| 437 |
+
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
|
| 438 |
+
return pre, post, comb
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 442 |
+
def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
| 443 |
+
"""FP8 act x FP4 weight GEMM kernel.
|
| 444 |
+
|
| 445 |
+
C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
|
| 446 |
+
|
| 447 |
+
Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
|
| 448 |
+
Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
|
| 449 |
+
|
| 450 |
+
B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
|
| 451 |
+
The FP4 values are packed along the K (last) dimension.
|
| 452 |
+
|
| 453 |
+
Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
|
| 454 |
+
cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
|
| 455 |
+
Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
|
| 456 |
+
"""
|
| 457 |
+
M = T.symbolic("M")
|
| 458 |
+
act_group_size = 128
|
| 459 |
+
weight_group_size = 32
|
| 460 |
+
block_M = 32
|
| 461 |
+
block_N = 128
|
| 462 |
+
block_K = 32 # matches weight_group_size for simple scale handling
|
| 463 |
+
n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
|
| 464 |
+
|
| 465 |
+
@T.prim_func
|
| 466 |
+
def fp4_gemm_kernel_(
|
| 467 |
+
A: T.Tensor[(M, K), FP8],
|
| 468 |
+
B: T.Tensor[(N, K), FP4],
|
| 469 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 470 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
|
| 471 |
+
scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
|
| 472 |
+
):
|
| 473 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 474 |
+
bx,
|
| 475 |
+
by,
|
| 476 |
+
):
|
| 477 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 478 |
+
B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
|
| 479 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 480 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 481 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 482 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 483 |
+
scale_a_frag = T.alloc_fragment((block_M,), FP32)
|
| 484 |
+
scale_b_frag = T.alloc_fragment((block_N,), FP32)
|
| 485 |
+
|
| 486 |
+
T.use_swizzle(panel_size=10)
|
| 487 |
+
T.clear(C_local)
|
| 488 |
+
T.clear(C_local_accum)
|
| 489 |
+
|
| 490 |
+
K_iters = T.ceildiv(K, block_K)
|
| 491 |
+
for k in T.Pipelined(K_iters, num_stages=2):
|
| 492 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 493 |
+
T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
|
| 494 |
+
# FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
|
| 495 |
+
for i, j in T.Parallel(block_N, block_K):
|
| 496 |
+
B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
|
| 497 |
+
|
| 498 |
+
# Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
|
| 499 |
+
for i in T.Parallel(block_N):
|
| 500 |
+
scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
|
| 501 |
+
|
| 502 |
+
# Act scale: per 128 on K, indexed by k // 4
|
| 503 |
+
for i in T.Parallel(block_M):
|
| 504 |
+
scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
|
| 505 |
+
|
| 506 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 507 |
+
|
| 508 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 509 |
+
C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
|
| 510 |
+
T.clear(C_local)
|
| 511 |
+
|
| 512 |
+
T.copy(C_local_accum, C_shared)
|
| 513 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 514 |
+
|
| 515 |
+
return fp4_gemm_kernel_
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def fp4_gemm(
|
| 519 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
| 520 |
+
scale_dtype: torch.dtype = torch.float32,
|
| 521 |
+
) -> torch.Tensor:
|
| 522 |
+
"""C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
|
| 523 |
+
A has per-128 act scale; B has per-32 E8M0 weight scale.
|
| 524 |
+
B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
|
| 525 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 526 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 527 |
+
"Scaling factor tensors must be contiguous"
|
| 528 |
+
)
|
| 529 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 530 |
+
K = a.size(-1)
|
| 531 |
+
M = a.numel() // K
|
| 532 |
+
N = b.size(0)
|
| 533 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 534 |
+
kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
| 535 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 536 |
+
return c
|
inference/model.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Optional, Literal
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
world_size = 1
|
| 16 |
+
rank = 0
|
| 17 |
+
block_size = 128
|
| 18 |
+
fp4_block_size = 32
|
| 19 |
+
default_dtype = torch.bfloat16
|
| 20 |
+
scale_fmt = None
|
| 21 |
+
scale_dtype = torch.float32
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@contextmanager
|
| 25 |
+
def set_dtype(dtype):
|
| 26 |
+
"""Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
|
| 27 |
+
prev = torch.get_default_dtype()
|
| 28 |
+
torch.set_default_dtype(dtype)
|
| 29 |
+
try:
|
| 30 |
+
yield
|
| 31 |
+
finally:
|
| 32 |
+
torch.set_default_dtype(prev)
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ModelArgs:
|
| 36 |
+
"""Model hyperparameters. Field names match the config JSON keys."""
|
| 37 |
+
max_batch_size: int = 4
|
| 38 |
+
max_seq_len: int = 4096
|
| 39 |
+
dtype: Literal["bf16", "fp8"] = "fp8"
|
| 40 |
+
scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
|
| 41 |
+
expert_dtype: Literal[None, "fp4"] = None
|
| 42 |
+
scale_dtype: Literal["fp32", "fp8"] = "fp8"
|
| 43 |
+
vocab_size: int = 129280
|
| 44 |
+
dim: int = 4096
|
| 45 |
+
moe_inter_dim: int = 4096
|
| 46 |
+
n_layers: int = 7
|
| 47 |
+
n_hash_layers: int = 0
|
| 48 |
+
n_mtp_layers: int = 1
|
| 49 |
+
n_heads: int = 64
|
| 50 |
+
# moe
|
| 51 |
+
n_routed_experts: int = 8
|
| 52 |
+
n_shared_experts: int = 1
|
| 53 |
+
n_activated_experts: int = 2
|
| 54 |
+
score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
|
| 55 |
+
route_scale: float = 1.
|
| 56 |
+
swiglu_limit: float = 0.
|
| 57 |
+
# mqa
|
| 58 |
+
q_lora_rank: int = 1024
|
| 59 |
+
head_dim: int = 512
|
| 60 |
+
rope_head_dim: int = 64
|
| 61 |
+
norm_eps: float = 1e-6
|
| 62 |
+
o_groups: int = 8
|
| 63 |
+
o_lora_rank: int = 1024
|
| 64 |
+
window_size: int = 128
|
| 65 |
+
compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
|
| 66 |
+
# yarn
|
| 67 |
+
compress_rope_theta: float = 40000.0
|
| 68 |
+
original_seq_len: int = 0
|
| 69 |
+
rope_theta: float = 10000.0
|
| 70 |
+
rope_factor: float = 40
|
| 71 |
+
beta_fast: int = 32
|
| 72 |
+
beta_slow: int = 1
|
| 73 |
+
# index
|
| 74 |
+
index_n_heads: int = 64
|
| 75 |
+
index_head_dim: int = 128
|
| 76 |
+
index_topk: int = 512
|
| 77 |
+
# hc
|
| 78 |
+
hc_mult: int = 4
|
| 79 |
+
hc_sinkhorn_iters: int = 20
|
| 80 |
+
hc_eps: float = 1e-6
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ParallelEmbedding(nn.Module):
|
| 84 |
+
"""Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
|
| 85 |
+
Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
|
| 86 |
+
def __init__(self, vocab_size: int, dim: int):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.vocab_size = vocab_size
|
| 89 |
+
self.dim = dim
|
| 90 |
+
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
| 91 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 92 |
+
self.vocab_start_idx = rank * self.part_vocab_size
|
| 93 |
+
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
| 94 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
if world_size > 1:
|
| 98 |
+
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
| 99 |
+
x = x - self.vocab_start_idx
|
| 100 |
+
x[mask] = 0
|
| 101 |
+
y = F.embedding(x, self.weight)
|
| 102 |
+
if world_size > 1:
|
| 103 |
+
y[mask] = 0
|
| 104 |
+
dist.all_reduce(y)
|
| 105 |
+
return y
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 109 |
+
"""Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
|
| 110 |
+
For quantized weights, x is first quantized to FP8 via act_quant."""
|
| 111 |
+
assert bias is None
|
| 112 |
+
|
| 113 |
+
if weight.dtype == torch.float4_e2m1fn_x2:
|
| 114 |
+
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
| 115 |
+
return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
|
| 116 |
+
elif weight.dtype == torch.float8_e4m3fn:
|
| 117 |
+
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
| 118 |
+
return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
|
| 119 |
+
else:
|
| 120 |
+
return F.linear(x, weight)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Linear(nn.Module):
|
| 124 |
+
"""Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
|
| 125 |
+
|
| 126 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.in_features = in_features
|
| 129 |
+
self.out_features = out_features
|
| 130 |
+
dtype = dtype or default_dtype
|
| 131 |
+
if dtype == torch.float4_e2m1fn_x2:
|
| 132 |
+
# FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
|
| 133 |
+
# Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
|
| 134 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
|
| 135 |
+
scale_out_features = out_features
|
| 136 |
+
scale_in_features = in_features // fp4_block_size
|
| 137 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
| 138 |
+
elif dtype == torch.float8_e4m3fn:
|
| 139 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
| 140 |
+
scale_out_features = (out_features + block_size - 1) // block_size
|
| 141 |
+
scale_in_features = (in_features + block_size - 1) // block_size
|
| 142 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
| 143 |
+
else:
|
| 144 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
| 145 |
+
self.register_parameter("scale", None)
|
| 146 |
+
if bias:
|
| 147 |
+
self.bias = nn.Parameter(torch.empty(out_features))
|
| 148 |
+
else:
|
| 149 |
+
self.register_parameter("bias", None)
|
| 150 |
+
|
| 151 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 152 |
+
return linear(x, self.weight, self.bias)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ColumnParallelLinear(Linear):
|
| 156 |
+
"""Shards output dim across TP ranks. No all-reduce needed on output."""
|
| 157 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 158 |
+
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
| 159 |
+
self.part_out_features = out_features // world_size
|
| 160 |
+
super().__init__(in_features, self.part_out_features, bias, dtype)
|
| 161 |
+
|
| 162 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
return linear(x, self.weight, self.bias)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class RowParallelLinear(Linear):
|
| 167 |
+
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
|
| 168 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 169 |
+
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
| 170 |
+
self.part_in_features = in_features // world_size
|
| 171 |
+
super().__init__(self.part_in_features, out_features, bias, dtype)
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 174 |
+
y = linear(x, self.weight, None)
|
| 175 |
+
if world_size > 1:
|
| 176 |
+
y = y.float()
|
| 177 |
+
dist.all_reduce(y)
|
| 178 |
+
if self.bias is not None:
|
| 179 |
+
y += self.bias
|
| 180 |
+
return y.type_as(x)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class RMSNorm(nn.Module):
|
| 184 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.dim = dim
|
| 187 |
+
self.eps = eps
|
| 188 |
+
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 189 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor):
|
| 192 |
+
dtype = x.dtype
|
| 193 |
+
x = x.float()
|
| 194 |
+
var = x.square().mean(-1, keepdim=True)
|
| 195 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 196 |
+
return (self.weight * x).to(dtype)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@lru_cache(2)
|
| 200 |
+
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
|
| 201 |
+
"""Precomputes complex exponentials for rotary embeddings with YaRN scaling.
|
| 202 |
+
When original_seq_len > 0, applies frequency interpolation with a smooth
|
| 203 |
+
linear ramp between beta_fast and beta_slow correction ranges."""
|
| 204 |
+
|
| 205 |
+
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
| 206 |
+
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
| 207 |
+
|
| 208 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
| 209 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
| 210 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
| 211 |
+
return max(low, 0), min(high, dim-1)
|
| 212 |
+
|
| 213 |
+
def linear_ramp_factor(min, max, dim):
|
| 214 |
+
if min == max:
|
| 215 |
+
max += 0.001
|
| 216 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 217 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 218 |
+
return ramp_func
|
| 219 |
+
|
| 220 |
+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 221 |
+
if original_seq_len > 0:
|
| 222 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
|
| 223 |
+
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
| 224 |
+
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
| 225 |
+
|
| 226 |
+
t = torch.arange(seqlen)
|
| 227 |
+
freqs = torch.outer(t, freqs)
|
| 228 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 229 |
+
return freqs_cis
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
|
| 233 |
+
"""Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
|
| 234 |
+
y = x
|
| 235 |
+
x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
|
| 236 |
+
if inverse:
|
| 237 |
+
freqs_cis = freqs_cis.conj()
|
| 238 |
+
if x.ndim == 3:
|
| 239 |
+
freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
|
| 240 |
+
else:
|
| 241 |
+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 242 |
+
x = torch.view_as_real(x * freqs_cis).flatten(-2)
|
| 243 |
+
y.copy_(x)
|
| 244 |
+
return y
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
|
| 249 |
+
assert x.dtype == torch.bfloat16
|
| 250 |
+
from fast_hadamard_transform import hadamard_transform
|
| 251 |
+
return hadamard_transform(x, scale=x.size(-1) ** -0.5)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@lru_cache(1)
|
| 255 |
+
def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
|
| 256 |
+
if start_pos >= window_size - 1:
|
| 257 |
+
start_pos %= window_size
|
| 258 |
+
matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
|
| 259 |
+
elif start_pos > 0:
|
| 260 |
+
matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
|
| 261 |
+
else:
|
| 262 |
+
base = torch.arange(seqlen).unsqueeze(1)
|
| 263 |
+
matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
|
| 264 |
+
matrix = torch.where(matrix > base, -1, matrix)
|
| 265 |
+
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@lru_cache(2)
|
| 269 |
+
def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
|
| 270 |
+
if start_pos > 0:
|
| 271 |
+
matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
|
| 272 |
+
else:
|
| 273 |
+
matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
|
| 274 |
+
mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 275 |
+
matrix = torch.where(mask, -1, matrix + offset)
|
| 276 |
+
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class Compressor(nn.Module):
|
| 280 |
+
"""Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
|
| 281 |
+
When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
|
| 282 |
+
|
| 283 |
+
def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.dim = args.dim
|
| 286 |
+
self.head_dim = head_dim
|
| 287 |
+
self.rope_head_dim = args.rope_head_dim
|
| 288 |
+
self.nope_head_dim = head_dim - args.rope_head_dim
|
| 289 |
+
self.compress_ratio = compress_ratio
|
| 290 |
+
self.overlap = compress_ratio == 4
|
| 291 |
+
self.rotate = rotate
|
| 292 |
+
coff = 1 + self.overlap
|
| 293 |
+
|
| 294 |
+
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
| 295 |
+
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 296 |
+
# When overlap, the first half of dims is for overlapping compression, second half for normal.
|
| 297 |
+
self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
| 298 |
+
self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
| 299 |
+
self.norm = RMSNorm(self.head_dim, args.norm_eps)
|
| 300 |
+
self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
|
| 301 |
+
# State buffers for decode-phase incremental compression.
|
| 302 |
+
# With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
|
| 303 |
+
self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
|
| 304 |
+
self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
|
| 305 |
+
self.freqs_cis: torch.Tensor = None
|
| 306 |
+
|
| 307 |
+
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
| 308 |
+
# tensor: [b,s,r,2d]
|
| 309 |
+
b, s, _, _ = tensor.size()
|
| 310 |
+
ratio, d = self.compress_ratio, self.head_dim
|
| 311 |
+
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
| 312 |
+
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
| 313 |
+
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
| 314 |
+
return new_tensor
|
| 315 |
+
|
| 316 |
+
def forward(self, x: torch.Tensor, start_pos: int):
|
| 317 |
+
assert self.kv_cache is not None
|
| 318 |
+
bsz, seqlen, _ = x.size()
|
| 319 |
+
ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
|
| 320 |
+
dtype = x.dtype
|
| 321 |
+
# compression need fp32
|
| 322 |
+
x = x.float()
|
| 323 |
+
kv = self.wkv(x)
|
| 324 |
+
score = self.wgate(x)
|
| 325 |
+
if start_pos == 0:
|
| 326 |
+
should_compress = seqlen >= ratio
|
| 327 |
+
remainder = seqlen % ratio
|
| 328 |
+
cutoff = seqlen - remainder
|
| 329 |
+
offset = ratio if overlap else 0
|
| 330 |
+
if overlap and cutoff >= ratio:
|
| 331 |
+
self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
|
| 332 |
+
self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
|
| 333 |
+
if remainder > 0:
|
| 334 |
+
kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
|
| 335 |
+
self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
|
| 336 |
+
score = score[:, :cutoff]
|
| 337 |
+
kv = kv.unflatten(1, (-1, ratio))
|
| 338 |
+
score = score.unflatten(1, (-1, ratio)) + self.ape
|
| 339 |
+
if overlap:
|
| 340 |
+
kv = self.overlap_transform(kv, 0)
|
| 341 |
+
score = self.overlap_transform(score, float("-inf"))
|
| 342 |
+
kv = (kv * score.softmax(dim=2)).sum(dim=2)
|
| 343 |
+
else:
|
| 344 |
+
should_compress = (start_pos + 1) % self.compress_ratio == 0
|
| 345 |
+
score += self.ape[start_pos % ratio]
|
| 346 |
+
if overlap:
|
| 347 |
+
self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
|
| 348 |
+
self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
|
| 349 |
+
if should_compress:
|
| 350 |
+
kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
|
| 351 |
+
score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
|
| 352 |
+
kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
|
| 353 |
+
self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
|
| 354 |
+
self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
|
| 355 |
+
else:
|
| 356 |
+
self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
|
| 357 |
+
self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
|
| 358 |
+
if should_compress:
|
| 359 |
+
kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
|
| 360 |
+
if not should_compress:
|
| 361 |
+
return
|
| 362 |
+
kv = self.norm(kv.to(dtype))
|
| 363 |
+
if start_pos == 0:
|
| 364 |
+
freqs_cis = self.freqs_cis[:cutoff:ratio]
|
| 365 |
+
else:
|
| 366 |
+
freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
|
| 367 |
+
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
| 368 |
+
if self.rotate:
|
| 369 |
+
kv = rotate_activation(kv)
|
| 370 |
+
fp4_act_quant(kv, fp4_block_size, True)
|
| 371 |
+
else:
|
| 372 |
+
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
| 373 |
+
if start_pos == 0:
|
| 374 |
+
self.kv_cache[:bsz, :seqlen // ratio] = kv
|
| 375 |
+
else:
|
| 376 |
+
self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
|
| 377 |
+
return kv
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Indexer(torch.nn.Module):
|
| 381 |
+
"""Selects top-k compressed KV positions for sparse attention via learned scoring.
|
| 382 |
+
Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
|
| 383 |
+
|
| 384 |
+
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.dim = args.dim
|
| 387 |
+
self.n_heads = args.index_n_heads
|
| 388 |
+
self.n_local_heads = args.index_n_heads // world_size
|
| 389 |
+
self.head_dim = args.index_head_dim
|
| 390 |
+
self.rope_head_dim = args.rope_head_dim
|
| 391 |
+
self.index_topk = args.index_topk
|
| 392 |
+
self.q_lora_rank = args.q_lora_rank
|
| 393 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 394 |
+
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
|
| 395 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 396 |
+
self.compress_ratio = compress_ratio
|
| 397 |
+
|
| 398 |
+
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
|
| 399 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
|
| 400 |
+
self.freqs_cis = None
|
| 401 |
+
|
| 402 |
+
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
|
| 403 |
+
bsz, seqlen, _ = x.size()
|
| 404 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 405 |
+
ratio = self.compress_ratio
|
| 406 |
+
rd = self.rope_head_dim
|
| 407 |
+
end_pos = start_pos + seqlen
|
| 408 |
+
if self.compressor.kv_cache is None:
|
| 409 |
+
self.compressor.kv_cache = self.kv_cache
|
| 410 |
+
self.compressor.freqs_cis = self.freqs_cis
|
| 411 |
+
q = self.wq_b(qr)
|
| 412 |
+
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
|
| 413 |
+
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
| 414 |
+
q = rotate_activation(q)
|
| 415 |
+
# use fp4 simulation for q and kv in indexer
|
| 416 |
+
fp4_act_quant(q, fp4_block_size, True)
|
| 417 |
+
self.compressor(x, start_pos)
|
| 418 |
+
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
|
| 419 |
+
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
| 420 |
+
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
|
| 421 |
+
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
|
| 422 |
+
if world_size > 1:
|
| 423 |
+
dist.all_reduce(index_score)
|
| 424 |
+
if start_pos == 0:
|
| 425 |
+
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 426 |
+
index_score += torch.where(mask, float("-inf"), 0)
|
| 427 |
+
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
|
| 428 |
+
if start_pos == 0:
|
| 429 |
+
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 430 |
+
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
|
| 431 |
+
else:
|
| 432 |
+
topk_idxs += offset
|
| 433 |
+
return topk_idxs
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class Attention(nn.Module):
|
| 437 |
+
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
|
| 438 |
+
Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
|
| 439 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.layer_id = layer_id
|
| 442 |
+
self.dim = args.dim
|
| 443 |
+
self.n_heads = args.n_heads
|
| 444 |
+
self.n_local_heads = args.n_heads // world_size
|
| 445 |
+
self.q_lora_rank = args.q_lora_rank
|
| 446 |
+
self.o_lora_rank = args.o_lora_rank
|
| 447 |
+
self.head_dim = args.head_dim
|
| 448 |
+
self.rope_head_dim = args.rope_head_dim
|
| 449 |
+
self.nope_head_dim = args.head_dim - args.rope_head_dim
|
| 450 |
+
self.n_groups = args.o_groups
|
| 451 |
+
self.n_local_groups = self.n_groups // world_size
|
| 452 |
+
self.window_size = args.window_size
|
| 453 |
+
self.compress_ratio = args.compress_ratios[layer_id]
|
| 454 |
+
self.eps = args.norm_eps
|
| 455 |
+
|
| 456 |
+
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
|
| 457 |
+
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
| 458 |
+
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
| 459 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 460 |
+
self.wkv = Linear(self.dim, self.head_dim)
|
| 461 |
+
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
| 462 |
+
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
|
| 463 |
+
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
|
| 464 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 465 |
+
|
| 466 |
+
if self.compress_ratio:
|
| 467 |
+
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
|
| 468 |
+
if self.compress_ratio == 4:
|
| 469 |
+
self.indexer = Indexer(args, self.compress_ratio)
|
| 470 |
+
else:
|
| 471 |
+
self.indexer = None
|
| 472 |
+
|
| 473 |
+
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
|
| 474 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
|
| 475 |
+
if self.compress_ratio:
|
| 476 |
+
original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
|
| 477 |
+
else:
|
| 478 |
+
# disable YaRN and use base rope_theta in pure sliding-window attention
|
| 479 |
+
original_seq_len, rope_theta = 0, args.rope_theta
|
| 480 |
+
freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
|
| 481 |
+
rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
|
| 482 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 483 |
+
|
| 484 |
+
def forward(self, x: torch.Tensor, start_pos: int):
|
| 485 |
+
bsz, seqlen, _ = x.size()
|
| 486 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 487 |
+
win = self.window_size
|
| 488 |
+
ratio = self.compress_ratio
|
| 489 |
+
rd = self.rope_head_dim
|
| 490 |
+
if self.compress_ratio and self.compressor.kv_cache is None:
|
| 491 |
+
self.compressor.kv_cache = self.kv_cache[:, win:]
|
| 492 |
+
self.compressor.freqs_cis = self.freqs_cis
|
| 493 |
+
if self.indexer is not None:
|
| 494 |
+
self.indexer.freqs_cis = self.freqs_cis
|
| 495 |
+
# q
|
| 496 |
+
qr = q = self.q_norm(self.wq_a(x))
|
| 497 |
+
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
|
| 498 |
+
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
|
| 499 |
+
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
| 500 |
+
|
| 501 |
+
# win kv & topk_idxs
|
| 502 |
+
kv = self.wkv(x)
|
| 503 |
+
kv = self.kv_norm(kv)
|
| 504 |
+
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
| 505 |
+
# FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
|
| 506 |
+
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
| 507 |
+
topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
|
| 508 |
+
if self.compress_ratio:
|
| 509 |
+
offset = kv.size(1) if start_pos == 0 else win
|
| 510 |
+
if self.indexer is not None:
|
| 511 |
+
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
|
| 512 |
+
else:
|
| 513 |
+
compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
|
| 514 |
+
topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
|
| 515 |
+
topk_idxs = topk_idxs.int()
|
| 516 |
+
|
| 517 |
+
# compress kv & attn
|
| 518 |
+
if start_pos == 0:
|
| 519 |
+
if seqlen <= win:
|
| 520 |
+
self.kv_cache[:bsz, :seqlen] = kv
|
| 521 |
+
else:
|
| 522 |
+
cutoff = seqlen % win
|
| 523 |
+
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
|
| 524 |
+
if self.compress_ratio:
|
| 525 |
+
if (kv_compress := self.compressor(x, start_pos)) is not None:
|
| 526 |
+
kv = torch.cat([kv, kv_compress], dim=1)
|
| 527 |
+
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
| 528 |
+
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
|
| 529 |
+
else:
|
| 530 |
+
self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
|
| 531 |
+
if self.compress_ratio:
|
| 532 |
+
self.compressor(x, start_pos)
|
| 533 |
+
o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
|
| 534 |
+
apply_rotary_emb(o[..., -rd:], freqs_cis, True)
|
| 535 |
+
|
| 536 |
+
# o
|
| 537 |
+
o = o.view(bsz, seqlen, self.n_local_groups, -1)
|
| 538 |
+
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
|
| 539 |
+
# NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
|
| 540 |
+
# but using BF16 for simplicity.
|
| 541 |
+
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
|
| 542 |
+
x = self.wo_b(o.flatten(2))
|
| 543 |
+
return x
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class Gate(nn.Module):
|
| 547 |
+
"""MoE gating: computes expert routing scores and selects top-k experts.
|
| 548 |
+
Supports hash-based routing (first n_hash_layers) where expert indices are
|
| 549 |
+
predetermined per token ID, and score-based routing (remaining layers)."""
|
| 550 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 551 |
+
super().__init__()
|
| 552 |
+
self.dim = args.dim
|
| 553 |
+
self.topk = args.n_activated_experts
|
| 554 |
+
self.score_func = args.score_func
|
| 555 |
+
self.route_scale = args.route_scale
|
| 556 |
+
self.hash = layer_id < args.n_hash_layers
|
| 557 |
+
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
| 558 |
+
if self.hash:
|
| 559 |
+
self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
|
| 560 |
+
self.bias = None
|
| 561 |
+
else:
|
| 562 |
+
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
|
| 563 |
+
|
| 564 |
+
def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 565 |
+
scores = linear(x.float(), self.weight.float())
|
| 566 |
+
if self.score_func == "softmax":
|
| 567 |
+
scores = scores.softmax(dim=-1)
|
| 568 |
+
elif self.score_func == "sigmoid":
|
| 569 |
+
scores = scores.sigmoid()
|
| 570 |
+
else:
|
| 571 |
+
scores = F.softplus(scores).sqrt()
|
| 572 |
+
original_scores = scores
|
| 573 |
+
# Bias shifts scores for expert selection (topk) but does not affect routing weights.
|
| 574 |
+
if self.bias is not None:
|
| 575 |
+
scores = scores + self.bias
|
| 576 |
+
if self.hash:
|
| 577 |
+
indices = self.tid2eid[input_ids]
|
| 578 |
+
else:
|
| 579 |
+
indices = scores.topk(self.topk, dim=-1)[1]
|
| 580 |
+
weights = original_scores.gather(1, indices)
|
| 581 |
+
if self.score_func != "softmax":
|
| 582 |
+
weights /= weights.sum(dim=-1, keepdim=True)
|
| 583 |
+
weights *= self.route_scale
|
| 584 |
+
return weights, indices
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class Expert(nn.Module):
|
| 588 |
+
"""Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
|
| 589 |
+
def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.w1 = Linear(dim, inter_dim, dtype=dtype)
|
| 592 |
+
self.w2 = Linear(inter_dim, dim, dtype=dtype)
|
| 593 |
+
self.w3 = Linear(dim, inter_dim, dtype=dtype)
|
| 594 |
+
self.swiglu_limit = swiglu_limit
|
| 595 |
+
|
| 596 |
+
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 597 |
+
dtype = x.dtype
|
| 598 |
+
gate = self.w1(x).float()
|
| 599 |
+
up = self.w3(x).float()
|
| 600 |
+
if self.swiglu_limit > 0:
|
| 601 |
+
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
|
| 602 |
+
gate = torch.clamp(gate, max=self.swiglu_limit)
|
| 603 |
+
x = F.silu(gate) * up
|
| 604 |
+
if weights is not None:
|
| 605 |
+
x = weights * x
|
| 606 |
+
return self.w2(x.to(dtype))
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class MoE(nn.Module):
|
| 610 |
+
"""Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
|
| 611 |
+
Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
|
| 612 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 613 |
+
super().__init__()
|
| 614 |
+
self.layer_id = layer_id
|
| 615 |
+
self.dim = args.dim
|
| 616 |
+
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
| 617 |
+
self.n_routed_experts = args.n_routed_experts
|
| 618 |
+
self.n_local_experts = args.n_routed_experts // world_size
|
| 619 |
+
self.n_activated_experts = args.n_activated_experts
|
| 620 |
+
self.experts_start_idx = rank * self.n_local_experts
|
| 621 |
+
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
| 622 |
+
self.gate = Gate(layer_id, args)
|
| 623 |
+
expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
|
| 624 |
+
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 625 |
+
for i in range(self.n_routed_experts)])
|
| 626 |
+
assert args.n_shared_experts == 1
|
| 627 |
+
# no swiglu_limit
|
| 628 |
+
self.shared_experts = Expert(args.dim, args.moe_inter_dim)
|
| 629 |
+
|
| 630 |
+
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 631 |
+
shape = x.size()
|
| 632 |
+
x = x.view(-1, self.dim)
|
| 633 |
+
weights, indices = self.gate(x, input_ids.flatten())
|
| 634 |
+
y = torch.zeros_like(x, dtype=torch.float32)
|
| 635 |
+
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
| 636 |
+
for i in range(self.experts_start_idx, self.experts_end_idx):
|
| 637 |
+
if counts[i] == 0:
|
| 638 |
+
continue
|
| 639 |
+
expert = self.experts[i]
|
| 640 |
+
idx, top = torch.where(indices == i)
|
| 641 |
+
y[idx] += expert(x[idx], weights[idx, top, None])
|
| 642 |
+
if world_size > 1:
|
| 643 |
+
dist.all_reduce(y)
|
| 644 |
+
y += self.shared_experts(x)
|
| 645 |
+
return y.type_as(x).view(shape)
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
class Block(nn.Module):
|
| 649 |
+
"""Transformer block with Hyper-Connections (HC) mixing.
|
| 650 |
+
Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
|
| 651 |
+
hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
|
| 652 |
+
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
|
| 653 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 654 |
+
super().__init__()
|
| 655 |
+
self.layer_id = layer_id
|
| 656 |
+
self.norm_eps = args.norm_eps
|
| 657 |
+
self.attn = Attention(layer_id, args)
|
| 658 |
+
self.ffn = MoE(layer_id, args)
|
| 659 |
+
self.attn_norm = RMSNorm(args.dim, self.norm_eps)
|
| 660 |
+
self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
|
| 661 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 662 |
+
self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
|
| 663 |
+
self.hc_eps = args.hc_eps
|
| 664 |
+
mix_hc = (2 + hc_mult) * hc_mult
|
| 665 |
+
hc_dim = hc_mult * args.dim
|
| 666 |
+
with set_dtype(torch.float32):
|
| 667 |
+
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
| 668 |
+
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
| 669 |
+
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
|
| 670 |
+
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
|
| 671 |
+
self.hc_attn_scale = nn.Parameter(torch.empty(3))
|
| 672 |
+
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
|
| 673 |
+
|
| 674 |
+
def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
| 675 |
+
# x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
|
| 676 |
+
shape, dtype = x.size(), x.dtype
|
| 677 |
+
x = x.flatten(2).float()
|
| 678 |
+
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
| 679 |
+
mixes = F.linear(x, hc_fn) * rsqrt
|
| 680 |
+
pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
|
| 681 |
+
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
| 682 |
+
return y.to(dtype), post, comb
|
| 683 |
+
|
| 684 |
+
def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
|
| 685 |
+
# x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
|
| 686 |
+
y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
|
| 687 |
+
return y.type_as(x)
|
| 688 |
+
|
| 689 |
+
def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
|
| 690 |
+
residual = x
|
| 691 |
+
x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
|
| 692 |
+
x = self.attn_norm(x)
|
| 693 |
+
x = self.attn(x, start_pos)
|
| 694 |
+
x = self.hc_post(x, residual, post, comb)
|
| 695 |
+
|
| 696 |
+
residual = x
|
| 697 |
+
x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
|
| 698 |
+
x = self.ffn_norm(x)
|
| 699 |
+
x = self.ffn(x, input_ids)
|
| 700 |
+
x = self.hc_post(x, residual, post, comb)
|
| 701 |
+
return x
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class ParallelHead(nn.Module):
|
| 705 |
+
|
| 706 |
+
def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
|
| 707 |
+
super().__init__()
|
| 708 |
+
self.vocab_size = vocab_size
|
| 709 |
+
self.dim = dim
|
| 710 |
+
self.norm_eps = norm_eps
|
| 711 |
+
self.hc_eps = hc_eps
|
| 712 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 713 |
+
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
|
| 714 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
|
| 715 |
+
|
| 716 |
+
def get_logits(self, x):
|
| 717 |
+
return F.linear(x[:, -1].float(), self.weight)
|
| 718 |
+
|
| 719 |
+
def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
|
| 720 |
+
# x: [b,s,hc,d]
|
| 721 |
+
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
|
| 722 |
+
logits = self.get_logits(norm(x))
|
| 723 |
+
if world_size > 1:
|
| 724 |
+
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
| 725 |
+
dist.all_gather(all_logits, logits)
|
| 726 |
+
logits = torch.cat(all_logits, dim=-1)
|
| 727 |
+
return logits
|
| 728 |
+
|
| 729 |
+
def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
| 730 |
+
shape, dtype = x.size(), x.dtype
|
| 731 |
+
x = x.flatten(2).float()
|
| 732 |
+
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
| 733 |
+
mixes = F.linear(x, hc_fn) * rsqrt
|
| 734 |
+
pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
|
| 735 |
+
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
| 736 |
+
return y.to(dtype)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class MTPBlock(Block):
|
| 740 |
+
|
| 741 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 742 |
+
super().__init__(layer_id, args)
|
| 743 |
+
self.e_proj = Linear(args.dim, args.dim)
|
| 744 |
+
self.h_proj = Linear(args.dim, args.dim)
|
| 745 |
+
self.enorm = RMSNorm(args.dim, args.norm_eps)
|
| 746 |
+
self.hnorm = RMSNorm(args.dim, args.norm_eps)
|
| 747 |
+
self.norm = RMSNorm(args.dim, args.norm_eps)
|
| 748 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 749 |
+
hc_dim = hc_mult * args.dim
|
| 750 |
+
with set_dtype(torch.float32):
|
| 751 |
+
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
| 752 |
+
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
| 753 |
+
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
| 754 |
+
self.embed: ParallelEmbedding = None
|
| 755 |
+
self.head: ParallelHead = None
|
| 756 |
+
|
| 757 |
+
@torch.inference_mode()
|
| 758 |
+
def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
|
| 759 |
+
# x: [b,s,hc,d]
|
| 760 |
+
assert self.embed is not None and self.head is not None
|
| 761 |
+
e = self.embed(input_ids)
|
| 762 |
+
e = self.enorm(e)
|
| 763 |
+
x = self.hnorm(x)
|
| 764 |
+
x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
|
| 765 |
+
x = super().forward(x, start_pos, input_ids)
|
| 766 |
+
logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
| 767 |
+
return logits
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
class Transformer(nn.Module):
|
| 771 |
+
"""Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
|
| 772 |
+
Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
|
| 773 |
+
def __init__(self, args: ModelArgs):
|
| 774 |
+
global world_size, rank, default_dtype, scale_fmt, scale_dtype
|
| 775 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 776 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 777 |
+
default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
| 778 |
+
scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
|
| 779 |
+
scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
|
| 780 |
+
super().__init__()
|
| 781 |
+
self.max_seq_len = args.max_seq_len
|
| 782 |
+
self.norm_eps = args.norm_eps
|
| 783 |
+
self.hc_eps = args.hc_eps
|
| 784 |
+
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
| 785 |
+
self.layers = torch.nn.ModuleList()
|
| 786 |
+
for layer_id in range(args.n_layers):
|
| 787 |
+
self.layers.append(Block(layer_id, args))
|
| 788 |
+
self.norm = RMSNorm(args.dim, self.norm_eps)
|
| 789 |
+
self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
|
| 790 |
+
self.mtp = torch.nn.ModuleList()
|
| 791 |
+
for layer_id in range(args.n_mtp_layers):
|
| 792 |
+
self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
|
| 793 |
+
self.mtp[-1].embed = self.embed
|
| 794 |
+
self.mtp[-1].head = self.head
|
| 795 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 796 |
+
hc_dim = hc_mult * args.dim
|
| 797 |
+
with set_dtype(torch.float32):
|
| 798 |
+
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
| 799 |
+
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
| 800 |
+
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
| 801 |
+
|
| 802 |
+
@torch.inference_mode()
|
| 803 |
+
def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
|
| 804 |
+
h = self.embed(input_ids)
|
| 805 |
+
# Expand to hc_mult copies for Hyper-Connections
|
| 806 |
+
h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
|
| 807 |
+
for layer in self.layers:
|
| 808 |
+
h = layer(h, start_pos, input_ids)
|
| 809 |
+
logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
| 810 |
+
return logits
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
if __name__ == "__main__":
|
| 814 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 815 |
+
torch.set_default_device("cuda")
|
| 816 |
+
torch.manual_seed(0)
|
| 817 |
+
args = ModelArgs(n_hash_layers=0)
|
| 818 |
+
x = torch.randint(0, args.vocab_size, (2, 128))
|
| 819 |
+
model = Transformer(args)
|
| 820 |
+
|
| 821 |
+
print(model(x).size())
|
| 822 |
+
for i in range(128, 150):
|
| 823 |
+
print(i, model(x[:, 0:1], i).size())
|
| 824 |
+
|
| 825 |
+
h = torch.randn(2, 128, args.hc_mult, args.dim)
|
| 826 |
+
mtp = model.mtp[0]
|
| 827 |
+
print(mtp(h, 0, x).size())
|
| 828 |
+
print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
|
inference/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.10.0
|
| 2 |
+
transformers>=5.0.0
|
| 3 |
+
safetensors>=0.7.0
|
| 4 |
+
fast_hadamard_transform
|
| 5 |
+
tilelang==0.1.8
|