deepseek-v4-tiny-random / export_state_dict_from_config.py
yujiepan's picture
Upload folder using huggingface_hub
e44adc3 verified
#!/usr/bin/env python3
"""
从 HuggingFace 风格的 config.json 构建 inference/model.py 中的 Transformer,
导出 state_dict,键名与仓库内 model.safetensors.index.json / info.html 中的 Metadata 列一致
(例如 embed.weight、layers.0.attn.wq_a.weight、layers.0.ffn.shared_experts.w1.weight,
无 transformers 的 model. 前缀)。
分片保存:指定 --output-dir 时,按 --index-json(默认仓库根目录 model.safetensors.index.json)
的 weight_map 将张量写入对应 model-XXXX-of-YYYY.safetensors,并生成同目录下的
model.safetensors.index.json(键名与参考索引一致,分片文件名与参考索引一致)。
用法示例:
python export_state_dict_from_config.py --config config.json --output /tmp/out.safetensors
python export_state_dict_from_config.py --config config.json --output-dir /tmp/shards
完整 max_position_embeddings 会在每层分配 RoPE 缓冲,内存很大;默认将 max_seq_len 限制为
65536,可用 --max-seq-len 覆盖。state_dict 不包含 persistent=False 的 buffer,与官方分片键集合对齐。
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent
INFERENCE_DIR = REPO_ROOT / "inference"
sys.path.insert(0, str(INFERENCE_DIR))
from model import ModelArgs, Transformer # noqa: E402
def without_mtp_tied_aliases(sd: dict[str, Any]) -> dict[str, Any]:
"""MTP 的 embed/head 与主模块共享参数,safetensors 不能重复保存这些别名。"""
out: dict[str, Any] = {}
for k, v in sd.items():
parts = k.split(".")
if len(parts) >= 4 and parts[0] == "mtp" and parts[2] in {"embed", "head"}:
continue
out[k] = v
return out
def _load_index_weight_map(index_path: Path) -> dict[str, str]:
with open(index_path, encoding="utf-8") as f:
idx = json.load(f)
wm = idx.get("weight_map")
if not isinstance(wm, dict):
raise SystemExit(f"{index_path} 缺少 weight_map")
return dict(wm)
def save_sharded_index_style(
sd_inference: dict[str, Any],
weight_map: dict[str, str],
out_dir: Path,
) -> None:
"""按 weight_map 中的分片文件名写入多个 .safetensors,并写 model.safetensors.index.json。"""
from safetensors.torch import save_file
out_dir.mkdir(parents=True, exist_ok=True)
shard_to_tensors: dict[str, dict[str, Any]] = defaultdict(dict)
new_weight_map: dict[str, str] = {}
for inf_key, shard_file in weight_map.items():
if inf_key not in sd_inference:
continue
shard_to_tensors[shard_file][inf_key] = sd_inference[inf_key]
new_weight_map[inf_key] = shard_file
if not new_weight_map:
raise SystemExit("weight_map 与当前 state_dict 无交集,无法分片写出")
total_size = 0
for shard_file, part in shard_to_tensors.items():
for t in part.values():
total_size += int(t.numel()) * int(t.element_size())
save_file(part, str(out_dir / shard_file))
index_out = {
"metadata": {"total_size": total_size},
"weight_map": new_weight_map,
}
with open(out_dir / "model.safetensors.index.json", "w", encoding="utf-8") as f:
json.dump(index_out, f, indent=2)
f.write("\n")
def _is_hf_deepseek_config(cfg: dict[str, Any]) -> bool:
return cfg.get("model_type") == "deepseek_v4" or "num_hidden_layers" in cfg
def hf_config_to_model_args(
cfg: dict[str, Any],
*,
max_batch_size: int,
max_seq_len: int,
) -> ModelArgs:
"""将仓库根目录下 HuggingFace 的 config.json 转为 inference.ModelArgs。"""
rope = cfg.get("rope_scaling") or {}
qc = cfg.get("quantization_config") or {}
use_fp8 = qc.get("quant_method") == "fp8" or str(qc.get("fmt", "")).lower() in (
"e4m3",
"e4m3fn",
"fp8",
)
n_layers = int(cfg["num_hidden_layers"])
ratios = list(cfg.get("compress_ratios") or ())
compress_ratios = tuple(int(x) for x in ratios)
return ModelArgs(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
dtype="fp8" if use_fp8 else "bf16",
scale_fmt=qc.get("scale_fmt") or cfg.get("scale_fmt") or "ue8m0",
expert_dtype=cfg.get("expert_dtype"),
scale_dtype=cfg.get("scale_dtype") or ("fp8" if use_fp8 else "fp32"),
vocab_size=int(cfg["vocab_size"]),
dim=int(cfg["hidden_size"]),
moe_inter_dim=int(cfg["moe_intermediate_size"]),
n_layers=n_layers,
n_hash_layers=int(cfg.get("num_hash_layers", 0)),
n_mtp_layers=int(cfg.get("num_nextn_predict_layers", 0)),
n_heads=int(cfg["num_attention_heads"]),
n_routed_experts=int(cfg["n_routed_experts"]),
n_shared_experts=int(cfg.get("n_shared_experts", 1)),
n_activated_experts=int(cfg["num_experts_per_tok"]),
score_func=str(cfg.get("scoring_func", cfg.get(
"score_func", "sqrtsoftplus"))),
route_scale=float(cfg.get("routed_scaling_factor",
cfg.get("route_scale", 1.0))),
swiglu_limit=float(cfg.get("swiglu_limit", 0.0)),
q_lora_rank=int(cfg["q_lora_rank"]),
head_dim=int(cfg["head_dim"]),
rope_head_dim=int(cfg["qk_rope_head_dim"]),
norm_eps=float(cfg.get("rms_norm_eps", cfg.get("norm_eps", 1e-6))),
o_groups=int(cfg["o_groups"]),
o_lora_rank=int(cfg["o_lora_rank"]),
window_size=int(cfg["sliding_window"]),
compress_ratios=compress_ratios,
compress_rope_theta=float(cfg.get("compress_rope_theta", 160000.0)),
original_seq_len=int(
rope.get(
"original_max_position_embeddings",
cfg.get("original_seq_len", 0),
)
),
rope_theta=float(cfg.get("rope_theta", 10000.0)),
rope_factor=float(rope.get("factor", cfg.get("rope_factor", 1.0))),
beta_fast=int(rope.get("beta_fast", cfg.get("beta_fast", 32))),
beta_slow=int(rope.get("beta_slow", cfg.get("beta_slow", 1))),
index_n_heads=int(cfg["index_n_heads"]),
index_head_dim=int(cfg["index_head_dim"]),
index_topk=int(cfg["index_topk"]),
hc_mult=int(cfg["hc_mult"]),
hc_sinkhorn_iters=int(cfg["hc_sinkhorn_iters"]),
hc_eps=float(cfg.get("hc_eps", 1e-6)),
)
def load_model_args(
config_path: Path,
*,
max_batch_size: int,
max_seq_len: int | None,
cap_seq_len: int,
) -> ModelArgs:
with open(config_path, encoding="utf-8") as f:
raw = json.load(f)
if _is_hf_deepseek_config(raw):
cfg_max = int(raw.get("max_position_embeddings", cap_seq_len))
mseq = max_seq_len if max_seq_len is not None else min(
cfg_max, cap_seq_len)
return hf_config_to_model_args(
raw, max_batch_size=max_batch_size, max_seq_len=mseq
)
return ModelArgs(**raw)
def optional_prefix_keys(
sd: dict[str, Any], prefix: str | None
) -> dict[str, Any]:
if not prefix:
return sd
p = prefix.rstrip(".") + "."
return {p + k: v for k, v in sd.items()}
def validate_against_index(
keys: set[str], index_path: Path
) -> tuple[set[str], set[str]]:
with open(index_path, encoding="utf-8") as f:
idx = json.load(f)
ref = set(idx.get("weight_map", {}).keys())
missing = ref - keys
extra = keys - ref
return missing, extra
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--config",
type=Path,
default=REPO_ROOT / "config.json",
help="config.json(HF 或 inference 格式)",
)
out_group = ap.add_mutually_exclusive_group(required=True)
out_group.add_argument(
"--output",
type=Path,
help="输出单个 .safetensors 或 .pt / .pth",
)
out_group.add_argument(
"--output-dir",
type=Path,
default=None,
help="按 index 的 weight_map 分片写入该目录(仅 safetensors + model.safetensors.index.json)",
)
ap.add_argument("--device", type=str, default="cpu",
help="cpu / cuda / meta")
ap.add_argument("--max-batch-size", type=int, default=4)
ap.add_argument(
"--max-seq-len",
type=int,
default=None,
help="覆盖 ModelArgs.max_seq_len;默认 min(max_position_embeddings, --cap-seq-len)",
)
ap.add_argument(
"--cap-seq-len",
type=int,
default=65536,
help="HF 配置下默认 max_seq_len 上限,避免每层分配过大 RoPE 缓冲",
)
ap.add_argument(
"--prefix",
type=str,
default="",
help='为输出键名加此前缀(如 "model.");默认与 index 一致,不加前缀',
)
ap.add_argument(
"--index-json",
type=Path,
default=None,
help="校验或分片布局:默认在 --output-dir 时为仓库根 model.safetensors.index.json",
)
ap.add_argument("--strict-index", action="store_true",
help="与 --index-json 联用,要求键完全一致")
args = ap.parse_args()
import torch
torch.set_default_dtype(torch.bfloat16)
margs = load_model_args(
args.config,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
cap_seq_len=args.cap_seq_len,
)
with open(REPO_ROOT / "ds_config.json", "w", encoding="utf-8") as f:
json.dump(margs.__dict__, f, indent=2)
dev = torch.device(args.device)
with torch.device(dev):
from transformers import set_seed
set_seed(42)
model = Transformer(margs)
n_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {n_params}")
for k, v in model.named_modules():
if k.count('.') <= 3:
n_params_k = sum(p.numel() for p in v.parameters())
print(k, f"{n_params_k} {n_params_k / n_params:.2%}")
if dev.type != "meta":
model = model.to(dev)
for k, v in model.named_parameters():
if 'norm.weight' in k or k == "head.weight":
v.data = v.data.to(torch.bfloat16)
sd_inference = without_mtp_tied_aliases(model.state_dict())
index_path = args.index_json
if args.output_dir is not None and index_path is None:
index_path = REPO_ROOT / "model.safetensors.index.json"
if index_path is not None:
missing, extra = validate_against_index(
set(sd_inference.keys()), index_path)
if missing or extra:
msg = f"index 对比: missing={len(missing)} extra={len(extra)}"
if missing and len(missing) <= 20:
msg += f"\n missing 样例: {sorted(missing)[:20]}"
elif missing:
msg += f"\n missing 样例: {sorted(missing)[:5]} ..."
if extra and len(extra) <= 20:
msg += f"\n extra 样例: {sorted(extra)[:20]}"
elif extra:
msg += f"\n extra 样例: {sorted(extra)[:5]} ..."
if args.strict_index:
raise SystemExit(msg)
print(msg, file=sys.stderr)
if args.output_dir is not None:
if dev.type == "meta":
raise SystemExit("device=meta 时无法写 safetensors,请改用 cpu/cuda")
wm = _load_index_weight_map(index_path) # type: ignore[arg-type]
save_sharded_index_style(
sd_inference,
wm,
args.output_dir,
)
written = sum(1 for k in wm if k in sd_inference)
n_shards = len({wm[k] for k in wm if k in sd_inference})
print(
f"Wrote {written} tensors in {n_shards} shard files under {args.output_dir}"
)
return
sd = optional_prefix_keys(sd_inference, args.prefix or None)
out = args.output
out.parent.mkdir(parents=True, exist_ok=True)
suffix = out.suffix.lower()
for k, v in sd.items():
print(k, v.shape, v.dtype)
if suffix == ".safetensors":
from safetensors.torch import save_file
# meta 张量无法写入 safetensors
if dev.type == "meta":
raise SystemExit(
"device=meta 时无法写 safetensors,请改用 cpu/cuda 或输出 .pt")
save_file(sd, str(out))
else:
if dev.type == "meta":
torch.save(sd, out)
else:
torch.save(sd, out)
print(f"Wrote {len(sd)} tensors to {out}")
if __name__ == "__main__":
main()