#!/usr/bin/env python3 """ Convert CLIP and SAM3 checkpoints to safetensors format. Run from repo root: python convert_to_safetensors.py """ import argparse from pathlib import Path import torch from safetensors.torch import save_file def convert_clip(source_dir: Path, output_dir: Path = None) -> Path: """Convert CLIP pytorch_model.bin to model.safetensors.""" output_dir = output_dir or source_dir bin_path = source_dir / "pytorch_model.bin" out_path = output_dir / "model.safetensors" if out_path.exists(): print(f"CLIP: {out_path} already exists, skip") return out_path if bin_path.exists(): print(f"CLIP: Loading from {bin_path}...") state_dict = torch.load(bin_path, map_location="cpu", weights_only=True) else: print(f"CLIP: Loading from HuggingFace...") from transformers import CLIPModel model = CLIPModel.from_pretrained(str(source_dir) if source_dir.exists() else "openai/clip-vit-base-patch16") state_dict = model.state_dict() state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()} save_file(state_dict, str(out_path)) print(f"CLIP: Saved to {out_path}") if bin_path.exists(): bin_path.unlink() print(f"CLIP: Removed {bin_path}") return out_path def _extract_sam3_state_dict(ckpt: dict) -> dict: """Extract SAM3 image model state dict from checkpoint (same logic as sam3._load_checkpoint).""" if "model" in ckpt and isinstance(ckpt["model"], dict): ckpt = ckpt["model"] sam3_image_ckpt = { k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k } return sam3_image_ckpt def convert_sam3(pt_path: Path, output_path: Path = None) -> Path: """Convert SAM3 sam3.pt to model.safetensors (image model weights only).""" output_path = output_path or pt_path.parent / "model.safetensors" print(f"SAM3: Loading from {pt_path}...") ckpt = torch.load(pt_path, map_location="cpu", weights_only=True) state_dict = _extract_sam3_state_dict(ckpt) state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()} save_file(state_dict, str(output_path)) print(f"SAM3: Saved to {output_path}") return output_path def copy_sam3_safetensors(source: Path, dest_dir: Path) -> Path: """Copy HF model.safetensors (detector_model.* keys) to SegEarth-OV. Pipeline maps keys on load.""" dest = dest_dir / "model.safetensors" if source.exists(): import shutil shutil.copy2(source, dest) print(f"SAM3: Copied {source} -> {dest}") return dest return None def main(): parser = argparse.ArgumentParser(description="Convert CLIP and SAM3 checkpoints to safetensors") parser.add_argument("--clip", action="store_true", help="Convert CLIP only") parser.add_argument("--sam3", action="store_true", help="Convert SAM3 only") parser.add_argument("--all", action="store_true", help="Convert all (default when no --clip/--sam3)") args = parser.parse_args() repo = Path(__file__).parent do_both = not args.clip and not args.sam3 if args.clip or do_both: clip_dir = repo / "OV" / "weights" / "backbone" / "clip-vit-base-patch16" if clip_dir.exists(): convert_clip(clip_dir) else: print(f"CLIP: {clip_dir} not found, skip") if args.sam3 or do_both: sam3_dir = repo / "OV-3" / "weights" / "backbone" / "sam3" hf_safetensors = Path("/data/projects/models/hf_models/facebook/sam3/model.safetensors") sam3_pt = sam3_dir / "sam3.pt" if hf_safetensors.exists(): copy_sam3_safetensors(hf_safetensors, sam3_dir) if sam3_pt.exists(): sam3_pt.unlink() print(f"SAM3: Removed {sam3_pt}") elif sam3_pt.exists(): st_path = convert_sam3(sam3_pt) sam3_pt.unlink() print(f"SAM3: Removed {sam3_pt}") else: print(f"SAM3: Neither {hf_safetensors} nor {sam3_pt} found, skip") print("Done.") if __name__ == "__main__": main()