YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
DNA-Bacteria-JEPA v3.1
A self-supervised foundation model for bacterial genome representation learning using Joint-Embedding Predictive Architecture (JEPA).
DNA-Bacteria-JEPA learns rich, transferable representations of bacterial DNA sequences without requiring any labels. Trained on 500,000 sequences from 894 bacterial genomes, the model captures meaningful genomic features that transfer to downstream tasks such as species classification, antimicrobial resistance prediction, and functional annotation.
This repository contains the first version of DNA-Bacteria-JEPA, trained on 20 genomes with 4.5M parameters as a proof of concept. We are currently scaling the model to v3.1 with 8.5M encoder parameters trained on 895 genomes and 500,000 sequences on an A100 GPU. The scaled model weights and updated benchmarks will be released here upon completion of training.
Table of Contents
- Quick Start
- Model Architecture
- Training
- Evaluation
- Intended Use
- Limitations
- Repository Structure
- Citation
Quick Start
Installation
pip install torch numpy
Loading the Model
import torch
# Load pretrained model
checkpoint = torch.load("dna_bacteria_jepa_v3.1.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
Extracting Representations
import torch
def tokenize_sequence(sequence):
"""Convert a DNA sequence to nucleotide token indices."""
nuc_to_idx = {"A": 1, "C": 2, "G": 3, "T": 4, "N": 5}
tokens = [nuc_to_idx.get(c, 5) for c in sequence.upper()]
return torch.tensor(tokens).unsqueeze(0)
# Example: extract embeddings for a DNA sequence
sequence = "ATCGATCGATCGATCGATCG..."
tokens = tokenize_sequence(sequence)
with torch.no_grad():
representations = model.context_encoder(tokens)
# Use mean pooling for sequence-level representation
seq_embedding = representations.mean(dim=1) # Shape: (1, 384)
Fine-Tuning for Downstream Tasks
import torch.nn as nn
class DownstreamClassifier(nn.Module):
def __init__(self, pretrained_encoder, num_classes):
super().__init__()
self.encoder = pretrained_encoder
self.classifier = nn.Sequential(
nn.Linear(384, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, num_classes)
)
def forward(self, x):
with torch.no_grad(): # Freeze encoder, or remove for full fine-tuning
features = self.encoder(x).mean(dim=1)
return self.classifier(features)
Model Architecture
DNA-Bacteria-JEPA is based on the Joint-Embedding Predictive Architecture (JEPA), adapted for genomic sequences. Rather than reconstructing raw input (as in masked language models), JEPA predicts latent representations of masked target regions from visible context β encouraging the model to learn high-level semantic features of DNA.
Available Versions
| v3.0 (this repo) | v3.1 (in training) | |
|---|---|---|
| Encoder parameters | 4.5M | 8.5M |
| Genomes | 20 | 895 |
| Sequences | ~50,000 | 500,000 |
| Status | Released | Training on A100 |
v3.1 Architecture (Scaled Model)
| Component | Details |
|---|---|
| Encoder parameters | 8.48M |
| Predictor parameters | 5.43M |
| Context Encoder | 6-layer Transformer, 6 heads, hidden dim 384, FFN dim 1024 |
| Target Encoder | EMA copy of context encoder (decay 0.996 β 1.0, cosine schedule) |
| Predictor | 4-layer Transformer, 6 heads, hidden dim 384 |
| Tokenization | Nucleotide-level (A, C, G, T, N) β vocabulary size 9 |
| Max Sequence Length | 1,024 tokens |
| Regularization | SIGReg (log-determinant covariance regularization) |
| Auxiliary objectives | Reverse complement equivariance, supervised contrastive, GC adversary |
Why JEPA for DNA?
Traditional masked language models for genomics (e.g., DNABERT, Nucleotide Transformer) reconstruct token-level inputs. JEPA instead operates in representation space, which:
- Avoids wasting capacity on low-level reconstruction details
- Learns more abstract, transferable features
- Is more robust to input noise and sequencing errors
- Prevents representation collapse via SIGReg regularization
SIGReg Regularization
A key challenge in JEPA training is representation collapse, where the model learns to output constant representations regardless of input. DNA-Bacteria-JEPA v3.1 addresses this with SIGReg (log-determinant regularization of the covariance eigenspectrum), which provides theoretically infinite resistance to collapse by penalizing low-variance embedding dimensions. Training health is monitored via RankMe (effective dimensionality), which remains above 370/384 throughout stable training.
GC Content Adversary
Bacterial genomes vary widely in GC content (% of G and C bases), providing a trivial shortcut for discriminating species. To force the model to learn deeper genomic features, a gradient-reversal adversary actively strips GC-content information from the representations during training.
Training
Data
| Property | Value |
|---|---|
| Total sequences | 500,000 |
| Source genomes | 895 bacterial species |
| Sequence preprocessing | Nucleotide-level tokenization |
| Training paradigm | Self-supervised (no labels required) |
Training Configuration (v3.1)
| Hyperparameter | Value |
|---|---|
| Optimizer | AdamW (weight decay 0.05) |
| Learning rate | 2e-4, cosine annealing with 10-epoch warmup |
| Effective batch size | 512 (256 Γ 2 gradient accumulation) |
| EMA decay (target encoder) | 0.996 β 1.0 (cosine schedule) |
| Masking strategy | Block masking, 4 blocks, ratio 0.15 β 0.50 (curriculum) |
| Block length | 3 β 15 tokens (curriculum) |
| Precision | bfloat16 |
| Hardware | NVIDIA A100-SXM4-40GB |
| Epochs | 200 |
| Framework | PyTorch |
Training Stability
v3.1 resolved representation collapse issues present in earlier versions through:
- SIGReg regularization (weight 10.0) replacing VICReg variance/covariance terms
- Lower learning rate (2e-4 vs 6e-4 in early attempts)
- GC adversary with boosted gradient reversal (weight 3.0) to prevent GC-content shortcut exploitation
Key training metrics tracked include:
- RankMe: Effective dimensionality of representations (target: close to 384)
- Representation std: Representation spread (collapse indicator β should grow steadily)
- GC|r|: Correlation with GC content (target: < 0.3)
- Invariance loss: L2 prediction error in latent space
Evaluation
Pretraining Metrics
| Metric | Value |
|---|---|
| Final prediction loss | TBD |
| Representation std | TBD |
| RankMe | ~371/384 (epoch 40) |
| Collapse detected | No β |
Downstream Performance
Note: Fill in with benchmark results as they become available.
| Task | Dataset | Metric | Score |
|---|---|---|---|
| Species classification | TBD | Accuracy | TBD |
| AMR prediction | TBD | AUROC | TBD |
| Functional annotation | TBD | F1 | TBD |
Representation Quality
To assess representation quality, we evaluate:
- Linear probing: Train a linear classifier on frozen representations
- k-NN accuracy: Nearest-neighbor classification in embedding space
- UMAP visualization: Cluster separation by taxonomy
Intended Use
Primary Use Cases
- Representation learning: Extract general-purpose embeddings for bacterial DNA sequences
- Transfer learning: Fine-tune on downstream genomic tasks with limited labeled data
- Exploratory analysis: Visualize and cluster bacterial genomes in learned embedding space
- Research: Study self-supervised learning dynamics on genomic data
Out-of-Scope Uses
- Clinical or diagnostic decision-making without further validation
- Non-bacterial genomes (eukaryotic, viral) β the model was trained exclusively on bacterial data
- Real-time pathogen identification in production systems
Limitations
- Bacterial genomes only: Trained on 895 bacterial species; generalization to unseen species or other domains (viral, eukaryotic) is not guaranteed.
- Sequence length: Limited to 1,024 nucleotide tokens. Longer genomic regions require chunking or sliding-window approaches.
- Training scale: At 8.5M parameters and 500K sequences, this is a proof-of-concept model. Performance will likely improve with scale.
- No fine-tuning benchmarks yet: Downstream task performance is under active evaluation.
Repository Structure
.
βββ config.json # Model configuration
βββ README.md # This model card
βββ dna_bacteria_jepa_v3.1.pt # Model weights (PyTorch checkpoint)
βββ tokenizer/
βββ vocab.json # Nucleotide vocabulary
Reproducibility
Environment
python>=3.9
torch>=2.0
numpy>=1.24
Training from Scratch
# Clone the training repository
git clone https://github.com/VUzan-bio/DNA-Bacteria-JEPA.git
cd DNA-Bacteria-JEPA
# Install dependencies
pip install -r requirements.txt
# Run pretraining
python scripts/01_pretrain_jepa.py \
--data-path data/processed/pretrain_sequences_expanded.csv \
--epochs 200 \
--batch-size 256 \
--grad-accum-steps 2 \
--max-samples 500000 \
--save-every 5 \
--sigreg-weight 10.0 \
--lr 0.0002 \
--wandb-run-name jepa-v3.1
Citation
If you use DNA-Bacteria-JEPA in your research, please cite:
@software{dna_bacteria_jepa_2025,
title = {DNA-Bacteria-JEPA: Self-Supervised Genomic Representation Learning
with Joint-Embedding Predictive Architecture},
author = {Valentin Uzan},
year = {2025},
version = {3.1},
url = {https://huggingface.co/VUzan/DNA-Bacteria-JEPA},
license = {Apache-2.0}
}
- Downloads last month
- -