| import json |
| import os |
| from argparse import ArgumentParser |
| from glob import glob |
|
|
| from tqdm import tqdm |
| from safetensors import safe_open |
|
|
|
|
| mapping = { |
| "embed_tokens": ("embed", 0), |
| "input_layernorm": ("attn_norm", None), |
| "post_attention_layernorm": ("ffn_norm", None), |
| "q_proj": ("wq", 0), |
| "q_a_proj": ("wq_a", None), |
| "q_a_layernorm": ("q_norm", None), |
| "q_b_proj": ("wq_b", 0), |
| "kv_a_proj_with_mqa": ("wkv_a", None), |
| "kv_a_layernorm": ("kv_norm", None), |
| "kv_b_proj": ("wkv_b", 0), |
| "o_proj": ("wo", 1), |
| "gate_proj": ("w1", 0), |
| "down_proj": ("w2", 1), |
| "up_proj": ("w3", 0), |
| "lm_head": ("head", 0), |
| "embed": ("embed", 0), |
| "wq_b": ("wq_b", 0), |
| "wo_a": ("wo_a", 0), |
| "wo_b": ("wo_b", 1), |
| "head": ("head", 0), |
| "attn_sink": ("attn_sink", 0), |
| "weights_proj": ("weights_proj", 0), |
| } |
|
|
|
|
| def _tensor_header(f, name: str): |
| """Shape + dtype from file header (no full tensor read).""" |
| sl = f.get_slice(name) |
| return sl.get_shape(), sl.get_dtype() |
|
|
|
|
| def collect_save_keys( |
| hf_ckpt_path: str, |
| n_experts: int, |
| mp: int, |
| ) -> list[list[str]]: |
| """ |
| Returns, for each parallel shard, the sorted list of key names that |
| `save_file` would write (same naming as the original convert, without |
| loading tensor payloads). |
| """ |
| n_local_experts = n_experts // mp |
| per_shard: list[set[str]] = [set() for _ in range(mp)] |
|
|
| files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors"))) |
| if not files: |
| raise FileNotFoundError(f"no *.safetensors under {hf_ckpt_path!r}") |
|
|
| for file_path in tqdm(files, desc="keys"): |
| with safe_open(file_path, framework="pt", device="cpu") as f: |
| for raw_name in f.keys(): |
| name = raw_name |
| if name.startswith("model."): |
| name = name[len("model.") :] |
| if name.startswith("mtp.") and ( |
| "emb" in name or name.endswith("head.weight") |
| ): |
| continue |
| name = name.replace("self_attn", "attn") |
| name = name.replace("mlp", "ffn") |
| name = name.replace("weight_scale_inv", "scale") |
| name = name.replace("e_score_correction_bias", "bias") |
| if any( |
| x in name for x in ["hc", "attn_sink", "tie2eid", "ape"] |
| ): |
| key = name.split(".")[-1] |
| else: |
| key = name.split(".")[-2] |
| if key in mapping: |
| new_key, dim = mapping[key] |
| else: |
| new_key, dim = key, None |
| name = name.replace(key, new_key) |
|
|
| shape, _dtype = _tensor_header(f, raw_name) |
|
|
| for i in range(mp): |
| if "experts" in name and "shared_experts" not in name: |
| idx = int(name.split(".")[-3]) |
| if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: |
| continue |
| elif dim is not None: |
| assert ( |
| shape[dim] % mp == 0 |
| ), f"Dimension {dim} must be divisible by {mp} for {name!r}" |
| per_shard[i].add(name) |
|
|
| return [_final_save_keys(s) for s in per_shard] |
|
|
|
|
| def _final_save_keys(keys: set[str]) -> list[str]: |
| """ |
| After the original second pass, the only removed keys are the wo_a.scale |
| pairs merged into wo_a.weight; other rewrites keep the same key names. |
| """ |
| s = set(keys) |
| for k in list(s): |
| if k.endswith("wo_a.weight"): |
| s.discard(k.replace("weight", "scale")) |
| return sorted(s) |
|
|
|
|
| def main( |
| hf_ckpt_path: str, |
| n_experts: int, |
| mp: int, |
| as_json: bool, |
| ): |
| per_shard = collect_save_keys(hf_ckpt_path, n_experts, mp) |
|
|
| if as_json: |
| print( |
| json.dumps( |
| {f"model{i}-mp{mp}": per_shard[i] for i in range(mp)}, |
| indent=2, |
| ensure_ascii=False, |
| ) |
| ) |
| else: |
| for i, keys in enumerate(per_shard): |
| print(f"=== model{i}-mp{mp} ({len(keys)} keys) ===") |
| for k in keys: |
| print(k) |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser( |
| description="List target safetensors key names (no tensor load/save).", |
| ) |
| parser.add_argument("--hf-ckpt-path", type=str, required=True) |
| parser.add_argument("--n-experts", type=int, required=True) |
| parser.add_argument("--model-parallel", type=int, required=True) |
| parser.add_argument( |
| "--json", |
| action="store_true", |
| dest="as_json", |
| help='print one JSON object: {"model0-mpK": [...], ...}', |
| ) |
| args = parser.parse_args() |
| assert args.n_experts % args.model_parallel == 0, ( |
| "Number of experts must be divisible by model parallelism" |
| ) |
| main( |
| args.hf_ckpt_path, |
| args.n_experts, |
| args.model_parallel, |
| args.as_json, |
| ) |
|
|