C2S-UC-Gemma-2B

GitHub

A fine-tuned Cell2Sentence model for Ulcerative Colitis (UC) single-cell RNA-seq cell type prediction.

This model understands the "language of biology" by treating gene expression lists as sentences. It achieves 89% accuracy on distinguishing 44 distinct cell types in human colon tissue, outperforming the base model significantly on this specific task.

Model Description

  • Task: Single-cell cell type prediction
  • Base Model: vandijklab/C2S-Scale-Gemma-2-2B
  • Finetuning Data: UC single-cell RNA-seq atlas (Human Colon)
  • Cell Types: 44 fine-grained L3 cell types
  • Input: List of genes ordered by expression level (Cell Sentence)
  • Output: Cell Type Label

Performance

The model was evaluated on a held-out test set of 500 cells.

Model Accuracy
Original (pre-trained) 3.60%
Fine-tuned (This Model) 89.00%
Improvement +85.40% (25x)

Usage

1. Installation

pip install transformers torch

2. Prediction Code

Use the following code to predict cell types. Note that the prompt format is critical for good performance.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load model
model_name = "Jyx0208/C2S-UC-Gemma-2B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

def predict_cell_type(cell_sentence, top_k=200):
    """
    Args:
        cell_sentence: Space-separated string of gene names ordered by expression level (descending)
    """
    # Create prompt (Cell2Sentence Standard Format)
    num_genes = len(cell_sentence.split()[:top_k])
    prompt = f"The following is a list of {num_genes} gene names ordered by descending expression level in a human cell. Your task is to give the cell type which this cell belongs to based on its gene expression.\nCell sentence: {cell_sentence}.\nThe cell type corresponding to these genes is:"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=50,
            do_sample=False
        )
        
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract prediction (remove prompt)
    prediction = answer[len(prompt):].strip().split('\n')[0]
    return prediction

# Example: Plasma IgA Cell
# (Top expressed genes: JCHAIN, IGHA1, IGHA2, MZB1, ...)
example_cell = "JCHAIN IGHA1 IGHA2 MZB1 TNFRSF17 SSR4 TXNDC5 FKBP11 SEC11C"
print(f"Prediction: {predict_cell_type(example_cell)}")

Training Details

  • Dataset: Human Ulcerative Colitis Atlas (311,920 cells)
  • Training Steps: 3,000
  • Batch Size: 32 (effective)
  • Learning Rate: 2e-5
  • Hardware: NVIDIA RTX 4090

Supported Cell Types (44 total)

Key cell types include:

  • Immune: Plasma IgA, Plasma IgG, Naive T, Treg, CD8+ Tem, B cell
  • Epithelial: Absorptive, Early_Absorptive, Goblet, Stem_TA_prolif
  • Stromal: Inflammatory Fibro, Villus-top Fibro, Crypt-bottom Fibro
  • Other: Active_Glia, Endothelial

Citation

If you use this model or the Cell2Sentence approach, please cite:

@article{cell2sentence2024,
  title={Cell2Sentence: Teaching Large Language Models the Language of Biology},
  author={Levine, Daniel and Rizvi, Syed and others},
  journal={bioRxiv},
  year={2024}
}

License

This model is fine-tuned from Gemma and is subject to the Gemma Terms of Use.

Downloads last month
7
Safetensors
Model size
3B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Jyx0208/C2S-UC-Gemma-2B

Finetuned
(2)
this model
Quantizations
1 model

Evaluation results