PlantGenoAnn

PlantGenoAnn is a plant genomic segmentation model that enables the prediction of various plant genomic elements at single-nucleotide resolution. The model is built upon the PlantBiMoE architecture with a 1D U-Net segmentation head, specifically designed for automated plant genome annotation. It predicts gene structures—including genes, CDSs, and exons—on both the forward and reverse strands. In addition, PlantGenoAnn can serve as a long-context plant genomic foundation model (up to 49,152 bp), adaptable through fine-tuning to predict plant omic signal tracks such as RNA-seq or ATAC-seq.

Developed by: hu-lab

Model Sources

How to use

The model requires the mamba-ssm and causal-conv1d libraries for the core backbone. You can retrieve both genomic feature probabilities and sequence embeddings using the following snippet:

import torch
from transformers import AutoTokenizer, AutoModel

# Load model and tokenizer
repo_id = "qzzhang/PlantGenoAnn-model-plants"
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)

# The number of DNA tokens (excluding the [CLS] and [SEP] token) needs to be divisible by 8 
# as required by the U-Net downsampling blocks. 
sequences = ["ACTAGAGCGAGAGAAA","TTTGAGAGCGCGCGGA"] 

# Tokenize
tokenized_sequences = tokenizer(
    sequences, 
    return_tensors="pt", 
    padding="longest"
)["input_ids"]

# Infer
model.to("cuda")
model.eval()
with torch.no_grad():
    outs = model(input_ids=tokenized_sequences.to("cuda"))

# Obtain the logits over the genomic features
# Shape: [batch, sequence_length, num_features]
logits = outs.logits

# Get probabilities associated with CDS on the forward strand (+)
pos_strand_cds_probs = model.get_feature_logits(feature="CDS", strand="+", logtis=logits).detach()
print(f"CDS probabilities on the forward strand: {pos_strand_cds_probs}")

# Get the sequence embeddings
# Shape: [batch, sequence_length, 1024]
hidden_states = outs.hidden_states.detach()
print(f"Sequence embeddings shape is: {hidden_states.shape}")

Architecture

PlantGenoAnn-model-plants is composed of the PlantBiMoE encoder (a 116M parameter plant genomic foundation model) coupled with a custom 1D U-Net segmentation head.

🛠️ Training Procedure

PlantGenoAnn-model-plants was trained for 30 hours on 4 x NVIDIA A800-80G GPUs, processing a total of 18B tokens. The training utilized a high-quality dataset of 9 model plant genomes with their annotations. The model was optimized using AdamW (learning rate: 1e-4 and weight decay: 0.01) with a cosine learning rate scheduler, ensuring robust convergence across diverse plant genomic contexts.

BibTeX entry and citation info

Coming soon.

Downloads last month
438
Safetensors
Model size
0.2B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support