import dataclasses import logging from pathlib import Path from typing import Any import torch import tyro from openpi.training import config as _config @dataclasses.dataclass class CkptSpec: dir: str @dataclasses.dataclass class Args: config: str a: CkptSpec b: CkptSpec out: str = "checkpoints/diff/a_minus_b.pth" only_vlm: bool = False strict_keys: bool = False dtype: str = "fp32" device: str = "cpu" def _extract_state_dict(obj: Any) -> dict[str, torch.Tensor]: """ Try best to get a torch state_dict from a Policy or Module-like object. """ # Case 1: policy itself has state_dict() if hasattr(obj, "state_dict") and callable(obj.state_dict): sd = obj.state_dict() if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()): return sd # Case 2: common attributes that hold torch.nn.Module for attr in ["model", "_model", "module", "net", "_net", "policy", "_policy"]: if hasattr(obj, attr): m = getattr(obj, attr) if hasattr(m, "state_dict") and callable(m.state_dict): sd = m.state_dict() if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()): return sd raise RuntimeError( "Cannot extract state_dict. " "Please inspect Policy object and update attribute list in _extract_state_dict()." ) def _cast_tensor(t: torch.Tensor, dtype: str) -> torch.Tensor: if dtype == "fp32": return t.float() if dtype == "fp16": return t.half() if dtype == "bf16": return t.bfloat16() raise ValueError(f"Unknown dtype: {dtype}") def load_model(config_name: str, spec: CkptSpec): cfg = _config.get_config(config_name) weight_path = Path(spec.dir) / "model.safetensors" if not weight_path.exists(): raise FileNotFoundError(f"Missing model.safetensors in checkpoint directory: {spec.dir}") return cfg.model.load_pytorch(cfg, str(weight_path)) def main(args: Args) -> None: logging.info("Loading A model from %s with config %s", args.a.dir, args.config) model_a = load_model(args.config, args.a) logging.info("Loading B model from %s with config %s", args.b.dir, args.config) model_b = load_model(args.config, args.b) sd_a = _extract_state_dict(model_a) sd_b = _extract_state_dict(model_b) keys_a = set(sd_a.keys()) keys_b = set(sd_b.keys()) if args.strict_keys: if keys_a != keys_b: only_a = sorted(list(keys_a - keys_b))[:20] only_b = sorted(list(keys_b - keys_a))[:20] raise RuntimeError( f"State dict keys mismatch.\n" f"Only in A (show up to 20): {only_a}\n" f"Only in B (show up to 20): {only_b}\n" f"Set --strict-keys False to subtract intersection only." ) keys = sorted(keys_a) else: keys = sorted(list(keys_a & keys_b)) logging.warning("Non-strict mode: subtracting only intersection keys: %d", len(keys)) device = torch.device(args.device) diff: dict[str, torch.Tensor] = {} if args.only_vlm: ZERO_PREFIXES = [ "paligemma_with_expert.gemma_expert.", "action_in_proj.", "action_out_proj.", "action_time_mlp_in", "action_time_mlp_oout", ] else: ZERO_PREFIXES = [] for k in keys: ta = sd_a[k].to(device) tb = sd_b[k].to(device) if ta.shape != tb.shape: raise RuntimeError(f"Shape mismatch at key={k}: {ta.shape} vs {tb.shape}") zero_this = any(k.startswith(p) for p in ZERO_PREFIXES) if zero_this: out = torch.zeros_like(ta) else: if ta.is_floating_point(): out = _cast_tensor(ta, args.dtype) - _cast_tensor(tb, args.dtype) else: out = ta diff[k] = out.detach().cpu() out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) torch.save({"state_dict": diff, "a": dataclasses.asdict(args.a), "b": dataclasses.asdict(args.b)}, out_path) logging.info("Saved diff checkpoint to: %s", str(out_path)) if __name__ == "__main__": logging.basicConfig(level=logging.INFO, force=True) main(tyro.cli(Args))