PlantGenoANN
PlantGenoANN is a plant genomic segmentation model that enables the prediction of various plant genomic elements at single-nucleotide resolution. The model is built upon the PlantBiMoE architecture with a 1D U-Net segmentation head, specifically designed for automated plant genome annotation. It predicts gene structures—including genes, CDSs, and exons—on both the forward and reverse strands. In addition, PlantGenoANN can serve as a long-context plant genomic foundation model (up to 49,152 bp), adaptable through fine-tuning to predict plant omic signal tracks such as RNA-seq or ATAC-seq.
Developed by: hu-lab
Model Sources
- Repository: PlantGenoANN
- GitHub: https://github.com/qzzhang0131/PlantGenoANN
How to use
The model requires the mamba-ssm and causal-conv1d libraries for the core backbone. You can retrieve both genomic feature probabilities and sequence embeddings using the following snippet:
import torch
from transformers import AutoTokenizer, AutoModel
# Load model and tokenizer
repo_id = "qzzhang/PlantGenoANN"
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
# The number of DNA tokens (excluding the [CLS] and [SEP] token) needs to be divisible by 8
# as required by the U-Net downsampling blocks.
sequences = ["ACTAGAGCGAGAGAAA","TTTGAGAGCGCGCGGA"]
# Tokenize
tokenized_sequences = tokenizer(
sequences,
return_tensors="pt",
padding="longest"
)["input_ids"]
# Infer
model.to("cuda")
model.eval()
with torch.no_grad():
outs = model(input_ids=tokenized_sequences.to("cuda"))
# Obtain the logits over the genomic features
# Shape: [batch, sequence_length, num_features]
logits = outs.logits
# Get probabilities associated with CDS on the forward strand (+)
pos_strand_cds_probs = model.get_feature_logits(feature="CDS", strand="+", logtis=logits).detach()
print(f"CDS probabilities on the forward strand: {pos_strand_cds_probs}")
# Get the sequence embeddings
# Shape: [batch, sequence_length, 1024]
hidden_states = outs.hidden_states.detach()
print(f"Sequence embeddings shape is: {hidden_states.shape}")
Architecture
PlantGenoANN is composed of the PlantBiMoE encoder (a 116M parameter foundation model) coupled with a custom U-Net segmentation head.
🛠️ Training Procedure
PlantGenoANN was trained for 30 hours on 4x NVIDIA A800-80G GPUs, processing a total of 18B tokens. The training utilized a high-quality dataset of 9 model plant genomes with their annotations. The model was optimized using AdamW (learning rate: 1e-4 and weight decay: 0.01) with a cosine learning rate scheduler, ensuring robust convergence across diverse plant genomic contexts.
BibTeX entry and citation info
- Downloads last month
- 328