| import dataclasses
|
| import logging
|
| from pathlib import Path
|
| from typing import Any
|
|
|
| import torch
|
| import tyro
|
|
|
| from openpi.training import config as _config
|
|
|
|
|
| @dataclasses.dataclass
|
| class CkptSpec:
|
| dir: str
|
|
|
|
|
| @dataclasses.dataclass
|
| class Args:
|
| config: str
|
| a: CkptSpec
|
| b: CkptSpec
|
| out: str = "checkpoints/diff/a_minus_b.pth"
|
| only_vlm: bool = False
|
| strict_keys: bool = False
|
| dtype: str = "fp32"
|
| device: str = "cpu"
|
|
|
|
|
| def _extract_state_dict(obj: Any) -> dict[str, torch.Tensor]:
|
| """
|
| Try best to get a torch state_dict from a Policy or Module-like object.
|
| """
|
|
|
| if hasattr(obj, "state_dict") and callable(obj.state_dict):
|
| sd = obj.state_dict()
|
| if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
|
| return sd
|
|
|
|
|
| for attr in ["model", "_model", "module", "net", "_net", "policy", "_policy"]:
|
| if hasattr(obj, attr):
|
| m = getattr(obj, attr)
|
| if hasattr(m, "state_dict") and callable(m.state_dict):
|
| sd = m.state_dict()
|
| if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
|
| return sd
|
|
|
| raise RuntimeError(
|
| "Cannot extract state_dict. "
|
| "Please inspect Policy object and update attribute list in _extract_state_dict()."
|
| )
|
|
|
|
|
| def _cast_tensor(t: torch.Tensor, dtype: str) -> torch.Tensor:
|
| if dtype == "fp32":
|
| return t.float()
|
| if dtype == "fp16":
|
| return t.half()
|
| if dtype == "bf16":
|
| return t.bfloat16()
|
| raise ValueError(f"Unknown dtype: {dtype}")
|
|
|
|
|
| def load_model(config_name: str, spec: CkptSpec):
|
| cfg = _config.get_config(config_name)
|
| weight_path = Path(spec.dir) / "model.safetensors"
|
| if not weight_path.exists():
|
| raise FileNotFoundError(f"Missing model.safetensors in checkpoint directory: {spec.dir}")
|
| return cfg.model.load_pytorch(cfg, str(weight_path))
|
|
|
|
|
| def main(args: Args) -> None:
|
| logging.info("Loading A model from %s with config %s", args.a.dir, args.config)
|
| model_a = load_model(args.config, args.a)
|
| logging.info("Loading B model from %s with config %s", args.b.dir, args.config)
|
| model_b = load_model(args.config, args.b)
|
|
|
| sd_a = _extract_state_dict(model_a)
|
| sd_b = _extract_state_dict(model_b)
|
|
|
| keys_a = set(sd_a.keys())
|
| keys_b = set(sd_b.keys())
|
|
|
| if args.strict_keys:
|
| if keys_a != keys_b:
|
| only_a = sorted(list(keys_a - keys_b))[:20]
|
| only_b = sorted(list(keys_b - keys_a))[:20]
|
| raise RuntimeError(
|
| f"State dict keys mismatch.\n"
|
| f"Only in A (show up to 20): {only_a}\n"
|
| f"Only in B (show up to 20): {only_b}\n"
|
| f"Set --strict-keys False to subtract intersection only."
|
| )
|
| keys = sorted(keys_a)
|
| else:
|
| keys = sorted(list(keys_a & keys_b))
|
| logging.warning("Non-strict mode: subtracting only intersection keys: %d", len(keys))
|
|
|
| device = torch.device(args.device)
|
| diff: dict[str, torch.Tensor] = {}
|
|
|
| if args.only_vlm:
|
| ZERO_PREFIXES = [
|
| "paligemma_with_expert.gemma_expert.",
|
| "action_in_proj.",
|
| "action_out_proj.",
|
| "action_time_mlp_in",
|
| "action_time_mlp_oout",
|
| ]
|
| else:
|
| ZERO_PREFIXES = []
|
|
|
| for k in keys:
|
| ta = sd_a[k].to(device)
|
| tb = sd_b[k].to(device)
|
|
|
| if ta.shape != tb.shape:
|
| raise RuntimeError(f"Shape mismatch at key={k}: {ta.shape} vs {tb.shape}")
|
|
|
| zero_this = any(k.startswith(p) for p in ZERO_PREFIXES)
|
|
|
| if zero_this:
|
| out = torch.zeros_like(ta)
|
| else:
|
| if ta.is_floating_point():
|
| out = _cast_tensor(ta, args.dtype) - _cast_tensor(tb, args.dtype)
|
| else:
|
| out = ta
|
|
|
| diff[k] = out.detach().cpu()
|
|
|
|
|
|
|
| out_path = Path(args.out)
|
| out_path.parent.mkdir(parents=True, exist_ok=True)
|
| torch.save({"state_dict": diff, "a": dataclasses.asdict(args.a), "b": dataclasses.asdict(args.b)}, out_path)
|
| logging.info("Saved diff checkpoint to: %s", str(out_path))
|
|
|
|
|
| if __name__ == "__main__":
|
| logging.basicConfig(level=logging.INFO, force=True)
|
| main(tyro.cli(Args))
|
|
|