#!/usr/bin/env python3 """ 从 HuggingFace 风格的 config.json 构建 inference/model.py 中的 Transformer, 导出 state_dict,键名与仓库内 model.safetensors.index.json / info.html 中的 Metadata 列一致 (例如 embed.weight、layers.0.attn.wq_a.weight、layers.0.ffn.shared_experts.w1.weight, 无 transformers 的 model. 前缀)。 分片保存:指定 --output-dir 时,按 --index-json(默认仓库根目录 model.safetensors.index.json) 的 weight_map 将张量写入对应 model-XXXX-of-YYYY.safetensors,并生成同目录下的 model.safetensors.index.json(键名与参考索引一致,分片文件名与参考索引一致)。 用法示例: python export_state_dict_from_config.py --config config.json --output /tmp/out.safetensors python export_state_dict_from_config.py --config config.json --output-dir /tmp/shards 完整 max_position_embeddings 会在每层分配 RoPE 缓冲,内存很大;默认将 max_seq_len 限制为 65536,可用 --max-seq-len 覆盖。state_dict 不包含 persistent=False 的 buffer,与官方分片键集合对齐。 """ from __future__ import annotations import argparse import json import sys from collections import defaultdict from pathlib import Path from typing import Any REPO_ROOT = Path(__file__).resolve().parent INFERENCE_DIR = REPO_ROOT / "inference" sys.path.insert(0, str(INFERENCE_DIR)) from model import ModelArgs, Transformer # noqa: E402 def without_mtp_tied_aliases(sd: dict[str, Any]) -> dict[str, Any]: """MTP 的 embed/head 与主模块共享参数,safetensors 不能重复保存这些别名。""" out: dict[str, Any] = {} for k, v in sd.items(): parts = k.split(".") if len(parts) >= 4 and parts[0] == "mtp" and parts[2] in {"embed", "head"}: continue out[k] = v return out def _load_index_weight_map(index_path: Path) -> dict[str, str]: with open(index_path, encoding="utf-8") as f: idx = json.load(f) wm = idx.get("weight_map") if not isinstance(wm, dict): raise SystemExit(f"{index_path} 缺少 weight_map") return dict(wm) def save_sharded_index_style( sd_inference: dict[str, Any], weight_map: dict[str, str], out_dir: Path, ) -> None: """按 weight_map 中的分片文件名写入多个 .safetensors,并写 model.safetensors.index.json。""" from safetensors.torch import save_file out_dir.mkdir(parents=True, exist_ok=True) shard_to_tensors: dict[str, dict[str, Any]] = defaultdict(dict) new_weight_map: dict[str, str] = {} for inf_key, shard_file in weight_map.items(): if inf_key not in sd_inference: continue shard_to_tensors[shard_file][inf_key] = sd_inference[inf_key] new_weight_map[inf_key] = shard_file if not new_weight_map: raise SystemExit("weight_map 与当前 state_dict 无交集,无法分片写出") total_size = 0 for shard_file, part in shard_to_tensors.items(): for t in part.values(): total_size += int(t.numel()) * int(t.element_size()) save_file(part, str(out_dir / shard_file)) index_out = { "metadata": {"total_size": total_size}, "weight_map": new_weight_map, } with open(out_dir / "model.safetensors.index.json", "w", encoding="utf-8") as f: json.dump(index_out, f, indent=2) f.write("\n") def _is_hf_deepseek_config(cfg: dict[str, Any]) -> bool: return cfg.get("model_type") == "deepseek_v4" or "num_hidden_layers" in cfg def hf_config_to_model_args( cfg: dict[str, Any], *, max_batch_size: int, max_seq_len: int, ) -> ModelArgs: """将仓库根目录下 HuggingFace 的 config.json 转为 inference.ModelArgs。""" rope = cfg.get("rope_scaling") or {} qc = cfg.get("quantization_config") or {} use_fp8 = qc.get("quant_method") == "fp8" or str(qc.get("fmt", "")).lower() in ( "e4m3", "e4m3fn", "fp8", ) n_layers = int(cfg["num_hidden_layers"]) ratios = list(cfg.get("compress_ratios") or ()) compress_ratios = tuple(int(x) for x in ratios) return ModelArgs( max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype="fp8" if use_fp8 else "bf16", scale_fmt=qc.get("scale_fmt") or cfg.get("scale_fmt") or "ue8m0", expert_dtype=cfg.get("expert_dtype"), scale_dtype=cfg.get("scale_dtype") or ("fp8" if use_fp8 else "fp32"), vocab_size=int(cfg["vocab_size"]), dim=int(cfg["hidden_size"]), moe_inter_dim=int(cfg["moe_intermediate_size"]), n_layers=n_layers, n_hash_layers=int(cfg.get("num_hash_layers", 0)), n_mtp_layers=int(cfg.get("num_nextn_predict_layers", 0)), n_heads=int(cfg["num_attention_heads"]), n_routed_experts=int(cfg["n_routed_experts"]), n_shared_experts=int(cfg.get("n_shared_experts", 1)), n_activated_experts=int(cfg["num_experts_per_tok"]), score_func=str(cfg.get("scoring_func", cfg.get( "score_func", "sqrtsoftplus"))), route_scale=float(cfg.get("routed_scaling_factor", cfg.get("route_scale", 1.0))), swiglu_limit=float(cfg.get("swiglu_limit", 0.0)), q_lora_rank=int(cfg["q_lora_rank"]), head_dim=int(cfg["head_dim"]), rope_head_dim=int(cfg["qk_rope_head_dim"]), norm_eps=float(cfg.get("rms_norm_eps", cfg.get("norm_eps", 1e-6))), o_groups=int(cfg["o_groups"]), o_lora_rank=int(cfg["o_lora_rank"]), window_size=int(cfg["sliding_window"]), compress_ratios=compress_ratios, compress_rope_theta=float(cfg.get("compress_rope_theta", 160000.0)), original_seq_len=int( rope.get( "original_max_position_embeddings", cfg.get("original_seq_len", 0), ) ), rope_theta=float(cfg.get("rope_theta", 10000.0)), rope_factor=float(rope.get("factor", cfg.get("rope_factor", 1.0))), beta_fast=int(rope.get("beta_fast", cfg.get("beta_fast", 32))), beta_slow=int(rope.get("beta_slow", cfg.get("beta_slow", 1))), index_n_heads=int(cfg["index_n_heads"]), index_head_dim=int(cfg["index_head_dim"]), index_topk=int(cfg["index_topk"]), hc_mult=int(cfg["hc_mult"]), hc_sinkhorn_iters=int(cfg["hc_sinkhorn_iters"]), hc_eps=float(cfg.get("hc_eps", 1e-6)), ) def load_model_args( config_path: Path, *, max_batch_size: int, max_seq_len: int | None, cap_seq_len: int, ) -> ModelArgs: with open(config_path, encoding="utf-8") as f: raw = json.load(f) if _is_hf_deepseek_config(raw): cfg_max = int(raw.get("max_position_embeddings", cap_seq_len)) mseq = max_seq_len if max_seq_len is not None else min( cfg_max, cap_seq_len) return hf_config_to_model_args( raw, max_batch_size=max_batch_size, max_seq_len=mseq ) return ModelArgs(**raw) def optional_prefix_keys( sd: dict[str, Any], prefix: str | None ) -> dict[str, Any]: if not prefix: return sd p = prefix.rstrip(".") + "." return {p + k: v for k, v in sd.items()} def validate_against_index( keys: set[str], index_path: Path ) -> tuple[set[str], set[str]]: with open(index_path, encoding="utf-8") as f: idx = json.load(f) ref = set(idx.get("weight_map", {}).keys()) missing = ref - keys extra = keys - ref return missing, extra def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument( "--config", type=Path, default=REPO_ROOT / "config.json", help="config.json(HF 或 inference 格式)", ) out_group = ap.add_mutually_exclusive_group(required=True) out_group.add_argument( "--output", type=Path, help="输出单个 .safetensors 或 .pt / .pth", ) out_group.add_argument( "--output-dir", type=Path, default=None, help="按 index 的 weight_map 分片写入该目录(仅 safetensors + model.safetensors.index.json)", ) ap.add_argument("--device", type=str, default="cpu", help="cpu / cuda / meta") ap.add_argument("--max-batch-size", type=int, default=4) ap.add_argument( "--max-seq-len", type=int, default=None, help="覆盖 ModelArgs.max_seq_len;默认 min(max_position_embeddings, --cap-seq-len)", ) ap.add_argument( "--cap-seq-len", type=int, default=65536, help="HF 配置下默认 max_seq_len 上限,避免每层分配过大 RoPE 缓冲", ) ap.add_argument( "--prefix", type=str, default="", help='为输出键名加此前缀(如 "model.");默认与 index 一致,不加前缀', ) ap.add_argument( "--index-json", type=Path, default=None, help="校验或分片布局:默认在 --output-dir 时为仓库根 model.safetensors.index.json", ) ap.add_argument("--strict-index", action="store_true", help="与 --index-json 联用,要求键完全一致") args = ap.parse_args() import torch torch.set_default_dtype(torch.bfloat16) margs = load_model_args( args.config, max_batch_size=args.max_batch_size, max_seq_len=args.max_seq_len, cap_seq_len=args.cap_seq_len, ) with open(REPO_ROOT / "ds_config.json", "w", encoding="utf-8") as f: json.dump(margs.__dict__, f, indent=2) dev = torch.device(args.device) with torch.device(dev): from transformers import set_seed set_seed(42) model = Transformer(margs) n_params = sum(p.numel() for p in model.parameters()) print(f"Number of parameters: {n_params}") for k, v in model.named_modules(): if k.count('.') <= 3: n_params_k = sum(p.numel() for p in v.parameters()) print(k, f"{n_params_k} {n_params_k / n_params:.2%}") if dev.type != "meta": model = model.to(dev) for k, v in model.named_parameters(): if 'norm.weight' in k or k == "head.weight": v.data = v.data.to(torch.bfloat16) sd_inference = without_mtp_tied_aliases(model.state_dict()) index_path = args.index_json if args.output_dir is not None and index_path is None: index_path = REPO_ROOT / "model.safetensors.index.json" if index_path is not None: missing, extra = validate_against_index( set(sd_inference.keys()), index_path) if missing or extra: msg = f"index 对比: missing={len(missing)} extra={len(extra)}" if missing and len(missing) <= 20: msg += f"\n missing 样例: {sorted(missing)[:20]}" elif missing: msg += f"\n missing 样例: {sorted(missing)[:5]} ..." if extra and len(extra) <= 20: msg += f"\n extra 样例: {sorted(extra)[:20]}" elif extra: msg += f"\n extra 样例: {sorted(extra)[:5]} ..." if args.strict_index: raise SystemExit(msg) print(msg, file=sys.stderr) if args.output_dir is not None: if dev.type == "meta": raise SystemExit("device=meta 时无法写 safetensors,请改用 cpu/cuda") wm = _load_index_weight_map(index_path) # type: ignore[arg-type] save_sharded_index_style( sd_inference, wm, args.output_dir, ) written = sum(1 for k in wm if k in sd_inference) n_shards = len({wm[k] for k in wm if k in sd_inference}) print( f"Wrote {written} tensors in {n_shards} shard files under {args.output_dir}" ) return sd = optional_prefix_keys(sd_inference, args.prefix or None) out = args.output out.parent.mkdir(parents=True, exist_ok=True) suffix = out.suffix.lower() for k, v in sd.items(): print(k, v.shape, v.dtype) if suffix == ".safetensors": from safetensors.torch import save_file # meta 张量无法写入 safetensors if dev.type == "meta": raise SystemExit( "device=meta 时无法写 safetensors,请改用 cpu/cuda 或输出 .pt") save_file(sd, str(out)) else: if dev.type == "meta": torch.save(sd, out) else: torch.save(sd, out) print(f"Wrote {len(sd)} tensors to {out}") if __name__ == "__main__": main()