haofuly's picture
Add files using upload-large-folder tool
45ac12e verified
raw
history blame
4.55 kB
import dataclasses
import logging
from pathlib import Path
from typing import Any
import torch
import tyro
from openpi.training import config as _config
@dataclasses.dataclass
class CkptSpec:
dir: str
@dataclasses.dataclass
class Args:
config: str
a: CkptSpec
b: CkptSpec
out: str = "checkpoints/diff/a_minus_b.pth"
only_vlm: bool = False
strict_keys: bool = False
dtype: str = "fp32"
device: str = "cpu"
def _extract_state_dict(obj: Any) -> dict[str, torch.Tensor]:
"""
Try best to get a torch state_dict from a Policy or Module-like object.
"""
# Case 1: policy itself has state_dict()
if hasattr(obj, "state_dict") and callable(obj.state_dict):
sd = obj.state_dict()
if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
return sd
# Case 2: common attributes that hold torch.nn.Module
for attr in ["model", "_model", "module", "net", "_net", "policy", "_policy"]:
if hasattr(obj, attr):
m = getattr(obj, attr)
if hasattr(m, "state_dict") and callable(m.state_dict):
sd = m.state_dict()
if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
return sd
raise RuntimeError(
"Cannot extract state_dict. "
"Please inspect Policy object and update attribute list in _extract_state_dict()."
)
def _cast_tensor(t: torch.Tensor, dtype: str) -> torch.Tensor:
if dtype == "fp32":
return t.float()
if dtype == "fp16":
return t.half()
if dtype == "bf16":
return t.bfloat16()
raise ValueError(f"Unknown dtype: {dtype}")
def load_model(config_name: str, spec: CkptSpec):
cfg = _config.get_config(config_name)
weight_path = Path(spec.dir) / "model.safetensors"
if not weight_path.exists():
raise FileNotFoundError(f"Missing model.safetensors in checkpoint directory: {spec.dir}")
return cfg.model.load_pytorch(cfg, str(weight_path))
def main(args: Args) -> None:
logging.info("Loading A model from %s with config %s", args.a.dir, args.config)
model_a = load_model(args.config, args.a)
logging.info("Loading B model from %s with config %s", args.b.dir, args.config)
model_b = load_model(args.config, args.b)
sd_a = _extract_state_dict(model_a)
sd_b = _extract_state_dict(model_b)
keys_a = set(sd_a.keys())
keys_b = set(sd_b.keys())
if args.strict_keys:
if keys_a != keys_b:
only_a = sorted(list(keys_a - keys_b))[:20]
only_b = sorted(list(keys_b - keys_a))[:20]
raise RuntimeError(
f"State dict keys mismatch.\n"
f"Only in A (show up to 20): {only_a}\n"
f"Only in B (show up to 20): {only_b}\n"
f"Set --strict-keys False to subtract intersection only."
)
keys = sorted(keys_a)
else:
keys = sorted(list(keys_a & keys_b))
logging.warning("Non-strict mode: subtracting only intersection keys: %d", len(keys))
device = torch.device(args.device)
diff: dict[str, torch.Tensor] = {}
if args.only_vlm:
ZERO_PREFIXES = [
"paligemma_with_expert.gemma_expert.",
"action_in_proj.",
"action_out_proj.",
"action_time_mlp_in",
"action_time_mlp_oout",
]
else:
ZERO_PREFIXES = []
for k in keys:
ta = sd_a[k].to(device)
tb = sd_b[k].to(device)
if ta.shape != tb.shape:
raise RuntimeError(f"Shape mismatch at key={k}: {ta.shape} vs {tb.shape}")
zero_this = any(k.startswith(p) for p in ZERO_PREFIXES)
if zero_this:
out = torch.zeros_like(ta)
else:
if ta.is_floating_point():
out = _cast_tensor(ta, args.dtype) - _cast_tensor(tb, args.dtype)
else:
out = ta
diff[k] = out.detach().cpu()
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({"state_dict": diff, "a": dataclasses.asdict(args.a), "b": dataclasses.asdict(args.b)}, out_path)
logging.info("Saved diff checkpoint to: %s", str(out_path))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
main(tyro.cli(Args))