File size: 12,650 Bytes
e44adc3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 | #!/usr/bin/env python3
"""
从 HuggingFace 风格的 config.json 构建 inference/model.py 中的 Transformer,
导出 state_dict,键名与仓库内 model.safetensors.index.json / info.html 中的 Metadata 列一致
(例如 embed.weight、layers.0.attn.wq_a.weight、layers.0.ffn.shared_experts.w1.weight,
无 transformers 的 model. 前缀)。
分片保存:指定 --output-dir 时,按 --index-json(默认仓库根目录 model.safetensors.index.json)
的 weight_map 将张量写入对应 model-XXXX-of-YYYY.safetensors,并生成同目录下的
model.safetensors.index.json(键名与参考索引一致,分片文件名与参考索引一致)。
用法示例:
python export_state_dict_from_config.py --config config.json --output /tmp/out.safetensors
python export_state_dict_from_config.py --config config.json --output-dir /tmp/shards
完整 max_position_embeddings 会在每层分配 RoPE 缓冲,内存很大;默认将 max_seq_len 限制为
65536,可用 --max-seq-len 覆盖。state_dict 不包含 persistent=False 的 buffer,与官方分片键集合对齐。
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent
INFERENCE_DIR = REPO_ROOT / "inference"
sys.path.insert(0, str(INFERENCE_DIR))
from model import ModelArgs, Transformer # noqa: E402
def without_mtp_tied_aliases(sd: dict[str, Any]) -> dict[str, Any]:
"""MTP 的 embed/head 与主模块共享参数,safetensors 不能重复保存这些别名。"""
out: dict[str, Any] = {}
for k, v in sd.items():
parts = k.split(".")
if len(parts) >= 4 and parts[0] == "mtp" and parts[2] in {"embed", "head"}:
continue
out[k] = v
return out
def _load_index_weight_map(index_path: Path) -> dict[str, str]:
with open(index_path, encoding="utf-8") as f:
idx = json.load(f)
wm = idx.get("weight_map")
if not isinstance(wm, dict):
raise SystemExit(f"{index_path} 缺少 weight_map")
return dict(wm)
def save_sharded_index_style(
sd_inference: dict[str, Any],
weight_map: dict[str, str],
out_dir: Path,
) -> None:
"""按 weight_map 中的分片文件名写入多个 .safetensors,并写 model.safetensors.index.json。"""
from safetensors.torch import save_file
out_dir.mkdir(parents=True, exist_ok=True)
shard_to_tensors: dict[str, dict[str, Any]] = defaultdict(dict)
new_weight_map: dict[str, str] = {}
for inf_key, shard_file in weight_map.items():
if inf_key not in sd_inference:
continue
shard_to_tensors[shard_file][inf_key] = sd_inference[inf_key]
new_weight_map[inf_key] = shard_file
if not new_weight_map:
raise SystemExit("weight_map 与当前 state_dict 无交集,无法分片写出")
total_size = 0
for shard_file, part in shard_to_tensors.items():
for t in part.values():
total_size += int(t.numel()) * int(t.element_size())
save_file(part, str(out_dir / shard_file))
index_out = {
"metadata": {"total_size": total_size},
"weight_map": new_weight_map,
}
with open(out_dir / "model.safetensors.index.json", "w", encoding="utf-8") as f:
json.dump(index_out, f, indent=2)
f.write("\n")
def _is_hf_deepseek_config(cfg: dict[str, Any]) -> bool:
return cfg.get("model_type") == "deepseek_v4" or "num_hidden_layers" in cfg
def hf_config_to_model_args(
cfg: dict[str, Any],
*,
max_batch_size: int,
max_seq_len: int,
) -> ModelArgs:
"""将仓库根目录下 HuggingFace 的 config.json 转为 inference.ModelArgs。"""
rope = cfg.get("rope_scaling") or {}
qc = cfg.get("quantization_config") or {}
use_fp8 = qc.get("quant_method") == "fp8" or str(qc.get("fmt", "")).lower() in (
"e4m3",
"e4m3fn",
"fp8",
)
n_layers = int(cfg["num_hidden_layers"])
ratios = list(cfg.get("compress_ratios") or ())
compress_ratios = tuple(int(x) for x in ratios)
return ModelArgs(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
dtype="fp8" if use_fp8 else "bf16",
scale_fmt=qc.get("scale_fmt") or cfg.get("scale_fmt") or "ue8m0",
expert_dtype=cfg.get("expert_dtype"),
scale_dtype=cfg.get("scale_dtype") or ("fp8" if use_fp8 else "fp32"),
vocab_size=int(cfg["vocab_size"]),
dim=int(cfg["hidden_size"]),
moe_inter_dim=int(cfg["moe_intermediate_size"]),
n_layers=n_layers,
n_hash_layers=int(cfg.get("num_hash_layers", 0)),
n_mtp_layers=int(cfg.get("num_nextn_predict_layers", 0)),
n_heads=int(cfg["num_attention_heads"]),
n_routed_experts=int(cfg["n_routed_experts"]),
n_shared_experts=int(cfg.get("n_shared_experts", 1)),
n_activated_experts=int(cfg["num_experts_per_tok"]),
score_func=str(cfg.get("scoring_func", cfg.get(
"score_func", "sqrtsoftplus"))),
route_scale=float(cfg.get("routed_scaling_factor",
cfg.get("route_scale", 1.0))),
swiglu_limit=float(cfg.get("swiglu_limit", 0.0)),
q_lora_rank=int(cfg["q_lora_rank"]),
head_dim=int(cfg["head_dim"]),
rope_head_dim=int(cfg["qk_rope_head_dim"]),
norm_eps=float(cfg.get("rms_norm_eps", cfg.get("norm_eps", 1e-6))),
o_groups=int(cfg["o_groups"]),
o_lora_rank=int(cfg["o_lora_rank"]),
window_size=int(cfg["sliding_window"]),
compress_ratios=compress_ratios,
compress_rope_theta=float(cfg.get("compress_rope_theta", 160000.0)),
original_seq_len=int(
rope.get(
"original_max_position_embeddings",
cfg.get("original_seq_len", 0),
)
),
rope_theta=float(cfg.get("rope_theta", 10000.0)),
rope_factor=float(rope.get("factor", cfg.get("rope_factor", 1.0))),
beta_fast=int(rope.get("beta_fast", cfg.get("beta_fast", 32))),
beta_slow=int(rope.get("beta_slow", cfg.get("beta_slow", 1))),
index_n_heads=int(cfg["index_n_heads"]),
index_head_dim=int(cfg["index_head_dim"]),
index_topk=int(cfg["index_topk"]),
hc_mult=int(cfg["hc_mult"]),
hc_sinkhorn_iters=int(cfg["hc_sinkhorn_iters"]),
hc_eps=float(cfg.get("hc_eps", 1e-6)),
)
def load_model_args(
config_path: Path,
*,
max_batch_size: int,
max_seq_len: int | None,
cap_seq_len: int,
) -> ModelArgs:
with open(config_path, encoding="utf-8") as f:
raw = json.load(f)
if _is_hf_deepseek_config(raw):
cfg_max = int(raw.get("max_position_embeddings", cap_seq_len))
mseq = max_seq_len if max_seq_len is not None else min(
cfg_max, cap_seq_len)
return hf_config_to_model_args(
raw, max_batch_size=max_batch_size, max_seq_len=mseq
)
return ModelArgs(**raw)
def optional_prefix_keys(
sd: dict[str, Any], prefix: str | None
) -> dict[str, Any]:
if not prefix:
return sd
p = prefix.rstrip(".") + "."
return {p + k: v for k, v in sd.items()}
def validate_against_index(
keys: set[str], index_path: Path
) -> tuple[set[str], set[str]]:
with open(index_path, encoding="utf-8") as f:
idx = json.load(f)
ref = set(idx.get("weight_map", {}).keys())
missing = ref - keys
extra = keys - ref
return missing, extra
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--config",
type=Path,
default=REPO_ROOT / "config.json",
help="config.json(HF 或 inference 格式)",
)
out_group = ap.add_mutually_exclusive_group(required=True)
out_group.add_argument(
"--output",
type=Path,
help="输出单个 .safetensors 或 .pt / .pth",
)
out_group.add_argument(
"--output-dir",
type=Path,
default=None,
help="按 index 的 weight_map 分片写入该目录(仅 safetensors + model.safetensors.index.json)",
)
ap.add_argument("--device", type=str, default="cpu",
help="cpu / cuda / meta")
ap.add_argument("--max-batch-size", type=int, default=4)
ap.add_argument(
"--max-seq-len",
type=int,
default=None,
help="覆盖 ModelArgs.max_seq_len;默认 min(max_position_embeddings, --cap-seq-len)",
)
ap.add_argument(
"--cap-seq-len",
type=int,
default=65536,
help="HF 配置下默认 max_seq_len 上限,避免每层分配过大 RoPE 缓冲",
)
ap.add_argument(
"--prefix",
type=str,
default="",
help='为输出键名加此前缀(如 "model.");默认与 index 一致,不加前缀',
)
ap.add_argument(
"--index-json",
type=Path,
default=None,
help="校验或分片布局:默认在 --output-dir 时为仓库根 model.safetensors.index.json",
)
ap.add_argument("--strict-index", action="store_true",
help="与 --index-json 联用,要求键完全一致")
args = ap.parse_args()
import torch
torch.set_default_dtype(torch.bfloat16)
margs = load_model_args(
args.config,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
cap_seq_len=args.cap_seq_len,
)
with open(REPO_ROOT / "ds_config.json", "w", encoding="utf-8") as f:
json.dump(margs.__dict__, f, indent=2)
dev = torch.device(args.device)
with torch.device(dev):
from transformers import set_seed
set_seed(42)
model = Transformer(margs)
n_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {n_params}")
for k, v in model.named_modules():
if k.count('.') <= 3:
n_params_k = sum(p.numel() for p in v.parameters())
print(k, f"{n_params_k} {n_params_k / n_params:.2%}")
if dev.type != "meta":
model = model.to(dev)
for k, v in model.named_parameters():
if 'norm.weight' in k or k == "head.weight":
v.data = v.data.to(torch.bfloat16)
sd_inference = without_mtp_tied_aliases(model.state_dict())
index_path = args.index_json
if args.output_dir is not None and index_path is None:
index_path = REPO_ROOT / "model.safetensors.index.json"
if index_path is not None:
missing, extra = validate_against_index(
set(sd_inference.keys()), index_path)
if missing or extra:
msg = f"index 对比: missing={len(missing)} extra={len(extra)}"
if missing and len(missing) <= 20:
msg += f"\n missing 样例: {sorted(missing)[:20]}"
elif missing:
msg += f"\n missing 样例: {sorted(missing)[:5]} ..."
if extra and len(extra) <= 20:
msg += f"\n extra 样例: {sorted(extra)[:20]}"
elif extra:
msg += f"\n extra 样例: {sorted(extra)[:5]} ..."
if args.strict_index:
raise SystemExit(msg)
print(msg, file=sys.stderr)
if args.output_dir is not None:
if dev.type == "meta":
raise SystemExit("device=meta 时无法写 safetensors,请改用 cpu/cuda")
wm = _load_index_weight_map(index_path) # type: ignore[arg-type]
save_sharded_index_style(
sd_inference,
wm,
args.output_dir,
)
written = sum(1 for k in wm if k in sd_inference)
n_shards = len({wm[k] for k in wm if k in sd_inference})
print(
f"Wrote {written} tensors in {n_shards} shard files under {args.output_dir}"
)
return
sd = optional_prefix_keys(sd_inference, args.prefix or None)
out = args.output
out.parent.mkdir(parents=True, exist_ok=True)
suffix = out.suffix.lower()
for k, v in sd.items():
print(k, v.shape, v.dtype)
if suffix == ".safetensors":
from safetensors.torch import save_file
# meta 张量无法写入 safetensors
if dev.type == "meta":
raise SystemExit(
"device=meta 时无法写 safetensors,请改用 cpu/cuda 或输出 .pt")
save_file(sd, str(out))
else:
if dev.type == "meta":
torch.save(sd, out)
else:
torch.save(sd, out)
print(f"Wrote {len(sd)} tensors to {out}")
if __name__ == "__main__":
main()
|