import torch import os path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref_bak/checkpoint-step-1050-epoch-13/pytorch_model.bin" # 改成你的路径 print(f"\n=== Loading {path} ===") state = torch.load(path, map_location="cpu") import pdb; pdb.set_trace() print(f"\nTotal parameters: {len(state)}") import torch from safetensors.torch import load_file # ====== 1. 加载 DeepSpeed 的完整模型参数 ====== mp_path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref/checkpoint-step-2-epoch-1/pytorch_model/mp_rank_00_model_states.pt" mp_sd = torch.load(mp_path, map_location="cpu") # DS 的 state_dict 存在 mp_sd["module"] 或 mp_sd,本项目一般是 "module" if "module" in mp_sd: mp_sd = mp_sd["module"] print(f"[INFO] Loaded DeepSpeed full model: {len(mp_sd)} params") # ====== 2. 加载你保存的可训练权重 ====== safe_path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref/checkpoint-step-2-epoch-1/weights.safetensors" safe_sd = load_file(safe_path) print(f"[INFO] Loaded safetensors trainable params: {len(safe_sd)} params") # ====== 3. key 集合对比 ====== mp_keys = set(mp_sd.keys()) safe_keys = set(safe_sd.keys()) for key in state.keys(): if not "pipe." + key in mp_sd.keys(): print(key) import pdb; pdb.set_trace() common_keys = mp_keys & safe_keys only_mp = mp_keys - safe_keys only_safe = safe_keys - mp_keys print("\n===== KEY COMPARISON =====") print(f"Common keys: {len(common_keys)}") print(f"Only in mp_rank_00: {len(only_mp)}") print(f"Only in weights.safetensors: {len(only_safe)}") if only_mp: print("\n-- Keys ONLY in mp_rank_00_model_states.pt (show first 10):") for k in list(only_mp)[:10]: print(" ", k) if only_safe: print("\n-- Keys ONLY in weights.safetensors (show first 10):") for k in list(only_safe)[:10]: print(" ", k) # ====== 4. 查看共同 keys 的 shape 差异 ====== print("\n===== SHAPE CHECK (common keys) =====") shape_mismatch = [] for k in list(common_keys)[:50]: # 前50个 if mp_sd[k].shape != safe_sd[k].shape: shape_mismatch.append((k, mp_sd[k].shape, safe_sd[k].shape)) if shape_mismatch: print("Shape mismatch found:") for k, s1, s2 in shape_mismatch[:20]: print(f" {k}: mp={s1}, safe={s2}") else: print("All shapes match ✓") # ====== 5. 可选:检查参数差异(可能很大,默认只检查前 20 个 key)===== print("\n===== VALUE DIFFERENCE CHECK (first 20 keys) =====") for k in list(common_keys)[:20]: diff = torch.abs(mp_sd[k] - safe_sd[k]).mean().item() print(f"{k}: mean(abs(diff)) = {diff:.6f}")