EES-Transformer V2 β€” Arabidopsis Tissue Classification Model

A dual-path transformer trained on 12,212 Arabidopsis thaliana RNA-seq samples for tissue classification and gene regulatory network inference from Extreme Expression Sets (EES).

Model Description

EES-Transformer V2 converts continuous gene expression profiles into discrete token sequences (genes at expression extremes) and processes them with a dual-path transformer that simultaneously performs:

  • Tissue classification (92% accuracy across 47 tissue types) from gene expression alone
  • Masked language modeling with tissue conditioning for gene representation learning
  • Attention-derived gene regulatory networks extracted post-hoc from learned attention patterns

Architecture

Gene Embeddings β†’ Shared Encoder (4 layers)
                         ↓
           β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
           ↓                           ↓
    Tissue Encoder (2 layers)    [+ Tissue Embedding]
           ↓                           ↓
    Tissue Head                  MLM Encoder (4 layers)
           ↓                           ↓
    Predict Tissue               MLM Head
    (from genes only)            (tissue-conditioned)

Training Details

Metric Value
Validation tissue accuracy 92.2%
Training tissue accuracy 99.2%
Number of tissues 47
Vocabulary size 46,728 tokens
Model parameters 144.6M
Hidden size 768
Attention heads 12
Shared encoder layers 4
Tissue branch layers 2
MLM branch layers 4
Epochs trained 65 (best at epoch 63)
Training hardware NVIDIA GeForce RTX 4090 (24GB)
Training time ~19 hours
Precision FP32

Files

  • best_tissue_model.pt β€” Model checkpoint (state dict + optimizer state)
  • config.json β€” Model configuration

Usage

import torch
import json
from src.model.ees_transformer_v2 import EESTransformerV2, EESV2Config

# Load config and model
with open('config.json') as f:
    config = EESV2Config(**json.load(f))

checkpoint = torch.load('best_tissue_model.pt', map_location='cpu')
model = EESTransformerV2(config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Full code: https://github.com/k821209/ees-transformer

Citation

[Preprint forthcoming]

Authors

Joo-Seok Park, Yejin Lee, Yang Jae Kang

Research Institute of Molecular Alchemy & Division of Life Science, Gyeongsang National University, Republic of Korea

Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support