| |
| """ |
| Convert CLIP and SAM3 checkpoints to safetensors format. |
| Run from repo root: python convert_to_safetensors.py |
| """ |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import save_file |
|
|
|
|
| def convert_clip(source_dir: Path, output_dir: Path = None) -> Path: |
| """Convert CLIP pytorch_model.bin to model.safetensors.""" |
| output_dir = output_dir or source_dir |
| bin_path = source_dir / "pytorch_model.bin" |
| out_path = output_dir / "model.safetensors" |
|
|
| if out_path.exists(): |
| print(f"CLIP: {out_path} already exists, skip") |
| return out_path |
|
|
| if bin_path.exists(): |
| print(f"CLIP: Loading from {bin_path}...") |
| state_dict = torch.load(bin_path, map_location="cpu", weights_only=True) |
| else: |
| print(f"CLIP: Loading from HuggingFace...") |
| from transformers import CLIPModel |
| model = CLIPModel.from_pretrained(str(source_dir) if source_dir.exists() else "openai/clip-vit-base-patch16") |
| state_dict = model.state_dict() |
|
|
| state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()} |
| save_file(state_dict, str(out_path)) |
| print(f"CLIP: Saved to {out_path}") |
| if bin_path.exists(): |
| bin_path.unlink() |
| print(f"CLIP: Removed {bin_path}") |
| return out_path |
|
|
|
|
| def _extract_sam3_state_dict(ckpt: dict) -> dict: |
| """Extract SAM3 image model state dict from checkpoint (same logic as sam3._load_checkpoint).""" |
| if "model" in ckpt and isinstance(ckpt["model"], dict): |
| ckpt = ckpt["model"] |
| sam3_image_ckpt = { |
| k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k |
| } |
| return sam3_image_ckpt |
|
|
|
|
| def convert_sam3(pt_path: Path, output_path: Path = None) -> Path: |
| """Convert SAM3 sam3.pt to model.safetensors (image model weights only).""" |
| output_path = output_path or pt_path.parent / "model.safetensors" |
|
|
| print(f"SAM3: Loading from {pt_path}...") |
| ckpt = torch.load(pt_path, map_location="cpu", weights_only=True) |
| state_dict = _extract_sam3_state_dict(ckpt) |
|
|
| state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()} |
| save_file(state_dict, str(output_path)) |
| print(f"SAM3: Saved to {output_path}") |
| return output_path |
|
|
|
|
| def copy_sam3_safetensors(source: Path, dest_dir: Path) -> Path: |
| """Copy HF model.safetensors (detector_model.* keys) to SegEarth-OV. Pipeline maps keys on load.""" |
| dest = dest_dir / "model.safetensors" |
| if source.exists(): |
| import shutil |
| shutil.copy2(source, dest) |
| print(f"SAM3: Copied {source} -> {dest}") |
| return dest |
| return None |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert CLIP and SAM3 checkpoints to safetensors") |
| parser.add_argument("--clip", action="store_true", help="Convert CLIP only") |
| parser.add_argument("--sam3", action="store_true", help="Convert SAM3 only") |
| parser.add_argument("--all", action="store_true", help="Convert all (default when no --clip/--sam3)") |
| args = parser.parse_args() |
|
|
| repo = Path(__file__).parent |
| do_both = not args.clip and not args.sam3 |
|
|
| if args.clip or do_both: |
| clip_dir = repo / "OV" / "weights" / "backbone" / "clip-vit-base-patch16" |
| if clip_dir.exists(): |
| convert_clip(clip_dir) |
| else: |
| print(f"CLIP: {clip_dir} not found, skip") |
|
|
| if args.sam3 or do_both: |
| sam3_dir = repo / "OV-3" / "weights" / "backbone" / "sam3" |
| hf_safetensors = Path("/data/projects/models/hf_models/facebook/sam3/model.safetensors") |
| sam3_pt = sam3_dir / "sam3.pt" |
| if hf_safetensors.exists(): |
| copy_sam3_safetensors(hf_safetensors, sam3_dir) |
| if sam3_pt.exists(): |
| sam3_pt.unlink() |
| print(f"SAM3: Removed {sam3_pt}") |
| elif sam3_pt.exists(): |
| st_path = convert_sam3(sam3_pt) |
| sam3_pt.unlink() |
| print(f"SAM3: Removed {sam3_pt}") |
| else: |
| print(f"SAM3: Neither {hf_safetensors} nor {sam3_pt} found, skip") |
|
|
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|