File size: 5,089 Bytes
e44adc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
import os
from argparse import ArgumentParser
from glob import glob

from tqdm import tqdm
from safetensors import safe_open


mapping = {
    "embed_tokens": ("embed", 0),
    "input_layernorm": ("attn_norm", None),
    "post_attention_layernorm": ("ffn_norm", None),
    "q_proj": ("wq", 0),
    "q_a_proj": ("wq_a", None),
    "q_a_layernorm": ("q_norm", None),
    "q_b_proj": ("wq_b", 0),
    "kv_a_proj_with_mqa": ("wkv_a", None),
    "kv_a_layernorm": ("kv_norm", None),
    "kv_b_proj": ("wkv_b", 0),
    "o_proj": ("wo", 1),
    "gate_proj": ("w1", 0),
    "down_proj": ("w2", 1),
    "up_proj": ("w3", 0),
    "lm_head": ("head", 0),
    "embed": ("embed", 0),
    "wq_b": ("wq_b", 0),
    "wo_a": ("wo_a", 0),
    "wo_b": ("wo_b", 1),
    "head": ("head", 0),
    "attn_sink": ("attn_sink", 0),
    "weights_proj": ("weights_proj", 0),
}


def _tensor_header(f, name: str):
    """Shape + dtype from file header (no full tensor read)."""
    sl = f.get_slice(name)
    return sl.get_shape(), sl.get_dtype()


def collect_save_keys(
    hf_ckpt_path: str,
    n_experts: int,
    mp: int,
) -> list[list[str]]:
    """
    Returns, for each parallel shard, the sorted list of key names that
    `save_file` would write (same naming as the original convert, without
    loading tensor payloads).
    """
    n_local_experts = n_experts // mp
    per_shard: list[set[str]] = [set() for _ in range(mp)]

    files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
    if not files:
        raise FileNotFoundError(f"no *.safetensors under {hf_ckpt_path!r}")

    for file_path in tqdm(files, desc="keys"):
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for raw_name in f.keys():
                name = raw_name
                if name.startswith("model."):
                    name = name[len("model.") :]
                if name.startswith("mtp.") and (
                    "emb" in name or name.endswith("head.weight")
                ):
                    continue
                name = name.replace("self_attn", "attn")
                name = name.replace("mlp", "ffn")
                name = name.replace("weight_scale_inv", "scale")
                name = name.replace("e_score_correction_bias", "bias")
                if any(
                    x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]
                ):  # without .weight
                    key = name.split(".")[-1]
                else:
                    key = name.split(".")[-2]
                if key in mapping:
                    new_key, dim = mapping[key]
                else:
                    new_key, dim = key, None
                name = name.replace(key, new_key)

                shape, _dtype = _tensor_header(f, raw_name)

                for i in range(mp):
                    if "experts" in name and "shared_experts" not in name:
                        idx = int(name.split(".")[-3])
                        if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
                            continue
                    elif dim is not None:
                        assert (
                            shape[dim] % mp == 0
                        ), f"Dimension {dim} must be divisible by {mp} for {name!r}"
                    per_shard[i].add(name)

    return [_final_save_keys(s) for s in per_shard]


def _final_save_keys(keys: set[str]) -> list[str]:
    """
    After the original second pass, the only removed keys are the wo_a.scale
    pairs merged into wo_a.weight; other rewrites keep the same key names.
    """
    s = set(keys)
    for k in list(s):
        if k.endswith("wo_a.weight"):
            s.discard(k.replace("weight", "scale"))
    return sorted(s)


def main(
    hf_ckpt_path: str,
    n_experts: int,
    mp: int,
    as_json: bool,
):
    per_shard = collect_save_keys(hf_ckpt_path, n_experts, mp)

    if as_json:
        print(
            json.dumps(
                {f"model{i}-mp{mp}": per_shard[i] for i in range(mp)},
                indent=2,
                ensure_ascii=False,
            )
        )
    else:
        for i, keys in enumerate(per_shard):
            print(f"=== model{i}-mp{mp} ({len(keys)} keys) ===")
            for k in keys:
                print(k)


if __name__ == "__main__":
    parser = ArgumentParser(
        description="List target safetensors key names (no tensor load/save).",
    )
    parser.add_argument("--hf-ckpt-path", type=str, required=True)
    parser.add_argument("--n-experts", type=int, required=True)
    parser.add_argument("--model-parallel", type=int, required=True)
    parser.add_argument(
        "--json",
        action="store_true",
        dest="as_json",
        help='print one JSON object: {"model0-mpK": [...], ...}',
    )
    args = parser.parse_args()
    assert args.n_experts % args.model_parallel == 0, (
        "Number of experts must be divisible by model parallelism"
    )
    main(
        args.hf_ckpt_path,
        args.n_experts,
        args.model_parallel,
        args.as_json,
    )