"""Convert an auto-round / GPTQ W4A16 packed HuggingFace checkpoint of DeepSeek-V4 into the MP-sharded local format consumed by `model.py`/`generate.py`. Packing convention (auto-round → auto_gptq): - qweight : int32 [in_features // 8, out_features], LSB-first 4-bit packed along dim 0 - qzeros : int32 [in_features // group_size, out_features // 8], LSB-first 4-bit packed along dim 1 - scales : fp16 [in_features // group_size, out_features] Sharding rules per linear: - ColumnParallel (shard output dim, original `dim=0` in `mapping`): qweight along dim 1; qzeros along dim 1 (must be divisible by 8 first, then by world_size); scales along dim 1. - RowParallel (shard input dim, original `dim=1` in `mapping`): qweight along dim 0 (must be divisible by 8 first, then by world_size); qzeros along dim 0 (must be divisible by group_size first, then by world_size); scales along dim 0. Non-quantised tensors (embed.weight, *.norm.weight, attn_sink, hc_*, ape, gate.bias, gate.tid2eid, etc.) follow the same rules as the original `convert.py`. """ import os import shutil from argparse import ArgumentParser from glob import glob from tqdm import tqdm, trange import torch from safetensors.torch import safe_open, save_file GROUP_SIZE = 128 # Same name remapping as the original convert.py mapping = { "embed_tokens": ("embed", 0), "input_layernorm": ("attn_norm", None), "post_attention_layernorm": ("ffn_norm", None), "q_proj": ("wq", 0), "q_a_proj": ("wq_a", None), "q_a_layernorm": ("q_norm", None), "q_b_proj": ("wq_b", 0), "kv_a_proj_with_mqa": ("wkv_a", None), "kv_a_layernorm": ("kv_norm", None), "kv_b_proj": ("wkv_b", 0), "o_proj": ("wo", 1), "gate_proj": ("w1", 0), "down_proj": ("w2", 1), "up_proj": ("w3", 0), "lm_head": ("head", 0), # Already-translated names (used by the inference checkpoints we already have) "embed": ("embed", 0), "wq_a": ("wq_a", None), "wq_b": ("wq_b", 0), "wkv": ("wkv", None), "wo_a": ("wo_a", 0), "wo_b": ("wo_b", 1), "w1": ("w1", 0), "w2": ("w2", 1), "w3": ("w3", 0), "head": ("head", 0), "weights_proj": ("weights_proj", 0), # special non-weight keys "attn_sink": ("attn_sink", 0), "ape": ("ape", None), # NOTE: 'gate' is intentionally NOT in this mapping -- the routing gate is a # plain nn.Parameter that is replicated on every rank. } # Suffixes that mark the three pieces of a packed W4A16 linear. QUANT_SUFFIXES = (".qweight", ".qzeros", ".scales") def shard_quant(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, shard_dim: int, mp: int): """Yield (qweight_i, qzeros_i, scales_i) for i in range(mp). shard_dim is the *logical* dim of the dequantised weight: 0 == output (column parallel), 1 == input (row parallel).""" out = qweight.size(1) in_packed = qweight.size(0) # in_features // 8 n_groups = scales.size(0) # in_features // group_size if shard_dim == 0: # ColumnParallel: shard along OUTPUT assert out % mp == 0, f"out={out} not divisible by mp={mp}" # qzeros packs 8 outputs per int32 in dim 1, so need (out/mp) % 8 == 0 assert (out // mp) % 8 == 0, f"shard {out//mp} of out dim not divisible by 8 (qzeros packing)" sh_out = out // mp sh_qz_cols = qzeros.size(1) // mp # == out / 8 / mp for i in range(mp): yield ( qweight.narrow(1, i * sh_out, sh_out).contiguous(), qzeros.narrow(1, i * sh_qz_cols, sh_qz_cols).contiguous(), scales.narrow(1, i * sh_out, sh_out).contiguous(), ) elif shard_dim == 1: # RowParallel: shard along INPUT # qweight packs 8 inputs per int32 in dim 0, scales/qzeros are per-group on dim 0 assert in_packed % mp == 0, f"in_packed={in_packed} not divisible by mp={mp}" assert n_groups % mp == 0, f"n_groups={n_groups} not divisible by mp={mp}" sh_in_packed = in_packed // mp sh_groups = n_groups // mp for i in range(mp): yield ( qweight.narrow(0, i * sh_in_packed, sh_in_packed).contiguous(), qzeros.narrow(0, i * sh_groups, sh_groups).contiguous(), scales.narrow(0, i * sh_groups, sh_groups).contiguous(), ) else: # Replicate for _ in range(mp): yield qweight, qzeros, scales def get_layer_key(name: str): """Return the linear-name token (e.g. wq_a, w1, head) used for the rename mapping.""" parts = name.split(".") if name.endswith(QUANT_SUFFIXES): return parts[-2] # ...x.qweight -> x if name.endswith(".bias") and "gate" in name: return "gate" # ffn.gate.bias if name.endswith(".tid2eid"): return "gate" if any(k in parts for k in ("hc_attn_fn", "hc_attn_base", "hc_attn_scale", "hc_ffn_fn", "hc_ffn_base", "hc_ffn_scale", "hc_head_fn", "hc_head_base", "hc_head_scale", "attn_sink", "ape")): return parts[-1] return parts[-2] def main(hf_ckpt_path, save_path, n_experts, mp): torch.set_num_threads(8) n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] # Group all fragments belonging to the same logical linear so we can shard # qweight/qzeros/scales together. pending: dict[str, dict[str, torch.Tensor]] = {} def emit_linear(base_name: str, parts: dict[str, torch.Tensor], shard_dim): """Distribute a quantised linear (3 tensors) across `mp` shards.""" qweight = parts["qweight"] qzeros = parts["qzeros"] scales = parts["scales"].to(torch.bfloat16) # store bf16 instead of fp16 # Expert-local pruning: only the rank that owns this expert keeps the tensors. if "experts" in base_name and "shared_experts" not in base_name: idx = int(base_name.split(".experts.")[1].split(".")[0]) owner = idx // n_local_experts state_dicts[owner][base_name + ".qweight"] = qweight state_dicts[owner][base_name + ".qzeros"] = qzeros state_dicts[owner][base_name + ".scales"] = scales return if shard_dim is None: # Replicate across all ranks for i in range(mp): state_dicts[i][base_name + ".qweight"] = qweight state_dicts[i][base_name + ".qzeros"] = qzeros state_dicts[i][base_name + ".scales"] = scales else: for i, (qw, qz, sc) in enumerate(shard_quant(qweight, qzeros, scales, shard_dim, mp)): state_dicts[i][base_name + ".qweight"] = qw state_dicts[i][base_name + ".qzeros"] = qz state_dicts[i][base_name + ".scales"] = sc files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors"))) for file_path in tqdm(files, desc="files"): with safe_open(file_path, framework="pt", device="cpu") as f: for orig_name in f.keys(): # ----- name remapping (mirrors original convert.py) ----- name = orig_name if name.startswith("model."): name = name[len("model."):] if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")): continue name = name.replace("self_attn", "attn") name = name.replace("mlp", "ffn") name = name.replace("weight_scale_inv", "scale") name = name.replace("e_score_correction_bias", "bias") key = get_layer_key(name) if key in mapping: new_key, dim = mapping[key] name = name.replace(key, new_key) else: dim = None tensor = f.get_tensor(orig_name) # ----- handle the three-piece quantised linear ----- # `shared_experts` are plain (non-parallel) Linears in the model; # never shard them even though `w1/w2/w3` are in the mapping. if "shared_experts" in name: dim = None if orig_name.endswith(QUANT_SUFFIXES): base = name.rsplit(".", 1)[0] suf = name.rsplit(".", 1)[1] # qweight|qzeros|scales pending.setdefault(base, {"_dim": dim})[suf] = tensor pending[base]["_dim"] = dim parts = pending[base] if all(s in parts for s in ("qweight", "qzeros", "scales")): emit_linear(base, parts, parts["_dim"]) del pending[base] continue # ----- non-quantised tensor ----- if "experts" in name and "shared_experts" not in name: idx = int(name.split(".experts.")[1].split(".")[0]) owner = idx // n_local_experts state_dicts[owner][name] = tensor continue if dim is None: for i in range(mp): state_dicts[i][name] = tensor else: assert tensor.size(dim) % mp == 0, f"{name} dim {dim} ({tensor.size(dim)}) not divisible by {mp}" sh = tensor.size(dim) // mp for i in range(mp): state_dicts[i][name] = tensor.narrow(dim, i * sh, sh).contiguous() if pending: raise RuntimeError(f"Incomplete quantised linears: {list(pending)[:5]}") os.makedirs(save_path, exist_ok=True) for i in trange(mp, desc="write shards"): save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) for fn in ["tokenizer.json", "tokenizer_config.json"]: src = os.path.join(hf_ckpt_path, fn) dst = os.path.join(save_path, fn) if os.path.exists(src): shutil.copyfile(src, dst) if __name__ == "__main__": p = ArgumentParser() p.add_argument("--hf-ckpt-path", required=True) p.add_argument("--save-path", required=True) p.add_argument("--n-experts", type=int, required=True) p.add_argument("--model-parallel", type=int, required=True) a = p.parse_args() assert a.n_experts % a.model_parallel == 0 main(a.hf_ckpt_path, a.save_path, a.n_experts, a.model_parallel)