| |
| """ |
| 从 HuggingFace 风格的 config.json 构建 inference/model.py 中的 Transformer, |
| 导出 state_dict,键名与仓库内 model.safetensors.index.json / info.html 中的 Metadata 列一致 |
| (例如 embed.weight、layers.0.attn.wq_a.weight、layers.0.ffn.shared_experts.w1.weight, |
| 无 transformers 的 model. 前缀)。 |
| |
| 分片保存:指定 --output-dir 时,按 --index-json(默认仓库根目录 model.safetensors.index.json) |
| 的 weight_map 将张量写入对应 model-XXXX-of-YYYY.safetensors,并生成同目录下的 |
| model.safetensors.index.json(键名与参考索引一致,分片文件名与参考索引一致)。 |
| |
| 用法示例: |
| python export_state_dict_from_config.py --config config.json --output /tmp/out.safetensors |
| python export_state_dict_from_config.py --config config.json --output-dir /tmp/shards |
| |
| 完整 max_position_embeddings 会在每层分配 RoPE 缓冲,内存很大;默认将 max_seq_len 限制为 |
| 65536,可用 --max-seq-len 覆盖。state_dict 不包含 persistent=False 的 buffer,与官方分片键集合对齐。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Any |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
| INFERENCE_DIR = REPO_ROOT / "inference" |
| sys.path.insert(0, str(INFERENCE_DIR)) |
|
|
| from model import ModelArgs, Transformer |
|
|
|
|
| def without_mtp_tied_aliases(sd: dict[str, Any]) -> dict[str, Any]: |
| """MTP 的 embed/head 与主模块共享参数,safetensors 不能重复保存这些别名。""" |
| out: dict[str, Any] = {} |
| for k, v in sd.items(): |
| parts = k.split(".") |
| if len(parts) >= 4 and parts[0] == "mtp" and parts[2] in {"embed", "head"}: |
| continue |
| out[k] = v |
| return out |
|
|
|
|
| def _load_index_weight_map(index_path: Path) -> dict[str, str]: |
| with open(index_path, encoding="utf-8") as f: |
| idx = json.load(f) |
| wm = idx.get("weight_map") |
| if not isinstance(wm, dict): |
| raise SystemExit(f"{index_path} 缺少 weight_map") |
| return dict(wm) |
|
|
|
|
| def save_sharded_index_style( |
| sd_inference: dict[str, Any], |
| weight_map: dict[str, str], |
| out_dir: Path, |
| ) -> None: |
| """按 weight_map 中的分片文件名写入多个 .safetensors,并写 model.safetensors.index.json。""" |
| from safetensors.torch import save_file |
|
|
| out_dir.mkdir(parents=True, exist_ok=True) |
| shard_to_tensors: dict[str, dict[str, Any]] = defaultdict(dict) |
| new_weight_map: dict[str, str] = {} |
| for inf_key, shard_file in weight_map.items(): |
| if inf_key not in sd_inference: |
| continue |
| shard_to_tensors[shard_file][inf_key] = sd_inference[inf_key] |
| new_weight_map[inf_key] = shard_file |
|
|
| if not new_weight_map: |
| raise SystemExit("weight_map 与当前 state_dict 无交集,无法分片写出") |
|
|
| total_size = 0 |
| for shard_file, part in shard_to_tensors.items(): |
| for t in part.values(): |
| total_size += int(t.numel()) * int(t.element_size()) |
| save_file(part, str(out_dir / shard_file)) |
|
|
| index_out = { |
| "metadata": {"total_size": total_size}, |
| "weight_map": new_weight_map, |
| } |
| with open(out_dir / "model.safetensors.index.json", "w", encoding="utf-8") as f: |
| json.dump(index_out, f, indent=2) |
| f.write("\n") |
|
|
|
|
| def _is_hf_deepseek_config(cfg: dict[str, Any]) -> bool: |
| return cfg.get("model_type") == "deepseek_v4" or "num_hidden_layers" in cfg |
|
|
|
|
| def hf_config_to_model_args( |
| cfg: dict[str, Any], |
| *, |
| max_batch_size: int, |
| max_seq_len: int, |
| ) -> ModelArgs: |
| """将仓库根目录下 HuggingFace 的 config.json 转为 inference.ModelArgs。""" |
| rope = cfg.get("rope_scaling") or {} |
| qc = cfg.get("quantization_config") or {} |
|
|
| use_fp8 = qc.get("quant_method") == "fp8" or str(qc.get("fmt", "")).lower() in ( |
| "e4m3", |
| "e4m3fn", |
| "fp8", |
| ) |
|
|
| n_layers = int(cfg["num_hidden_layers"]) |
| ratios = list(cfg.get("compress_ratios") or ()) |
| compress_ratios = tuple(int(x) for x in ratios) |
|
|
| return ModelArgs( |
| max_batch_size=max_batch_size, |
| max_seq_len=max_seq_len, |
| dtype="fp8" if use_fp8 else "bf16", |
| scale_fmt=qc.get("scale_fmt") or cfg.get("scale_fmt") or "ue8m0", |
| expert_dtype=cfg.get("expert_dtype"), |
| scale_dtype=cfg.get("scale_dtype") or ("fp8" if use_fp8 else "fp32"), |
| vocab_size=int(cfg["vocab_size"]), |
| dim=int(cfg["hidden_size"]), |
| moe_inter_dim=int(cfg["moe_intermediate_size"]), |
| n_layers=n_layers, |
| n_hash_layers=int(cfg.get("num_hash_layers", 0)), |
| n_mtp_layers=int(cfg.get("num_nextn_predict_layers", 0)), |
| n_heads=int(cfg["num_attention_heads"]), |
| n_routed_experts=int(cfg["n_routed_experts"]), |
| n_shared_experts=int(cfg.get("n_shared_experts", 1)), |
| n_activated_experts=int(cfg["num_experts_per_tok"]), |
| score_func=str(cfg.get("scoring_func", cfg.get( |
| "score_func", "sqrtsoftplus"))), |
| route_scale=float(cfg.get("routed_scaling_factor", |
| cfg.get("route_scale", 1.0))), |
| swiglu_limit=float(cfg.get("swiglu_limit", 0.0)), |
| q_lora_rank=int(cfg["q_lora_rank"]), |
| head_dim=int(cfg["head_dim"]), |
| rope_head_dim=int(cfg["qk_rope_head_dim"]), |
| norm_eps=float(cfg.get("rms_norm_eps", cfg.get("norm_eps", 1e-6))), |
| o_groups=int(cfg["o_groups"]), |
| o_lora_rank=int(cfg["o_lora_rank"]), |
| window_size=int(cfg["sliding_window"]), |
| compress_ratios=compress_ratios, |
| compress_rope_theta=float(cfg.get("compress_rope_theta", 160000.0)), |
| original_seq_len=int( |
| rope.get( |
| "original_max_position_embeddings", |
| cfg.get("original_seq_len", 0), |
| ) |
| ), |
| rope_theta=float(cfg.get("rope_theta", 10000.0)), |
| rope_factor=float(rope.get("factor", cfg.get("rope_factor", 1.0))), |
| beta_fast=int(rope.get("beta_fast", cfg.get("beta_fast", 32))), |
| beta_slow=int(rope.get("beta_slow", cfg.get("beta_slow", 1))), |
| index_n_heads=int(cfg["index_n_heads"]), |
| index_head_dim=int(cfg["index_head_dim"]), |
| index_topk=int(cfg["index_topk"]), |
| hc_mult=int(cfg["hc_mult"]), |
| hc_sinkhorn_iters=int(cfg["hc_sinkhorn_iters"]), |
| hc_eps=float(cfg.get("hc_eps", 1e-6)), |
| ) |
|
|
|
|
| def load_model_args( |
| config_path: Path, |
| *, |
| max_batch_size: int, |
| max_seq_len: int | None, |
| cap_seq_len: int, |
| ) -> ModelArgs: |
| with open(config_path, encoding="utf-8") as f: |
| raw = json.load(f) |
|
|
| if _is_hf_deepseek_config(raw): |
| cfg_max = int(raw.get("max_position_embeddings", cap_seq_len)) |
| mseq = max_seq_len if max_seq_len is not None else min( |
| cfg_max, cap_seq_len) |
| return hf_config_to_model_args( |
| raw, max_batch_size=max_batch_size, max_seq_len=mseq |
| ) |
| return ModelArgs(**raw) |
|
|
|
|
| def optional_prefix_keys( |
| sd: dict[str, Any], prefix: str | None |
| ) -> dict[str, Any]: |
| if not prefix: |
| return sd |
| p = prefix.rstrip(".") + "." |
| return {p + k: v for k, v in sd.items()} |
|
|
|
|
| def validate_against_index( |
| keys: set[str], index_path: Path |
| ) -> tuple[set[str], set[str]]: |
| with open(index_path, encoding="utf-8") as f: |
| idx = json.load(f) |
| ref = set(idx.get("weight_map", {}).keys()) |
| missing = ref - keys |
| extra = keys - ref |
| return missing, extra |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser(description=__doc__) |
| ap.add_argument( |
| "--config", |
| type=Path, |
| default=REPO_ROOT / "config.json", |
| help="config.json(HF 或 inference 格式)", |
| ) |
| out_group = ap.add_mutually_exclusive_group(required=True) |
| out_group.add_argument( |
| "--output", |
| type=Path, |
| help="输出单个 .safetensors 或 .pt / .pth", |
| ) |
| out_group.add_argument( |
| "--output-dir", |
| type=Path, |
| default=None, |
| help="按 index 的 weight_map 分片写入该目录(仅 safetensors + model.safetensors.index.json)", |
| ) |
| ap.add_argument("--device", type=str, default="cpu", |
| help="cpu / cuda / meta") |
| ap.add_argument("--max-batch-size", type=int, default=4) |
| ap.add_argument( |
| "--max-seq-len", |
| type=int, |
| default=None, |
| help="覆盖 ModelArgs.max_seq_len;默认 min(max_position_embeddings, --cap-seq-len)", |
| ) |
| ap.add_argument( |
| "--cap-seq-len", |
| type=int, |
| default=65536, |
| help="HF 配置下默认 max_seq_len 上限,避免每层分配过大 RoPE 缓冲", |
| ) |
| ap.add_argument( |
| "--prefix", |
| type=str, |
| default="", |
| help='为输出键名加此前缀(如 "model.");默认与 index 一致,不加前缀', |
| ) |
| ap.add_argument( |
| "--index-json", |
| type=Path, |
| default=None, |
| help="校验或分片布局:默认在 --output-dir 时为仓库根 model.safetensors.index.json", |
| ) |
| ap.add_argument("--strict-index", action="store_true", |
| help="与 --index-json 联用,要求键完全一致") |
| args = ap.parse_args() |
|
|
| import torch |
| torch.set_default_dtype(torch.bfloat16) |
|
|
| margs = load_model_args( |
| args.config, |
| max_batch_size=args.max_batch_size, |
| max_seq_len=args.max_seq_len, |
| cap_seq_len=args.cap_seq_len, |
| ) |
| with open(REPO_ROOT / "ds_config.json", "w", encoding="utf-8") as f: |
| json.dump(margs.__dict__, f, indent=2) |
|
|
| dev = torch.device(args.device) |
| with torch.device(dev): |
| from transformers import set_seed |
| set_seed(42) |
| model = Transformer(margs) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Number of parameters: {n_params}") |
| for k, v in model.named_modules(): |
| if k.count('.') <= 3: |
| n_params_k = sum(p.numel() for p in v.parameters()) |
| print(k, f"{n_params_k} {n_params_k / n_params:.2%}") |
|
|
| if dev.type != "meta": |
| model = model.to(dev) |
| for k, v in model.named_parameters(): |
| if 'norm.weight' in k or k == "head.weight": |
| v.data = v.data.to(torch.bfloat16) |
|
|
| sd_inference = without_mtp_tied_aliases(model.state_dict()) |
| index_path = args.index_json |
| if args.output_dir is not None and index_path is None: |
| index_path = REPO_ROOT / "model.safetensors.index.json" |
|
|
| if index_path is not None: |
| missing, extra = validate_against_index( |
| set(sd_inference.keys()), index_path) |
| if missing or extra: |
| msg = f"index 对比: missing={len(missing)} extra={len(extra)}" |
| if missing and len(missing) <= 20: |
| msg += f"\n missing 样例: {sorted(missing)[:20]}" |
| elif missing: |
| msg += f"\n missing 样例: {sorted(missing)[:5]} ..." |
| if extra and len(extra) <= 20: |
| msg += f"\n extra 样例: {sorted(extra)[:20]}" |
| elif extra: |
| msg += f"\n extra 样例: {sorted(extra)[:5]} ..." |
| if args.strict_index: |
| raise SystemExit(msg) |
| print(msg, file=sys.stderr) |
|
|
| if args.output_dir is not None: |
| if dev.type == "meta": |
| raise SystemExit("device=meta 时无法写 safetensors,请改用 cpu/cuda") |
| wm = _load_index_weight_map(index_path) |
| save_sharded_index_style( |
| sd_inference, |
| wm, |
| args.output_dir, |
| ) |
| written = sum(1 for k in wm if k in sd_inference) |
| n_shards = len({wm[k] for k in wm if k in sd_inference}) |
| print( |
| f"Wrote {written} tensors in {n_shards} shard files under {args.output_dir}" |
| ) |
| return |
|
|
| sd = optional_prefix_keys(sd_inference, args.prefix or None) |
|
|
| out = args.output |
| out.parent.mkdir(parents=True, exist_ok=True) |
| suffix = out.suffix.lower() |
| for k, v in sd.items(): |
| print(k, v.shape, v.dtype) |
| if suffix == ".safetensors": |
| from safetensors.torch import save_file |
|
|
| |
| if dev.type == "meta": |
| raise SystemExit( |
| "device=meta 时无法写 safetensors,请改用 cpu/cuda 或输出 .pt") |
| save_file(sd, str(out)) |
| else: |
| if dev.type == "meta": |
| torch.save(sd, out) |
| else: |
| torch.save(sd, out) |
| print(f"Wrote {len(sd)} tensors to {out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|