File size: 10,655 Bytes
928743f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | """Convert an auto-round / GPTQ W4A16 packed HuggingFace checkpoint of DeepSeek-V4
into the MP-sharded local format consumed by `model.py`/`generate.py`.
Packing convention (auto-round → auto_gptq):
- qweight : int32 [in_features // 8, out_features], LSB-first 4-bit packed along dim 0
- qzeros : int32 [in_features // group_size, out_features // 8], LSB-first 4-bit packed along dim 1
- scales : fp16 [in_features // group_size, out_features]
Sharding rules per linear:
- ColumnParallel (shard output dim, original `dim=0` in `mapping`):
qweight along dim 1; qzeros along dim 1 (must be divisible by 8 first, then by world_size);
scales along dim 1.
- RowParallel (shard input dim, original `dim=1` in `mapping`):
qweight along dim 0 (must be divisible by 8 first, then by world_size);
qzeros along dim 0 (must be divisible by group_size first, then by world_size);
scales along dim 0.
Non-quantised tensors (embed.weight, *.norm.weight, attn_sink, hc_*, ape, gate.bias,
gate.tid2eid, etc.) follow the same rules as the original `convert.py`.
"""
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
GROUP_SIZE = 128
# Same name remapping as the original convert.py
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"lm_head": ("head", 0),
# Already-translated names (used by the inference checkpoints we already have)
"embed": ("embed", 0),
"wq_a": ("wq_a", None),
"wq_b": ("wq_b", 0),
"wkv": ("wkv", None),
"wo_a": ("wo_a", 0),
"wo_b": ("wo_b", 1),
"w1": ("w1", 0),
"w2": ("w2", 1),
"w3": ("w3", 0),
"head": ("head", 0),
"weights_proj": ("weights_proj", 0),
# special non-weight keys
"attn_sink": ("attn_sink", 0),
"ape": ("ape", None),
# NOTE: 'gate' is intentionally NOT in this mapping -- the routing gate is a
# plain nn.Parameter that is replicated on every rank.
}
# Suffixes that mark the three pieces of a packed W4A16 linear.
QUANT_SUFFIXES = (".qweight", ".qzeros", ".scales")
def shard_quant(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
shard_dim: int, mp: int):
"""Yield (qweight_i, qzeros_i, scales_i) for i in range(mp).
shard_dim is the *logical* dim of the dequantised weight: 0 == output (column parallel),
1 == input (row parallel)."""
out = qweight.size(1)
in_packed = qweight.size(0) # in_features // 8
n_groups = scales.size(0) # in_features // group_size
if shard_dim == 0: # ColumnParallel: shard along OUTPUT
assert out % mp == 0, f"out={out} not divisible by mp={mp}"
# qzeros packs 8 outputs per int32 in dim 1, so need (out/mp) % 8 == 0
assert (out // mp) % 8 == 0, f"shard {out//mp} of out dim not divisible by 8 (qzeros packing)"
sh_out = out // mp
sh_qz_cols = qzeros.size(1) // mp # == out / 8 / mp
for i in range(mp):
yield (
qweight.narrow(1, i * sh_out, sh_out).contiguous(),
qzeros.narrow(1, i * sh_qz_cols, sh_qz_cols).contiguous(),
scales.narrow(1, i * sh_out, sh_out).contiguous(),
)
elif shard_dim == 1: # RowParallel: shard along INPUT
# qweight packs 8 inputs per int32 in dim 0, scales/qzeros are per-group on dim 0
assert in_packed % mp == 0, f"in_packed={in_packed} not divisible by mp={mp}"
assert n_groups % mp == 0, f"n_groups={n_groups} not divisible by mp={mp}"
sh_in_packed = in_packed // mp
sh_groups = n_groups // mp
for i in range(mp):
yield (
qweight.narrow(0, i * sh_in_packed, sh_in_packed).contiguous(),
qzeros.narrow(0, i * sh_groups, sh_groups).contiguous(),
scales.narrow(0, i * sh_groups, sh_groups).contiguous(),
)
else:
# Replicate
for _ in range(mp):
yield qweight, qzeros, scales
def get_layer_key(name: str):
"""Return the linear-name token (e.g. wq_a, w1, head) used for the rename mapping."""
parts = name.split(".")
if name.endswith(QUANT_SUFFIXES):
return parts[-2] # ...x.qweight -> x
if name.endswith(".bias") and "gate" in name:
return "gate" # ffn.gate.bias
if name.endswith(".tid2eid"):
return "gate"
if any(k in parts for k in ("hc_attn_fn", "hc_attn_base", "hc_attn_scale",
"hc_ffn_fn", "hc_ffn_base", "hc_ffn_scale",
"hc_head_fn", "hc_head_base", "hc_head_scale",
"attn_sink", "ape")):
return parts[-1]
return parts[-2]
def main(hf_ckpt_path, save_path, n_experts, mp):
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
# Group all fragments belonging to the same logical linear so we can shard
# qweight/qzeros/scales together.
pending: dict[str, dict[str, torch.Tensor]] = {}
def emit_linear(base_name: str, parts: dict[str, torch.Tensor], shard_dim):
"""Distribute a quantised linear (3 tensors) across `mp` shards."""
qweight = parts["qweight"]
qzeros = parts["qzeros"]
scales = parts["scales"].to(torch.bfloat16) # store bf16 instead of fp16
# Expert-local pruning: only the rank that owns this expert keeps the tensors.
if "experts" in base_name and "shared_experts" not in base_name:
idx = int(base_name.split(".experts.")[1].split(".")[0])
owner = idx // n_local_experts
state_dicts[owner][base_name + ".qweight"] = qweight
state_dicts[owner][base_name + ".qzeros"] = qzeros
state_dicts[owner][base_name + ".scales"] = scales
return
if shard_dim is None:
# Replicate across all ranks
for i in range(mp):
state_dicts[i][base_name + ".qweight"] = qweight
state_dicts[i][base_name + ".qzeros"] = qzeros
state_dicts[i][base_name + ".scales"] = scales
else:
for i, (qw, qz, sc) in enumerate(shard_quant(qweight, qzeros, scales, shard_dim, mp)):
state_dicts[i][base_name + ".qweight"] = qw
state_dicts[i][base_name + ".qzeros"] = qz
state_dicts[i][base_name + ".scales"] = sc
files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
for file_path in tqdm(files, desc="files"):
with safe_open(file_path, framework="pt", device="cpu") as f:
for orig_name in f.keys():
# ----- name remapping (mirrors original convert.py) -----
name = orig_name
if name.startswith("model."):
name = name[len("model."):]
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
continue
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = get_layer_key(name)
if key in mapping:
new_key, dim = mapping[key]
name = name.replace(key, new_key)
else:
dim = None
tensor = f.get_tensor(orig_name)
# ----- handle the three-piece quantised linear -----
# `shared_experts` are plain (non-parallel) Linears in the model;
# never shard them even though `w1/w2/w3` are in the mapping.
if "shared_experts" in name:
dim = None
if orig_name.endswith(QUANT_SUFFIXES):
base = name.rsplit(".", 1)[0]
suf = name.rsplit(".", 1)[1] # qweight|qzeros|scales
pending.setdefault(base, {"_dim": dim})[suf] = tensor
pending[base]["_dim"] = dim
parts = pending[base]
if all(s in parts for s in ("qweight", "qzeros", "scales")):
emit_linear(base, parts, parts["_dim"])
del pending[base]
continue
# ----- non-quantised tensor -----
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".experts.")[1].split(".")[0])
owner = idx // n_local_experts
state_dicts[owner][name] = tensor
continue
if dim is None:
for i in range(mp):
state_dicts[i][name] = tensor
else:
assert tensor.size(dim) % mp == 0, f"{name} dim {dim} ({tensor.size(dim)}) not divisible by {mp}"
sh = tensor.size(dim) // mp
for i in range(mp):
state_dicts[i][name] = tensor.narrow(dim, i * sh, sh).contiguous()
if pending:
raise RuntimeError(f"Incomplete quantised linears: {list(pending)[:5]}")
os.makedirs(save_path, exist_ok=True)
for i in trange(mp, desc="write shards"):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for fn in ["tokenizer.json", "tokenizer_config.json"]:
src = os.path.join(hf_ckpt_path, fn)
dst = os.path.join(save_path, fn)
if os.path.exists(src):
shutil.copyfile(src, dst)
if __name__ == "__main__":
p = ArgumentParser()
p.add_argument("--hf-ckpt-path", required=True)
p.add_argument("--save-path", required=True)
p.add_argument("--n-experts", type=int, required=True)
p.add_argument("--model-parallel", type=int, required=True)
a = p.parse_args()
assert a.n_experts % a.model_parallel == 0
main(a.hf_ckpt_path, a.save_path, a.n_experts, a.model_parallel)
|