xinhe's picture
Upload folder using huggingface_hub
928743f verified
"""Convert an auto-round / GPTQ W4A16 packed HuggingFace checkpoint of DeepSeek-V4
into the MP-sharded local format consumed by `model.py`/`generate.py`.
Packing convention (auto-round → auto_gptq):
- qweight : int32 [in_features // 8, out_features], LSB-first 4-bit packed along dim 0
- qzeros : int32 [in_features // group_size, out_features // 8], LSB-first 4-bit packed along dim 1
- scales : fp16 [in_features // group_size, out_features]
Sharding rules per linear:
- ColumnParallel (shard output dim, original `dim=0` in `mapping`):
qweight along dim 1; qzeros along dim 1 (must be divisible by 8 first, then by world_size);
scales along dim 1.
- RowParallel (shard input dim, original `dim=1` in `mapping`):
qweight along dim 0 (must be divisible by 8 first, then by world_size);
qzeros along dim 0 (must be divisible by group_size first, then by world_size);
scales along dim 0.
Non-quantised tensors (embed.weight, *.norm.weight, attn_sink, hc_*, ape, gate.bias,
gate.tid2eid, etc.) follow the same rules as the original `convert.py`.
"""
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
GROUP_SIZE = 128
# Same name remapping as the original convert.py
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"lm_head": ("head", 0),
# Already-translated names (used by the inference checkpoints we already have)
"embed": ("embed", 0),
"wq_a": ("wq_a", None),
"wq_b": ("wq_b", 0),
"wkv": ("wkv", None),
"wo_a": ("wo_a", 0),
"wo_b": ("wo_b", 1),
"w1": ("w1", 0),
"w2": ("w2", 1),
"w3": ("w3", 0),
"head": ("head", 0),
"weights_proj": ("weights_proj", 0),
# special non-weight keys
"attn_sink": ("attn_sink", 0),
"ape": ("ape", None),
# NOTE: 'gate' is intentionally NOT in this mapping -- the routing gate is a
# plain nn.Parameter that is replicated on every rank.
}
# Suffixes that mark the three pieces of a packed W4A16 linear.
QUANT_SUFFIXES = (".qweight", ".qzeros", ".scales")
def shard_quant(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
shard_dim: int, mp: int):
"""Yield (qweight_i, qzeros_i, scales_i) for i in range(mp).
shard_dim is the *logical* dim of the dequantised weight: 0 == output (column parallel),
1 == input (row parallel)."""
out = qweight.size(1)
in_packed = qweight.size(0) # in_features // 8
n_groups = scales.size(0) # in_features // group_size
if shard_dim == 0: # ColumnParallel: shard along OUTPUT
assert out % mp == 0, f"out={out} not divisible by mp={mp}"
# qzeros packs 8 outputs per int32 in dim 1, so need (out/mp) % 8 == 0
assert (out // mp) % 8 == 0, f"shard {out//mp} of out dim not divisible by 8 (qzeros packing)"
sh_out = out // mp
sh_qz_cols = qzeros.size(1) // mp # == out / 8 / mp
for i in range(mp):
yield (
qweight.narrow(1, i * sh_out, sh_out).contiguous(),
qzeros.narrow(1, i * sh_qz_cols, sh_qz_cols).contiguous(),
scales.narrow(1, i * sh_out, sh_out).contiguous(),
)
elif shard_dim == 1: # RowParallel: shard along INPUT
# qweight packs 8 inputs per int32 in dim 0, scales/qzeros are per-group on dim 0
assert in_packed % mp == 0, f"in_packed={in_packed} not divisible by mp={mp}"
assert n_groups % mp == 0, f"n_groups={n_groups} not divisible by mp={mp}"
sh_in_packed = in_packed // mp
sh_groups = n_groups // mp
for i in range(mp):
yield (
qweight.narrow(0, i * sh_in_packed, sh_in_packed).contiguous(),
qzeros.narrow(0, i * sh_groups, sh_groups).contiguous(),
scales.narrow(0, i * sh_groups, sh_groups).contiguous(),
)
else:
# Replicate
for _ in range(mp):
yield qweight, qzeros, scales
def get_layer_key(name: str):
"""Return the linear-name token (e.g. wq_a, w1, head) used for the rename mapping."""
parts = name.split(".")
if name.endswith(QUANT_SUFFIXES):
return parts[-2] # ...x.qweight -> x
if name.endswith(".bias") and "gate" in name:
return "gate" # ffn.gate.bias
if name.endswith(".tid2eid"):
return "gate"
if any(k in parts for k in ("hc_attn_fn", "hc_attn_base", "hc_attn_scale",
"hc_ffn_fn", "hc_ffn_base", "hc_ffn_scale",
"hc_head_fn", "hc_head_base", "hc_head_scale",
"attn_sink", "ape")):
return parts[-1]
return parts[-2]
def main(hf_ckpt_path, save_path, n_experts, mp):
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
# Group all fragments belonging to the same logical linear so we can shard
# qweight/qzeros/scales together.
pending: dict[str, dict[str, torch.Tensor]] = {}
def emit_linear(base_name: str, parts: dict[str, torch.Tensor], shard_dim):
"""Distribute a quantised linear (3 tensors) across `mp` shards."""
qweight = parts["qweight"]
qzeros = parts["qzeros"]
scales = parts["scales"].to(torch.bfloat16) # store bf16 instead of fp16
# Expert-local pruning: only the rank that owns this expert keeps the tensors.
if "experts" in base_name and "shared_experts" not in base_name:
idx = int(base_name.split(".experts.")[1].split(".")[0])
owner = idx // n_local_experts
state_dicts[owner][base_name + ".qweight"] = qweight
state_dicts[owner][base_name + ".qzeros"] = qzeros
state_dicts[owner][base_name + ".scales"] = scales
return
if shard_dim is None:
# Replicate across all ranks
for i in range(mp):
state_dicts[i][base_name + ".qweight"] = qweight
state_dicts[i][base_name + ".qzeros"] = qzeros
state_dicts[i][base_name + ".scales"] = scales
else:
for i, (qw, qz, sc) in enumerate(shard_quant(qweight, qzeros, scales, shard_dim, mp)):
state_dicts[i][base_name + ".qweight"] = qw
state_dicts[i][base_name + ".qzeros"] = qz
state_dicts[i][base_name + ".scales"] = sc
files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
for file_path in tqdm(files, desc="files"):
with safe_open(file_path, framework="pt", device="cpu") as f:
for orig_name in f.keys():
# ----- name remapping (mirrors original convert.py) -----
name = orig_name
if name.startswith("model."):
name = name[len("model."):]
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
continue
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = get_layer_key(name)
if key in mapping:
new_key, dim = mapping[key]
name = name.replace(key, new_key)
else:
dim = None
tensor = f.get_tensor(orig_name)
# ----- handle the three-piece quantised linear -----
# `shared_experts` are plain (non-parallel) Linears in the model;
# never shard them even though `w1/w2/w3` are in the mapping.
if "shared_experts" in name:
dim = None
if orig_name.endswith(QUANT_SUFFIXES):
base = name.rsplit(".", 1)[0]
suf = name.rsplit(".", 1)[1] # qweight|qzeros|scales
pending.setdefault(base, {"_dim": dim})[suf] = tensor
pending[base]["_dim"] = dim
parts = pending[base]
if all(s in parts for s in ("qweight", "qzeros", "scales")):
emit_linear(base, parts, parts["_dim"])
del pending[base]
continue
# ----- non-quantised tensor -----
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".experts.")[1].split(".")[0])
owner = idx // n_local_experts
state_dicts[owner][name] = tensor
continue
if dim is None:
for i in range(mp):
state_dicts[i][name] = tensor
else:
assert tensor.size(dim) % mp == 0, f"{name} dim {dim} ({tensor.size(dim)}) not divisible by {mp}"
sh = tensor.size(dim) // mp
for i in range(mp):
state_dicts[i][name] = tensor.narrow(dim, i * sh, sh).contiguous()
if pending:
raise RuntimeError(f"Incomplete quantised linears: {list(pending)[:5]}")
os.makedirs(save_path, exist_ok=True)
for i in trange(mp, desc="write shards"):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for fn in ["tokenizer.json", "tokenizer_config.json"]:
src = os.path.join(hf_ckpt_path, fn)
dst = os.path.join(save_path, fn)
if os.path.exists(src):
shutil.copyfile(src, dst)
if __name__ == "__main__":
p = ArgumentParser()
p.add_argument("--hf-ckpt-path", required=True)
p.add_argument("--save-path", required=True)
p.add_argument("--n-experts", type=int, required=True)
p.add_argument("--model-parallel", type=int, required=True)
a = p.parse_args()
assert a.n_experts % a.model_parallel == 0
main(a.hf_ckpt_path, a.save_path, a.n_experts, a.model_parallel)