YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

DNA-Bacteria-JEPA v3.1

A self-supervised foundation model for bacterial genome representation learning using Joint-Embedding Predictive Architecture (JEPA).

DNA-Bacteria-JEPA learns rich, transferable representations of bacterial DNA sequences without requiring any labels. Trained on 500,000 sequences from 894 bacterial genomes, the model captures meaningful genomic features that transfer to downstream tasks such as species classification, antimicrobial resistance prediction, and functional annotation.

This repository contains the first version of DNA-Bacteria-JEPA, trained on 20 genomes with 4.5M parameters as a proof of concept. We are currently scaling the model to v3.1 with 8.5M encoder parameters trained on 895 genomes and 500,000 sequences on an A100 GPU. The scaled model weights and updated benchmarks will be released here upon completion of training.

Table of Contents

Quick Start

Installation

pip install torch numpy

Loading the Model

import torch

# Load pretrained model
checkpoint = torch.load("dna_bacteria_jepa_v3.1.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

Extracting Representations

import torch

def tokenize_sequence(sequence):
    """Convert a DNA sequence to nucleotide token indices."""
    nuc_to_idx = {"A": 1, "C": 2, "G": 3, "T": 4, "N": 5}
    tokens = [nuc_to_idx.get(c, 5) for c in sequence.upper()]
    return torch.tensor(tokens).unsqueeze(0)

# Example: extract embeddings for a DNA sequence
sequence = "ATCGATCGATCGATCGATCG..."
tokens = tokenize_sequence(sequence)

with torch.no_grad():
    representations = model.context_encoder(tokens)
    # Use mean pooling for sequence-level representation
    seq_embedding = representations.mean(dim=1)  # Shape: (1, 384)

Fine-Tuning for Downstream Tasks

import torch.nn as nn

class DownstreamClassifier(nn.Module):
    def __init__(self, pretrained_encoder, num_classes):
        super().__init__()
        self.encoder = pretrained_encoder
        self.classifier = nn.Sequential(
            nn.Linear(384, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        with torch.no_grad():  # Freeze encoder, or remove for full fine-tuning
            features = self.encoder(x).mean(dim=1)
        return self.classifier(features)

Model Architecture

DNA-Bacteria-JEPA is based on the Joint-Embedding Predictive Architecture (JEPA), adapted for genomic sequences. Rather than reconstructing raw input (as in masked language models), JEPA predicts latent representations of masked target regions from visible context β€” encouraging the model to learn high-level semantic features of DNA.

Available Versions

v3.0 (this repo) v3.1 (in training)
Encoder parameters 4.5M 8.5M
Genomes 20 895
Sequences ~50,000 500,000
Status Released Training on A100

v3.1 Architecture (Scaled Model)

Component Details
Encoder parameters 8.48M
Predictor parameters 5.43M
Context Encoder 6-layer Transformer, 6 heads, hidden dim 384, FFN dim 1024
Target Encoder EMA copy of context encoder (decay 0.996 β†’ 1.0, cosine schedule)
Predictor 4-layer Transformer, 6 heads, hidden dim 384
Tokenization Nucleotide-level (A, C, G, T, N) β€” vocabulary size 9
Max Sequence Length 1,024 tokens
Regularization SIGReg (log-determinant covariance regularization)
Auxiliary objectives Reverse complement equivariance, supervised contrastive, GC adversary

Why JEPA for DNA?

Traditional masked language models for genomics (e.g., DNABERT, Nucleotide Transformer) reconstruct token-level inputs. JEPA instead operates in representation space, which:

  • Avoids wasting capacity on low-level reconstruction details
  • Learns more abstract, transferable features
  • Is more robust to input noise and sequencing errors
  • Prevents representation collapse via SIGReg regularization

SIGReg Regularization

A key challenge in JEPA training is representation collapse, where the model learns to output constant representations regardless of input. DNA-Bacteria-JEPA v3.1 addresses this with SIGReg (log-determinant regularization of the covariance eigenspectrum), which provides theoretically infinite resistance to collapse by penalizing low-variance embedding dimensions. Training health is monitored via RankMe (effective dimensionality), which remains above 370/384 throughout stable training.

GC Content Adversary

Bacterial genomes vary widely in GC content (% of G and C bases), providing a trivial shortcut for discriminating species. To force the model to learn deeper genomic features, a gradient-reversal adversary actively strips GC-content information from the representations during training.

Training

Data

Property Value
Total sequences 500,000
Source genomes 895 bacterial species
Sequence preprocessing Nucleotide-level tokenization
Training paradigm Self-supervised (no labels required)

Training Configuration (v3.1)

Hyperparameter Value
Optimizer AdamW (weight decay 0.05)
Learning rate 2e-4, cosine annealing with 10-epoch warmup
Effective batch size 512 (256 Γ— 2 gradient accumulation)
EMA decay (target encoder) 0.996 β†’ 1.0 (cosine schedule)
Masking strategy Block masking, 4 blocks, ratio 0.15 β†’ 0.50 (curriculum)
Block length 3 β†’ 15 tokens (curriculum)
Precision bfloat16
Hardware NVIDIA A100-SXM4-40GB
Epochs 200
Framework PyTorch

Training Stability

v3.1 resolved representation collapse issues present in earlier versions through:

  • SIGReg regularization (weight 10.0) replacing VICReg variance/covariance terms
  • Lower learning rate (2e-4 vs 6e-4 in early attempts)
  • GC adversary with boosted gradient reversal (weight 3.0) to prevent GC-content shortcut exploitation

Key training metrics tracked include:

  • RankMe: Effective dimensionality of representations (target: close to 384)
  • Representation std: Representation spread (collapse indicator β€” should grow steadily)
  • GC|r|: Correlation with GC content (target: < 0.3)
  • Invariance loss: L2 prediction error in latent space

Evaluation

Pretraining Metrics

Metric Value
Final prediction loss TBD
Representation std TBD
RankMe ~371/384 (epoch 40)
Collapse detected No βœ…

Downstream Performance

Note: Fill in with benchmark results as they become available.

Task Dataset Metric Score
Species classification TBD Accuracy TBD
AMR prediction TBD AUROC TBD
Functional annotation TBD F1 TBD

Representation Quality

To assess representation quality, we evaluate:

  • Linear probing: Train a linear classifier on frozen representations
  • k-NN accuracy: Nearest-neighbor classification in embedding space
  • UMAP visualization: Cluster separation by taxonomy

Intended Use

Primary Use Cases

  • Representation learning: Extract general-purpose embeddings for bacterial DNA sequences
  • Transfer learning: Fine-tune on downstream genomic tasks with limited labeled data
  • Exploratory analysis: Visualize and cluster bacterial genomes in learned embedding space
  • Research: Study self-supervised learning dynamics on genomic data

Out-of-Scope Uses

  • Clinical or diagnostic decision-making without further validation
  • Non-bacterial genomes (eukaryotic, viral) β€” the model was trained exclusively on bacterial data
  • Real-time pathogen identification in production systems

Limitations

  • Bacterial genomes only: Trained on 895 bacterial species; generalization to unseen species or other domains (viral, eukaryotic) is not guaranteed.
  • Sequence length: Limited to 1,024 nucleotide tokens. Longer genomic regions require chunking or sliding-window approaches.
  • Training scale: At 8.5M parameters and 500K sequences, this is a proof-of-concept model. Performance will likely improve with scale.
  • No fine-tuning benchmarks yet: Downstream task performance is under active evaluation.

Repository Structure

.
β”œβ”€β”€ config.json                    # Model configuration
β”œβ”€β”€ README.md                      # This model card
β”œβ”€β”€ dna_bacteria_jepa_v3.1.pt      # Model weights (PyTorch checkpoint)
└── tokenizer/
    └── vocab.json                 # Nucleotide vocabulary

Reproducibility

Environment

python>=3.9
torch>=2.0
numpy>=1.24

Training from Scratch

# Clone the training repository
git clone https://github.com/VUzan-bio/DNA-Bacteria-JEPA.git
cd DNA-Bacteria-JEPA

# Install dependencies
pip install -r requirements.txt

# Run pretraining
python scripts/01_pretrain_jepa.py \
  --data-path data/processed/pretrain_sequences_expanded.csv \
  --epochs 200 \
  --batch-size 256 \
  --grad-accum-steps 2 \
  --max-samples 500000 \
  --save-every 5 \
  --sigreg-weight 10.0 \
  --lr 0.0002 \
  --wandb-run-name jepa-v3.1

Citation

If you use DNA-Bacteria-JEPA in your research, please cite:

@software{dna_bacteria_jepa_2025,
  title     = {DNA-Bacteria-JEPA: Self-Supervised Genomic Representation Learning
               with Joint-Embedding Predictive Architecture},
  author    = {Valentin Uzan},
  year      = {2025},
  version   = {3.1},
  url       = {https://huggingface.co/VUzan/DNA-Bacteria-JEPA},
  license   = {Apache-2.0}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support