CLIP Insect Sex Classifier

A custom CLIP-based binary classifier for insect sex detection (female vs male) trained on an imagefolder dataset with augmentation.

Model Summary

  • Base model: openai/clip-vit-base-patch32
  • Head: learnable prompt vectors + cosine-similarity classifier
  • Labels: female, male

Intended Use

  • Research and prototyping for insect sex classification from images.
  • Educational and pipeline-integration use-cases.

Out of Scope

  • Safety-critical decisions.
  • Use on species/domains not represented in your data without additional validation.

How To Use

1) Install dependencies

pip install torch torchvision transformers huggingface_hub safetensors pillow numpy

2) Inference example (custom class required)

import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch import nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from transformers import CLIPVisionModelWithProjection

class CLIPBinaryClassifier(nn.Module):
    def __init__(self, clip_name: str, n_classes: int):
        super().__init__()
        self.vision = CLIPVisionModelWithProjection.from_pretrained(clip_name)
        self.logit_scale = nn.Parameter(torch.tensor(np.log(100.0)))
        self.prompt_vecs = nn.Parameter(
            torch.empty(n_classes, self.vision.config.projection_dim)
        )

    def forward(self, pixel_values):
        img_embeds = self.vision(pixel_values).image_embeds
        txt_embeds = self.prompt_vecs
        logits = torch.cosine_similarity(
            img_embeds.unsqueeze(1), txt_embeds.unsqueeze(0), dim=-1
        ) * self.logit_scale.exp()
        return {"logits": logits}

repo_id = "yashm/clip-insect-sex"
model = CLIPBinaryClassifier("openai/clip-vit-base-patch32", n_classes=2)
state_dict = load_file(
    hf_hub_download(repo_id=repo_id, filename="model.safetensors", repo_type="model")
)
model.load_state_dict(state_dict, strict=True)
model.eval()

label_names = ["female", "male"]
img_size = 224
preprocess = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
])

image = Image.open("path/to/image.jpg").convert("RGB")
x = preprocess(image).unsqueeze(0)

with torch.no_grad():
    logits = model(x)["logits"]
    probs = logits.softmax(-1).squeeze(0).cpu().numpy()

pred_idx = int(np.argmax(probs))
print("Prediction:", label_names[pred_idx])
print("Probabilities:", dict(zip(label_names, probs.round(4))))

Training Data

  • Source folders: data/ and augmented variants in augmented_data/.
  • Species context: Gryllus bimaculatus.
  • Dataset reuse is permission-controlled by owner.

Evaluation

  • Validation accuracy: 100.0%
  • Split: Train/Test = 192/48 (seed=42, test_size=0.2)

Limitations

  • Performance depends heavily on image quality, camera angle, and lighting.
  • Label noise and class imbalance can significantly affect confidence calibration.

License

This model card uses other. Confirm compatibility with base model and your intended redistribution.

Citation

If this model supports your work, cite your project/repository here.

  • Dataset size: 240 images
Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
87.9M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for yashm/clip-insect-sex

Finetuned
(115)
this model