AxoNet CLIP Stage 2

A CLIP-style model for multimodal neuron morphology understanding, enabling text-to-image retrieval and zero-shot classification.

Model Description

AxoNet CLIP extends the Stage 1 VAE with contrastive learning:

  • Image Encoder: Frozen VAE encoder + learnable projection head
  • Text Encoder: SciBERT + learnable projection head
  • Joint Embedding Space: 512-dimensional, L2-normalized

Architecture

Image: [Mask] -> [VAE Encoder] -> [Projection] -> [512-dim embedding]
                      |
                   (frozen)

Text:  [Description] -> [SciBERT] -> [Projection] -> [512-dim embedding]
                            |
                         (frozen)

Performance

Metric Value
Image-to-Text R@1 51.2%
Image-to-Text R@5 59.0%
Text-to-Image R@1 81.0%
Text-to-Image R@5 93.0%
Cell Type Zero-shot 51.5%
Brain Region Zero-shot 63.7%

Training

Usage

Text-to-Image Retrieval

import torch
from huggingface_hub import hf_hub_download

# Download models
clip_path = hf_hub_download("broadinstitute/axonet-clip-stage2", "full_checkpoint/best.ckpt")
vae_path = hf_hub_download("broadinstitute/axonet-vae-stage1", "pytorch_model.bin")

# Load (requires axonet package)
from axonet.models.d3_swc_vae import load_model
from axonet.models.clip_modules import SegVAE2D_CLIP
from axonet.models.text_encoders import TransformerTextEncoder, ProjectedTextEncoder

# ... (see examples/retrieval.py for full code)

# Encode query
query = "pyramidal neuron from mouse hippocampus"
text_embed = text_encoder([query])

# Compute similarities with image database
similarities = text_embed @ image_embeds.T
top_matches = similarities.argsort(descending=True)[:10]

Zero-shot Classification

# Define categories
cell_types = ["pyramidal neuron", "interneuron", "granule cell", "Purkinje cell"]
prompts = [f"a {cls} neuron" for cls in cell_types]

# Encode prompts
text_embeds = text_encoder(prompts)

# Classify image
logits = image_embed @ text_embeds.T
probs = torch.softmax(logits / 0.07, dim=-1)
predicted_class = cell_types[probs.argmax()]

Example Queries

Queries that work well with this model:

  • "pyramidal neuron from mouse cortex"
  • "Purkinje cell from cerebellum"
  • "interneuron with dense axonal arbor"
  • "bipolar neuron from retina"
  • "large motor neuron from spinal cord"
  • "neuron from human temporal cortex"

Files

File Description
pytorch_model.bin PyTorch state dict
model.safetensors Safetensors format
config.json Model configuration
full_checkpoint/best.ckpt Full Lightning checkpoint

Dependencies

Citation

@misc{axonet2025,
  author = {Hall, Giles},
  title = {AxoNet: Multimodal Neuron Morphology Embeddings via 2D Projections},
  year = {2025},
  publisher = {HuggingFace},
  howpublished = {\url{https://huggingface.co/broadinstitute/axonet-clip-stage2}}
}

License

MIT License

Downloads last month
8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for broadinstitute/axonet-clip-stage2

Finetuned
(98)
this model

Dataset used to train broadinstitute/axonet-clip-stage2

Space using broadinstitute/axonet-clip-stage2 1