SimCLR Model Checkpoint

This repository contains self-supervised weights for a SimCLR model. The weights have been converted from PyTorch .pth to .safetensors for safer and faster loading.

Usage

To use this model, you will need the SimCLR class definition from the accompanying files.

1. Install Dependencies

pip install torch torchvision safetensors huggingface_hub

2. Load the Model

import torch
import json
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from models.simclr import SimCLR # Ensure your model code is in the path

# 1. Download files
repo_id = "homeboi/luthra_simclr_im1k_r50"
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")

# 2. Setup Configuration
with open(config_path, "r") as f:
    config = json.load(f)

# 3. Initialize Encoder
if config["encoder_type"] == 'resnet50':
    import torchvision.models as models
    encoder = models.resnet50(pretrained=False)
elif config["encoder_type"] == 'vit_b':
    PATCH_SIZE = config["patch_size"]
    IMAGE_SIZE = config["image_size"]
    HIDDEN_DIM = config["token_hidden_dim"]
    MLP_DIM = config["mlp_dim"]
    STRIDE = config["stride"]
    encoder = models.VisionTransformer(
            patch_size=PATCH_SIZE,
            image_size=IMAGE_SIZE,
            hidden_dim=HIDDEN_DIM,
            mlp_dim=MLP_DIM,
            num_layers=12,
            num_heads=12,
        )

# 4. Initialize SimCLR
model = SimCLR(
    model=encoder,
    dataset=config["dataset"],
    width_multiplier=config["width_multiplier"],
    hidden_dim=config["hidden_dim"],
    projection_dim=config["projection_dim"],
    image_size=config["image_size"],
    patch_size=config["patch_size"],
    stride=config["stride"],
    token_hidden_dim=config["token_hidden_dim"],
    mlp_dim=config["mlp_dim"]
)

# 5. Load Safetensors
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
model.eval()

print("Model loaded successfully from Safetensors!")

License

This project is licensed under the MIT License.

Downloads last month
37
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train homeboi/luthra_simclr_im1k_r50