SegEarth-OV / convert_to_safetensors.py
Dingyi111's picture
Duplicate from BiliSakura/SegEarth-OV
fabc606
#!/usr/bin/env python3
"""
Convert CLIP and SAM3 checkpoints to safetensors format.
Run from repo root: python convert_to_safetensors.py
"""
import argparse
from pathlib import Path
import torch
from safetensors.torch import save_file
def convert_clip(source_dir: Path, output_dir: Path = None) -> Path:
"""Convert CLIP pytorch_model.bin to model.safetensors."""
output_dir = output_dir or source_dir
bin_path = source_dir / "pytorch_model.bin"
out_path = output_dir / "model.safetensors"
if out_path.exists():
print(f"CLIP: {out_path} already exists, skip")
return out_path
if bin_path.exists():
print(f"CLIP: Loading from {bin_path}...")
state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
else:
print(f"CLIP: Loading from HuggingFace...")
from transformers import CLIPModel
model = CLIPModel.from_pretrained(str(source_dir) if source_dir.exists() else "openai/clip-vit-base-patch16")
state_dict = model.state_dict()
state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()}
save_file(state_dict, str(out_path))
print(f"CLIP: Saved to {out_path}")
if bin_path.exists():
bin_path.unlink()
print(f"CLIP: Removed {bin_path}")
return out_path
def _extract_sam3_state_dict(ckpt: dict) -> dict:
"""Extract SAM3 image model state dict from checkpoint (same logic as sam3._load_checkpoint)."""
if "model" in ckpt and isinstance(ckpt["model"], dict):
ckpt = ckpt["model"]
sam3_image_ckpt = {
k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k
}
return sam3_image_ckpt
def convert_sam3(pt_path: Path, output_path: Path = None) -> Path:
"""Convert SAM3 sam3.pt to model.safetensors (image model weights only)."""
output_path = output_path or pt_path.parent / "model.safetensors"
print(f"SAM3: Loading from {pt_path}...")
ckpt = torch.load(pt_path, map_location="cpu", weights_only=True)
state_dict = _extract_sam3_state_dict(ckpt)
state_dict = {k: v.float() if v.dtype in (torch.float16, torch.bfloat16) else v for k, v in state_dict.items()}
save_file(state_dict, str(output_path))
print(f"SAM3: Saved to {output_path}")
return output_path
def copy_sam3_safetensors(source: Path, dest_dir: Path) -> Path:
"""Copy HF model.safetensors (detector_model.* keys) to SegEarth-OV. Pipeline maps keys on load."""
dest = dest_dir / "model.safetensors"
if source.exists():
import shutil
shutil.copy2(source, dest)
print(f"SAM3: Copied {source} -> {dest}")
return dest
return None
def main():
parser = argparse.ArgumentParser(description="Convert CLIP and SAM3 checkpoints to safetensors")
parser.add_argument("--clip", action="store_true", help="Convert CLIP only")
parser.add_argument("--sam3", action="store_true", help="Convert SAM3 only")
parser.add_argument("--all", action="store_true", help="Convert all (default when no --clip/--sam3)")
args = parser.parse_args()
repo = Path(__file__).parent
do_both = not args.clip and not args.sam3
if args.clip or do_both:
clip_dir = repo / "OV" / "weights" / "backbone" / "clip-vit-base-patch16"
if clip_dir.exists():
convert_clip(clip_dir)
else:
print(f"CLIP: {clip_dir} not found, skip")
if args.sam3 or do_both:
sam3_dir = repo / "OV-3" / "weights" / "backbone" / "sam3"
hf_safetensors = Path("/data/projects/models/hf_models/facebook/sam3/model.safetensors")
sam3_pt = sam3_dir / "sam3.pt"
if hf_safetensors.exists():
copy_sam3_safetensors(hf_safetensors, sam3_dir)
if sam3_pt.exists():
sam3_pt.unlink()
print(f"SAM3: Removed {sam3_pt}")
elif sam3_pt.exists():
st_path = convert_sam3(sam3_pt)
sam3_pt.unlink()
print(f"SAM3: Removed {sam3_pt}")
else:
print(f"SAM3: Neither {hf_safetensors} nor {sam3_pt} found, skip")
print("Done.")
if __name__ == "__main__":
main()