| import torch |
| import os |
|
|
| path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref_bak/checkpoint-step-1050-epoch-13/pytorch_model.bin" |
|
|
| print(f"\n=== Loading {path} ===") |
|
|
| state = torch.load(path, map_location="cpu") |
|
|
| import pdb; pdb.set_trace() |
|
|
| print(f"\nTotal parameters: {len(state)}") |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
| |
| mp_path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref/checkpoint-step-2-epoch-1/pytorch_model/mp_rank_00_model_states.pt" |
| mp_sd = torch.load(mp_path, map_location="cpu") |
|
|
| |
| if "module" in mp_sd: |
| mp_sd = mp_sd["module"] |
|
|
| print(f"[INFO] Loaded DeepSpeed full model: {len(mp_sd)} params") |
|
|
| |
| safe_path = "/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/ckpts/Wan2.2_5B-Multi_view-normal_rope_384_640-3ref/checkpoint-step-2-epoch-1/weights.safetensors" |
| safe_sd = load_file(safe_path) |
| print(f"[INFO] Loaded safetensors trainable params: {len(safe_sd)} params") |
|
|
| |
| mp_keys = set(mp_sd.keys()) |
| safe_keys = set(safe_sd.keys()) |
| for key in state.keys(): |
| if not "pipe." + key in mp_sd.keys(): |
| print(key) |
|
|
| import pdb; pdb.set_trace() |
|
|
| common_keys = mp_keys & safe_keys |
| only_mp = mp_keys - safe_keys |
| only_safe = safe_keys - mp_keys |
|
|
| print("\n===== KEY COMPARISON =====") |
| print(f"Common keys: {len(common_keys)}") |
| print(f"Only in mp_rank_00: {len(only_mp)}") |
| print(f"Only in weights.safetensors: {len(only_safe)}") |
|
|
| if only_mp: |
| print("\n-- Keys ONLY in mp_rank_00_model_states.pt (show first 10):") |
| for k in list(only_mp)[:10]: |
| print(" ", k) |
|
|
| if only_safe: |
| print("\n-- Keys ONLY in weights.safetensors (show first 10):") |
| for k in list(only_safe)[:10]: |
| print(" ", k) |
|
|
| |
| print("\n===== SHAPE CHECK (common keys) =====") |
| shape_mismatch = [] |
| for k in list(common_keys)[:50]: |
| if mp_sd[k].shape != safe_sd[k].shape: |
| shape_mismatch.append((k, mp_sd[k].shape, safe_sd[k].shape)) |
|
|
| if shape_mismatch: |
| print("Shape mismatch found:") |
| for k, s1, s2 in shape_mismatch[:20]: |
| print(f" {k}: mp={s1}, safe={s2}") |
| else: |
| print("All shapes match ✓") |
|
|
| |
| print("\n===== VALUE DIFFERENCE CHECK (first 20 keys) =====") |
| for k in list(common_keys)[:20]: |
| diff = torch.abs(mp_sd[k] - safe_sd[k]).mean().item() |
| print(f"{k}: mean(abs(diff)) = {diff:.6f}") |