import json import os from argparse import ArgumentParser from glob import glob from tqdm import tqdm from safetensors import safe_open mapping = { "embed_tokens": ("embed", 0), "input_layernorm": ("attn_norm", None), "post_attention_layernorm": ("ffn_norm", None), "q_proj": ("wq", 0), "q_a_proj": ("wq_a", None), "q_a_layernorm": ("q_norm", None), "q_b_proj": ("wq_b", 0), "kv_a_proj_with_mqa": ("wkv_a", None), "kv_a_layernorm": ("kv_norm", None), "kv_b_proj": ("wkv_b", 0), "o_proj": ("wo", 1), "gate_proj": ("w1", 0), "down_proj": ("w2", 1), "up_proj": ("w3", 0), "lm_head": ("head", 0), "embed": ("embed", 0), "wq_b": ("wq_b", 0), "wo_a": ("wo_a", 0), "wo_b": ("wo_b", 1), "head": ("head", 0), "attn_sink": ("attn_sink", 0), "weights_proj": ("weights_proj", 0), } def _tensor_header(f, name: str): """Shape + dtype from file header (no full tensor read).""" sl = f.get_slice(name) return sl.get_shape(), sl.get_dtype() def collect_save_keys( hf_ckpt_path: str, n_experts: int, mp: int, ) -> list[list[str]]: """ Returns, for each parallel shard, the sorted list of key names that `save_file` would write (same naming as the original convert, without loading tensor payloads). """ n_local_experts = n_experts // mp per_shard: list[set[str]] = [set() for _ in range(mp)] files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors"))) if not files: raise FileNotFoundError(f"no *.safetensors under {hf_ckpt_path!r}") for file_path in tqdm(files, desc="keys"): with safe_open(file_path, framework="pt", device="cpu") as f: for raw_name in f.keys(): name = raw_name if name.startswith("model."): name = name[len("model.") :] if name.startswith("mtp.") and ( "emb" in name or name.endswith("head.weight") ): continue name = name.replace("self_attn", "attn") name = name.replace("mlp", "ffn") name = name.replace("weight_scale_inv", "scale") name = name.replace("e_score_correction_bias", "bias") if any( x in name for x in ["hc", "attn_sink", "tie2eid", "ape"] ): # without .weight key = name.split(".")[-1] else: key = name.split(".")[-2] if key in mapping: new_key, dim = mapping[key] else: new_key, dim = key, None name = name.replace(key, new_key) shape, _dtype = _tensor_header(f, raw_name) for i in range(mp): if "experts" in name and "shared_experts" not in name: idx = int(name.split(".")[-3]) if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: continue elif dim is not None: assert ( shape[dim] % mp == 0 ), f"Dimension {dim} must be divisible by {mp} for {name!r}" per_shard[i].add(name) return [_final_save_keys(s) for s in per_shard] def _final_save_keys(keys: set[str]) -> list[str]: """ After the original second pass, the only removed keys are the wo_a.scale pairs merged into wo_a.weight; other rewrites keep the same key names. """ s = set(keys) for k in list(s): if k.endswith("wo_a.weight"): s.discard(k.replace("weight", "scale")) return sorted(s) def main( hf_ckpt_path: str, n_experts: int, mp: int, as_json: bool, ): per_shard = collect_save_keys(hf_ckpt_path, n_experts, mp) if as_json: print( json.dumps( {f"model{i}-mp{mp}": per_shard[i] for i in range(mp)}, indent=2, ensure_ascii=False, ) ) else: for i, keys in enumerate(per_shard): print(f"=== model{i}-mp{mp} ({len(keys)} keys) ===") for k in keys: print(k) if __name__ == "__main__": parser = ArgumentParser( description="List target safetensors key names (no tensor load/save).", ) parser.add_argument("--hf-ckpt-path", type=str, required=True) parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) parser.add_argument( "--json", action="store_true", dest="as_json", help='print one JSON object: {"model0-mpK": [...], ...}', ) args = parser.parse_args() assert args.n_experts % args.model_parallel == 0, ( "Number of experts must be divisible by model parallelism" ) main( args.hf_ckpt_path, args.n_experts, args.model_parallel, args.as_json, )