midmid3 / convert_checkpoint.py
markury's picture
Initial commit
d171350
"""Convert a midmid PyTorch checkpoint to safetensors + config.json.
Usage:
python convert_checkpoint.py path/to/best.pt --output-dir ./model_upload
This produces:
model_upload/model.safetensors (weights only, no pickle)
model_upload/config.json (model hyperparameters)
Then upload to HF:
huggingface-cli upload markury/midmid3-19m-0326 ./model_upload
"""
import argparse
import json
from pathlib import Path
import torch
from safetensors.torch import save_file
def main():
parser = argparse.ArgumentParser(description="Convert midmid checkpoint to safetensors")
parser.add_argument("checkpoint", type=Path, help="Path to .pt checkpoint")
parser.add_argument("--output-dir", type=Path, default=Path("model_upload"),
help="Output directory (default: ./model_upload)")
args = parser.parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
# Save config
config = ckpt["config"]
config_path = args.output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Config saved: {config_path}")
print(f" {json.dumps(config, indent=2)}")
# Save weights as safetensors
state_dict = ckpt["model_state_dict"]
safetensors_path = args.output_dir / "model.safetensors"
save_file(state_dict, str(safetensors_path))
print(f"Weights saved: {safetensors_path}")
# Summary
n_params = sum(p.numel() for p in state_dict.values())
print(f" {n_params:,} parameters ({n_params / 1e6:.1f}M)")
print(f"\nUpload to HF with:")
print(f" huggingface-cli upload markury/midmid3-19m-0326 {args.output_dir}")
if __name__ == "__main__":
main()