C2S-UC-Gemma-2B
A fine-tuned Cell2Sentence model for Ulcerative Colitis (UC) single-cell RNA-seq cell type prediction.
This model understands the "language of biology" by treating gene expression lists as sentences. It achieves 89% accuracy on distinguishing 44 distinct cell types in human colon tissue, outperforming the base model significantly on this specific task.
Model Description
- Task: Single-cell cell type prediction
- Base Model: vandijklab/C2S-Scale-Gemma-2-2B
- Finetuning Data: UC single-cell RNA-seq atlas (Human Colon)
- Cell Types: 44 fine-grained L3 cell types
- Input: List of genes ordered by expression level (Cell Sentence)
- Output: Cell Type Label
Performance
The model was evaluated on a held-out test set of 500 cells.
| Model | Accuracy |
|---|---|
| Original (pre-trained) | 3.60% |
| Fine-tuned (This Model) | 89.00% |
| Improvement | +85.40% (25x) |
Usage
1. Installation
pip install transformers torch
2. Prediction Code
Use the following code to predict cell types. Note that the prompt format is critical for good performance.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load model
model_name = "Jyx0208/C2S-UC-Gemma-2B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
def predict_cell_type(cell_sentence, top_k=200):
"""
Args:
cell_sentence: Space-separated string of gene names ordered by expression level (descending)
"""
# Create prompt (Cell2Sentence Standard Format)
num_genes = len(cell_sentence.split()[:top_k])
prompt = f"The following is a list of {num_genes} gene names ordered by descending expression level in a human cell. Your task is to give the cell type which this cell belongs to based on its gene expression.\nCell sentence: {cell_sentence}.\nThe cell type corresponding to these genes is:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=False
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract prediction (remove prompt)
prediction = answer[len(prompt):].strip().split('\n')[0]
return prediction
# Example: Plasma IgA Cell
# (Top expressed genes: JCHAIN, IGHA1, IGHA2, MZB1, ...)
example_cell = "JCHAIN IGHA1 IGHA2 MZB1 TNFRSF17 SSR4 TXNDC5 FKBP11 SEC11C"
print(f"Prediction: {predict_cell_type(example_cell)}")
Training Details
- Dataset: Human Ulcerative Colitis Atlas (311,920 cells)
- Training Steps: 3,000
- Batch Size: 32 (effective)
- Learning Rate: 2e-5
- Hardware: NVIDIA RTX 4090
Supported Cell Types (44 total)
Key cell types include:
- Immune:
Plasma IgA,Plasma IgG,Naive T,Treg,CD8+ Tem,B cell - Epithelial:
Absorptive,Early_Absorptive,Goblet,Stem_TA_prolif - Stromal:
Inflammatory Fibro,Villus-top Fibro,Crypt-bottom Fibro - Other:
Active_Glia,Endothelial
Citation
If you use this model or the Cell2Sentence approach, please cite:
@article{cell2sentence2024,
title={Cell2Sentence: Teaching Large Language Models the Language of Biology},
author={Levine, Daniel and Rizvi, Syed and others},
journal={bioRxiv},
year={2024}
}
License
This model is fine-tuned from Gemma and is subject to the Gemma Terms of Use.
- Downloads last month
- 7
Model tree for Jyx0208/C2S-UC-Gemma-2B
Evaluation results
- Accuracy on UC Atlas (Custom)self-reported89.000