SimCLR Model Checkpoint
This repository contains self-supervised weights for a SimCLR model. The weights have been converted from PyTorch .pth to .safetensors for safer and faster loading.
Usage
To use this model, you will need the SimCLR class definition from the accompanying files.
1. Install Dependencies
pip install torch torchvision safetensors huggingface_hub
2. Load the Model
import torch
import json
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from models.simclr import SimCLR # Ensure your model code is in the path
# 1. Download files
repo_id = "homeboi/luthra_simclr_im1k_r50"
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
# 2. Setup Configuration
with open(config_path, "r") as f:
config = json.load(f)
# 3. Initialize Encoder
if config["encoder_type"] == 'resnet50':
import torchvision.models as models
encoder = models.resnet50(pretrained=False)
elif config["encoder_type"] == 'vit_b':
PATCH_SIZE = config["patch_size"]
IMAGE_SIZE = config["image_size"]
HIDDEN_DIM = config["token_hidden_dim"]
MLP_DIM = config["mlp_dim"]
STRIDE = config["stride"]
encoder = models.VisionTransformer(
patch_size=PATCH_SIZE,
image_size=IMAGE_SIZE,
hidden_dim=HIDDEN_DIM,
mlp_dim=MLP_DIM,
num_layers=12,
num_heads=12,
)
# 4. Initialize SimCLR
model = SimCLR(
model=encoder,
dataset=config["dataset"],
width_multiplier=config["width_multiplier"],
hidden_dim=config["hidden_dim"],
projection_dim=config["projection_dim"],
image_size=config["image_size"],
patch_size=config["patch_size"],
stride=config["stride"],
token_hidden_dim=config["token_hidden_dim"],
mlp_dim=config["mlp_dim"]
)
# 5. Load Safetensors
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully from Safetensors!")
License
This project is licensed under the MIT License.
- Downloads last month
- 37
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support