AxoNet CLIP Stage 2
A CLIP-style model for multimodal neuron morphology understanding, enabling text-to-image retrieval and zero-shot classification.
Model Description
AxoNet CLIP extends the Stage 1 VAE with contrastive learning:
- Image Encoder: Frozen VAE encoder + learnable projection head
- Text Encoder: SciBERT + learnable projection head
- Joint Embedding Space: 512-dimensional, L2-normalized
Architecture
Image: [Mask] -> [VAE Encoder] -> [Projection] -> [512-dim embedding]
|
(frozen)
Text: [Description] -> [SciBERT] -> [Projection] -> [512-dim embedding]
|
(frozen)
Performance
| Metric | Value |
|---|---|
| Image-to-Text R@1 | 51.2% |
| Image-to-Text R@5 | 59.0% |
| Text-to-Image R@1 | 81.0% |
| Text-to-Image R@5 | 93.0% |
| Cell Type Zero-shot | 51.5% |
| Brain Region Zero-shot | 63.7% |
Training
- Base Model: broadinstitute/axonet-vae-stage1
- Text Encoder: allenai/scibert_scivocab_uncased
- Dataset: broadinstitute/axonet-neuromorpho-dataset
- Epochs: 20
- Batch size: 64
- Temperature: 0.07 (learnable)
- Hardware: 2x NVIDIA A100 80GB
Usage
Text-to-Image Retrieval
import torch
from huggingface_hub import hf_hub_download
# Download models
clip_path = hf_hub_download("broadinstitute/axonet-clip-stage2", "full_checkpoint/best.ckpt")
vae_path = hf_hub_download("broadinstitute/axonet-vae-stage1", "pytorch_model.bin")
# Load (requires axonet package)
from axonet.models.d3_swc_vae import load_model
from axonet.models.clip_modules import SegVAE2D_CLIP
from axonet.models.text_encoders import TransformerTextEncoder, ProjectedTextEncoder
# ... (see examples/retrieval.py for full code)
# Encode query
query = "pyramidal neuron from mouse hippocampus"
text_embed = text_encoder([query])
# Compute similarities with image database
similarities = text_embed @ image_embeds.T
top_matches = similarities.argsort(descending=True)[:10]
Zero-shot Classification
# Define categories
cell_types = ["pyramidal neuron", "interneuron", "granule cell", "Purkinje cell"]
prompts = [f"a {cls} neuron" for cls in cell_types]
# Encode prompts
text_embeds = text_encoder(prompts)
# Classify image
logits = image_embed @ text_embeds.T
probs = torch.softmax(logits / 0.07, dim=-1)
predicted_class = cell_types[probs.argmax()]
Example Queries
Queries that work well with this model:
- "pyramidal neuron from mouse cortex"
- "Purkinje cell from cerebellum"
- "interneuron with dense axonal arbor"
- "bipolar neuron from retina"
- "large motor neuron from spinal cord"
- "neuron from human temporal cortex"
Files
| File | Description |
|---|---|
pytorch_model.bin |
PyTorch state dict |
model.safetensors |
Safetensors format |
config.json |
Model configuration |
full_checkpoint/best.ckpt |
Full Lightning checkpoint |
Dependencies
- broadinstitute/axonet-vae-stage1 - Stage 1 VAE encoder
- allenai/scibert_scivocab_uncased - Text encoder
Citation
@misc{axonet2025,
author = {Hall, Giles},
title = {AxoNet: Multimodal Neuron Morphology Embeddings via 2D Projections},
year = {2025},
publisher = {HuggingFace},
howpublished = {\url{https://huggingface.co/broadinstitute/axonet-clip-stage2}}
}
License
MIT License
- Downloads last month
- 8
Model tree for broadinstitute/axonet-clip-stage2
Base model
allenai/scibert_scivocab_uncased