File size: 5,089 Bytes
e44adc3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import json
import os
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
from safetensors import safe_open
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"lm_head": ("head", 0),
"embed": ("embed", 0),
"wq_b": ("wq_b", 0),
"wo_a": ("wo_a", 0),
"wo_b": ("wo_b", 1),
"head": ("head", 0),
"attn_sink": ("attn_sink", 0),
"weights_proj": ("weights_proj", 0),
}
def _tensor_header(f, name: str):
"""Shape + dtype from file header (no full tensor read)."""
sl = f.get_slice(name)
return sl.get_shape(), sl.get_dtype()
def collect_save_keys(
hf_ckpt_path: str,
n_experts: int,
mp: int,
) -> list[list[str]]:
"""
Returns, for each parallel shard, the sorted list of key names that
`save_file` would write (same naming as the original convert, without
loading tensor payloads).
"""
n_local_experts = n_experts // mp
per_shard: list[set[str]] = [set() for _ in range(mp)]
files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
if not files:
raise FileNotFoundError(f"no *.safetensors under {hf_ckpt_path!r}")
for file_path in tqdm(files, desc="keys"):
with safe_open(file_path, framework="pt", device="cpu") as f:
for raw_name in f.keys():
name = raw_name
if name.startswith("model."):
name = name[len("model.") :]
if name.startswith("mtp.") and (
"emb" in name or name.endswith("head.weight")
):
continue
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
if any(
x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]
): # without .weight
key = name.split(".")[-1]
else:
key = name.split(".")[-2]
if key in mapping:
new_key, dim = mapping[key]
else:
new_key, dim = key, None
name = name.replace(key, new_key)
shape, _dtype = _tensor_header(f, raw_name)
for i in range(mp):
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert (
shape[dim] % mp == 0
), f"Dimension {dim} must be divisible by {mp} for {name!r}"
per_shard[i].add(name)
return [_final_save_keys(s) for s in per_shard]
def _final_save_keys(keys: set[str]) -> list[str]:
"""
After the original second pass, the only removed keys are the wo_a.scale
pairs merged into wo_a.weight; other rewrites keep the same key names.
"""
s = set(keys)
for k in list(s):
if k.endswith("wo_a.weight"):
s.discard(k.replace("weight", "scale"))
return sorted(s)
def main(
hf_ckpt_path: str,
n_experts: int,
mp: int,
as_json: bool,
):
per_shard = collect_save_keys(hf_ckpt_path, n_experts, mp)
if as_json:
print(
json.dumps(
{f"model{i}-mp{mp}": per_shard[i] for i in range(mp)},
indent=2,
ensure_ascii=False,
)
)
else:
for i, keys in enumerate(per_shard):
print(f"=== model{i}-mp{mp} ({len(keys)} keys) ===")
for k in keys:
print(k)
if __name__ == "__main__":
parser = ArgumentParser(
description="List target safetensors key names (no tensor load/save).",
)
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
parser.add_argument(
"--json",
action="store_true",
dest="as_json",
help='print one JSON object: {"model0-mpK": [...], ...}',
)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, (
"Number of experts must be divisible by model parallelism"
)
main(
args.hf_ckpt_path,
args.n_experts,
args.model_parallel,
args.as_json,
)
|