multishot / multi_view /mytest.py
PencilHu's picture
Upload folder using huggingface_hub
85752bc verified
import torch
import os
path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref_bak/checkpoint-step-1050-epoch-13/pytorch_model.bin" # 改成你的路径
print(f"\n=== Loading {path} ===")
state = torch.load(path, map_location="cpu")
import pdb; pdb.set_trace()
print(f"\nTotal parameters: {len(state)}")
import torch
from safetensors.torch import load_file
# ====== 1. 加载 DeepSpeed 的完整模型参数 ======
mp_path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref/checkpoint-step-2-epoch-1/pytorch_model/mp_rank_00_model_states.pt"
mp_sd = torch.load(mp_path, map_location="cpu")
# DS 的 state_dict 存在 mp_sd["module"] 或 mp_sd,本项目一般是 "module"
if "module" in mp_sd:
mp_sd = mp_sd["module"]
print(f"[INFO] Loaded DeepSpeed full model: {len(mp_sd)} params")
# ====== 2. 加载你保存的可训练权重 ======
safe_path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref/checkpoint-step-2-epoch-1/weights.safetensors"
safe_sd = load_file(safe_path)
print(f"[INFO] Loaded safetensors trainable params: {len(safe_sd)} params")
# ====== 3. key 集合对比 ======
mp_keys = set(mp_sd.keys())
safe_keys = set(safe_sd.keys())
for key in state.keys():
if not "pipe." + key in mp_sd.keys():
print(key)
import pdb; pdb.set_trace()
common_keys = mp_keys & safe_keys
only_mp = mp_keys - safe_keys
only_safe = safe_keys - mp_keys
print("\n===== KEY COMPARISON =====")
print(f"Common keys: {len(common_keys)}")
print(f"Only in mp_rank_00: {len(only_mp)}")
print(f"Only in weights.safetensors: {len(only_safe)}")
if only_mp:
print("\n-- Keys ONLY in mp_rank_00_model_states.pt (show first 10):")
for k in list(only_mp)[:10]:
print(" ", k)
if only_safe:
print("\n-- Keys ONLY in weights.safetensors (show first 10):")
for k in list(only_safe)[:10]:
print(" ", k)
# ====== 4. 查看共同 keys 的 shape 差异 ======
print("\n===== SHAPE CHECK (common keys) =====")
shape_mismatch = []
for k in list(common_keys)[:50]: # 前50个
if mp_sd[k].shape != safe_sd[k].shape:
shape_mismatch.append((k, mp_sd[k].shape, safe_sd[k].shape))
if shape_mismatch:
print("Shape mismatch found:")
for k, s1, s2 in shape_mismatch[:20]:
print(f" {k}: mp={s1}, safe={s2}")
else:
print("All shapes match ✓")
# ====== 5. 可选:检查参数差异(可能很大,默认只检查前 20 个 key)=====
print("\n===== VALUE DIFFERENCE CHECK (first 20 keys) =====")
for k in list(common_keys)[:20]:
diff = torch.abs(mp_sd[k] - safe_sd[k]).mean().item()
print(f"{k}: mean(abs(diff)) = {diff:.6f}")