Unify repo: default 24l-512d at root, add size variants via subfolder
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- 24l-512d/config.json +28 -0
- 24l-512d/configuration_genemamba.py +97 -0
- 24l-512d/model.safetensors +3 -0
- 24l-512d/modeling_genemamba.py +395 -0
- 24l-512d/modeling_outputs.py +81 -0
- 24l-512d/special_tokens_map.json +4 -0
- 24l-512d/tokenizer.json +0 -0
- 24l-512d/tokenizer_config.json +8 -0
- 24l-768d/config.json +28 -0
- 24l-768d/configuration_genemamba.py +97 -0
- 24l-768d/model.safetensors +3 -0
- 24l-768d/modeling_genemamba.py +395 -0
- 24l-768d/modeling_outputs.py +81 -0
- 24l-768d/special_tokens_map.json +4 -0
- 24l-768d/tokenizer.json +0 -0
- 24l-768d/tokenizer_config.json +8 -0
- 48l-512d/config.json +28 -0
- 48l-512d/configuration_genemamba.py +97 -0
- 48l-512d/model.safetensors +3 -0
- 48l-512d/modeling_genemamba.py +395 -0
- 48l-512d/modeling_outputs.py +81 -0
- 48l-512d/special_tokens_map.json +4 -0
- 48l-512d/tokenizer.json +0 -0
- 48l-512d/tokenizer_config.json +8 -0
- 48l-768d/config.json +28 -0
- 48l-768d/configuration_genemamba.py +97 -0
- 48l-768d/model.safetensors +3 -0
- 48l-768d/modeling_genemamba.py +395 -0
- 48l-768d/modeling_outputs.py +81 -0
- 48l-768d/special_tokens_map.json +4 -0
- 48l-768d/tokenizer.json +0 -0
- 48l-768d/tokenizer_config.json +8 -0
- README.md +133 -0
- config.json +28 -0
- configuration_genemamba.py +97 -0
- examples/00_preprocess_to_input_ids.py +75 -0
- examples/01_extract_embeddings.py +150 -0
- examples/downstream/10_finetune_classification.py +248 -0
- examples/downstream/11_zero_shot_logreg.py +98 -0
- examples/downstream/12_batch_integration_eval.py +79 -0
- examples/downstream/20_continue_pretraining_reference.py +265 -0
- examples/downstream/21_pretrain_from_scratch_reference.py +280 -0
- examples/downstream/README.md +35 -0
- examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_with_label.py +378 -0
- examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label.py +161 -0
- examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py +197 -0
- model.safetensors +3 -0
- modeling_genemamba.py +395 -0
- modeling_outputs.py +81 -0
- special_tokens_map.json +4 -0
24l-512d/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "genemamba",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GeneMambaModel"
|
| 5 |
+
],
|
| 6 |
+
"vocab_size": 25426,
|
| 7 |
+
"max_position_embeddings": 2048,
|
| 8 |
+
"hidden_size": 512,
|
| 9 |
+
"num_hidden_layers": 24,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mamba_mode": "gate",
|
| 14 |
+
"embedding_pooling": "mean",
|
| 15 |
+
"num_labels": 2,
|
| 16 |
+
"pad_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.40.2",
|
| 22 |
+
"auto_map": {
|
| 23 |
+
"AutoConfig": "configuration_genemamba.GeneMambaConfig",
|
| 24 |
+
"AutoModel": "modeling_genemamba.GeneMambaModel",
|
| 25 |
+
"AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
|
| 26 |
+
"AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
|
| 27 |
+
}
|
| 28 |
+
}
|
24l-512d/configuration_genemamba.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GeneMamba model.
|
| 3 |
+
Defines all hyperparameters and settings for the GeneMamba architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneMambaConfig(PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for GeneMamba model.
|
| 13 |
+
|
| 14 |
+
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
|
| 15 |
+
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (int, optional, defaults to 25426):
|
| 19 |
+
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
|
| 20 |
+
|
| 21 |
+
hidden_size (int, optional, defaults to 512):
|
| 22 |
+
Dimensionality of the hidden/embedding layers (d_model in Mamba).
|
| 23 |
+
|
| 24 |
+
num_hidden_layers (int, optional, defaults to 24):
|
| 25 |
+
Number of Mamba layers (mamba_layer).
|
| 26 |
+
|
| 27 |
+
intermediate_size (int, optional, defaults to 2048):
|
| 28 |
+
Dimensionality of intermediate representations in MLP.
|
| 29 |
+
|
| 30 |
+
max_position_embeddings (int, optional, defaults to 2048):
|
| 31 |
+
Maximum sequence length (seq_len).
|
| 32 |
+
|
| 33 |
+
hidden_dropout_prob (float, optional, defaults to 0.1):
|
| 34 |
+
Dropout probability for hidden states.
|
| 35 |
+
|
| 36 |
+
initializer_range (float, optional, defaults to 0.02):
|
| 37 |
+
Standard deviation of truncated normal initializer.
|
| 38 |
+
|
| 39 |
+
mamba_mode (str, optional, defaults to "gate"):
|
| 40 |
+
Aggregation mode for bidirectional Mamba layers.
|
| 41 |
+
Options: "mean", "sum", "concat", "gate".
|
| 42 |
+
|
| 43 |
+
embedding_pooling (str, optional, defaults to "mean"):
|
| 44 |
+
Method for pooling to get cell embedding.
|
| 45 |
+
Options: "CLS", "mean", "weighted".
|
| 46 |
+
|
| 47 |
+
num_labels (int, optional, defaults to 2):
|
| 48 |
+
Number of labels for sequence classification tasks.
|
| 49 |
+
|
| 50 |
+
pad_token_id (int, optional, defaults to 1):
|
| 51 |
+
Token ID for padding.
|
| 52 |
+
|
| 53 |
+
bos_token_id (int, optional, defaults to None):
|
| 54 |
+
Token ID for beginning of sequence.
|
| 55 |
+
|
| 56 |
+
eos_token_id (int, optional, defaults to None):
|
| 57 |
+
Token ID for end of sequence.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_type = "genemamba"
|
| 61 |
+
attribute_map = {
|
| 62 |
+
"hidden_size": "hidden_size",
|
| 63 |
+
"num_hidden_layers": "num_hidden_layers",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
vocab_size: int = 25426,
|
| 69 |
+
hidden_size: int = 512,
|
| 70 |
+
num_hidden_layers: int = 24,
|
| 71 |
+
intermediate_size: int = 2048,
|
| 72 |
+
max_position_embeddings: int = 2048,
|
| 73 |
+
hidden_dropout_prob: float = 0.1,
|
| 74 |
+
initializer_range: float = 0.02,
|
| 75 |
+
mamba_mode: str = "gate",
|
| 76 |
+
embedding_pooling: str = "mean",
|
| 77 |
+
num_labels: int = 2,
|
| 78 |
+
pad_token_id: int = 1,
|
| 79 |
+
bos_token_id: Optional[int] = None,
|
| 80 |
+
eos_token_id: Optional[int] = None,
|
| 81 |
+
**kwargs
|
| 82 |
+
):
|
| 83 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = vocab_size
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.num_hidden_layers = num_hidden_layers
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.max_position_embeddings = max_position_embeddings
|
| 90 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 91 |
+
self.initializer_range = initializer_range
|
| 92 |
+
self.mamba_mode = mamba_mode
|
| 93 |
+
self.embedding_pooling = embedding_pooling
|
| 94 |
+
self.num_labels = num_labels
|
| 95 |
+
self.pad_token_id = pad_token_id
|
| 96 |
+
self.bos_token_id = bos_token_id
|
| 97 |
+
self.eos_token_id = eos_token_id
|
24l-512d/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
|
| 3 |
+
size 262998656
|
24l-512d/modeling_genemamba.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
|
| 3 |
+
Includes backbone model and task-specific heads for various downstream tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.init import normal_, constant_
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 16 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
|
| 17 |
+
from transformers.models.auto import register_model_for_auto_class
|
| 18 |
+
|
| 19 |
+
from mamba_ssm import Mamba
|
| 20 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm
|
| 21 |
+
|
| 22 |
+
from .configuration_genemamba import GeneMambaConfig
|
| 23 |
+
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================
|
| 29 |
+
# Core Architecture Components
|
| 30 |
+
# ===========================
|
| 31 |
+
|
| 32 |
+
class EncoderLayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Single Mamba encoder layer with residual connection.
|
| 35 |
+
Applies a Mamba2 or Mamba layer followed by addition with input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (int): Dimension of hidden states.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size: int):
|
| 42 |
+
super(EncoderLayer, self).__init__()
|
| 43 |
+
self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
|
| 44 |
+
|
| 45 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output after Mamba layer and residual connection.
|
| 52 |
+
"""
|
| 53 |
+
output = self.mamba(X) + X
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MambaMixer(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Stack of Mamba encoder layers with bidirectional processing and aggregation.
|
| 60 |
+
Processes sequences in both forward and reverse directions, then aggregates.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
|
| 64 |
+
hidden_size (int): Dimension of hidden states.
|
| 65 |
+
num_hidden_layers (int): Number of Mamba layers.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
mode: str = "gate",
|
| 71 |
+
hidden_size: int = 512,
|
| 72 |
+
num_hidden_layers: int = 24
|
| 73 |
+
):
|
| 74 |
+
super(MambaMixer, self).__init__()
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
|
| 78 |
+
# Create Mamba layers
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Aggregation modules for certain modes
|
| 84 |
+
if mode in ["concat", "gate"]:
|
| 85 |
+
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
|
| 86 |
+
|
| 87 |
+
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Reverse a sequence based on actual length (ignoring padding).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 93 |
+
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Reversed tensor.
|
| 97 |
+
"""
|
| 98 |
+
batch_size, seq_length, embedding_dim = X.size()
|
| 99 |
+
|
| 100 |
+
if mask is None:
|
| 101 |
+
# Simple flip
|
| 102 |
+
return X.flip([1])
|
| 103 |
+
|
| 104 |
+
# Flip based on actual sequence length (marked by mask)
|
| 105 |
+
lengths = (~mask).sum(dim=1)
|
| 106 |
+
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
|
| 107 |
+
flip_mask = pos_tensor < lengths.unsqueeze(1)
|
| 108 |
+
reversed_positions = torch.where(
|
| 109 |
+
flip_mask,
|
| 110 |
+
lengths.unsqueeze(1) - 1 - pos_tensor,
|
| 111 |
+
pos_tensor
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
|
| 115 |
+
return X_reverse
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
X: torch.Tensor,
|
| 120 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Process sequence through bidirectional Mamba layers.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 127 |
+
padding_mask (torch.Tensor, optional): Padding mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Output after processing all layers and aggregation.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
# Flip sequence for reverse processing
|
| 135 |
+
X_flip = self.flip_sequence(X, padding_mask)
|
| 136 |
+
|
| 137 |
+
# Forward and reverse passes
|
| 138 |
+
X_f = layer(X)
|
| 139 |
+
X_b = layer(X_flip)
|
| 140 |
+
|
| 141 |
+
# Flip back the reverse output
|
| 142 |
+
X_b = self.flip_sequence(X_b, padding_mask)
|
| 143 |
+
|
| 144 |
+
# Aggregate forward and reverse
|
| 145 |
+
if self.mode == "mean":
|
| 146 |
+
X = (X_f + X_b) / 2
|
| 147 |
+
elif self.mode == "sum":
|
| 148 |
+
X = X_f + X_b
|
| 149 |
+
elif self.mode == "concat":
|
| 150 |
+
X = torch.cat([X_f, X_b], dim=-1)
|
| 151 |
+
X = self.aggr(X)
|
| 152 |
+
elif self.mode == "gate":
|
| 153 |
+
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
|
| 154 |
+
X = z * X_f + (1 - z) * X_b
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid aggregation mode: {self.mode}")
|
| 157 |
+
|
| 158 |
+
return X
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ===========================
|
| 162 |
+
# Base Model Classes
|
| 163 |
+
# ===========================
|
| 164 |
+
|
| 165 |
+
class GeneMambaPreTrainedModel(PreTrainedModel):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all GeneMamba models.
|
| 168 |
+
Handles weight initialization and provides standard model interfaces.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
config_class = GeneMambaConfig
|
| 172 |
+
base_model_prefix = "genemamba"
|
| 173 |
+
supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
"""Initialize module weights."""
|
| 177 |
+
if isinstance(module, nn.Linear):
|
| 178 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
constant_(module.bias, 0.0)
|
| 181 |
+
elif isinstance(module, nn.Embedding):
|
| 182 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 183 |
+
if module.padding_idx is not None:
|
| 184 |
+
module.weight.data[module.padding_idx].zero_()
|
| 185 |
+
elif isinstance(module, nn.LayerNorm):
|
| 186 |
+
constant_(module.bias, 0.0)
|
| 187 |
+
constant_(module.weight, 1.0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GeneMambaModel(GeneMambaPreTrainedModel):
|
| 191 |
+
"""
|
| 192 |
+
GeneMamba backbone model - outputs cell embeddings and hidden states.
|
| 193 |
+
This is the core model used by task-specific heads.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (GeneMambaConfig): Model configuration class.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: GeneMambaConfig):
|
| 200 |
+
super().__init__(config)
|
| 201 |
+
self.config = config
|
| 202 |
+
|
| 203 |
+
# Embedding layer
|
| 204 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 205 |
+
|
| 206 |
+
# Mamba layers with bidirectional aggregation
|
| 207 |
+
self.mamba_mixer = MambaMixer(
|
| 208 |
+
mode=config.mamba_mode,
|
| 209 |
+
hidden_size=config.hidden_size,
|
| 210 |
+
num_hidden_layers=config.num_hidden_layers
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Final layer normalization
|
| 214 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 215 |
+
|
| 216 |
+
self.apply(self._init_weights)
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 219 |
+
"""Return embedding layer."""
|
| 220 |
+
return self.embeddings
|
| 221 |
+
|
| 222 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 223 |
+
"""Set embedding layer."""
|
| 224 |
+
self.embeddings = value
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 230 |
+
output_hidden_states: bool = False,
|
| 231 |
+
) -> GeneMambaModelOutput:
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 235 |
+
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
|
| 236 |
+
output_hidden_states (bool): Whether to output hidden states from all layers.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
|
| 240 |
+
"""
|
| 241 |
+
# Get embeddings
|
| 242 |
+
hidden_states = self.embeddings(input_ids)
|
| 243 |
+
|
| 244 |
+
# Pass through Mamba layers
|
| 245 |
+
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 246 |
+
|
| 247 |
+
# Apply final normalization
|
| 248 |
+
hidden_states = self.norm(hidden_states)
|
| 249 |
+
|
| 250 |
+
# Compute pooled embedding (cell representation)
|
| 251 |
+
if self.config.embedding_pooling == "CLS":
|
| 252 |
+
# Use first token (CLS)
|
| 253 |
+
pooled_embedding = hidden_states[:, 0, :]
|
| 254 |
+
elif self.config.embedding_pooling == "mean":
|
| 255 |
+
# Mean pooling over sequence
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
|
| 258 |
+
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
|
| 259 |
+
else:
|
| 260 |
+
pooled_embedding = hidden_states.mean(dim=1)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
|
| 263 |
+
|
| 264 |
+
return GeneMambaModelOutput(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
pooled_embedding=pooled_embedding,
|
| 267 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 268 |
+
embedding_pooling=self.config.embedding_pooling,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ===========================
|
| 273 |
+
# Task-Specific Models
|
| 274 |
+
# ===========================
|
| 275 |
+
|
| 276 |
+
@register_model_for_auto_class("AutoModel")
|
| 277 |
+
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
|
| 278 |
+
"""
|
| 279 |
+
GeneMamba model for masked language modeling (MLM).
|
| 280 |
+
Suitable for pretraining and domain adaptation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config (GeneMambaConfig): Model configuration class.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: GeneMambaConfig):
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.genemamba = GeneMambaModel(config)
|
| 289 |
+
|
| 290 |
+
# Language modeling head
|
| 291 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 292 |
+
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
labels: Optional[torch.Tensor] = None,
|
| 300 |
+
output_hidden_states: bool = False,
|
| 301 |
+
) -> GeneMambaMaskedLMOutput:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 305 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 306 |
+
labels (torch.Tensor, optional): Target token ids for MLM loss.
|
| 307 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GeneMambaMaskedLMOutput: Contains logits and optional loss.
|
| 311 |
+
"""
|
| 312 |
+
outputs = self.genemamba(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
| 319 |
+
|
| 320 |
+
loss = None
|
| 321 |
+
if labels is not None:
|
| 322 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 323 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 324 |
+
|
| 325 |
+
return GeneMambaMaskedLMOutput(
|
| 326 |
+
loss=loss,
|
| 327 |
+
logits=logits,
|
| 328 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model_for_auto_class("AutoModelForSequenceClassification")
|
| 333 |
+
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
|
| 334 |
+
"""
|
| 335 |
+
GeneMamba model for sequence classification tasks.
|
| 336 |
+
Ideal for cell type annotation, tissue classification, etc.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
config (GeneMambaConfig): Model configuration class.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: GeneMambaConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
self.config = config
|
| 346 |
+
|
| 347 |
+
self.genemamba = GeneMambaModel(config)
|
| 348 |
+
|
| 349 |
+
# Classification head
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 352 |
+
|
| 353 |
+
self.apply(self._init_weights)
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
input_ids: torch.Tensor,
|
| 358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
labels: Optional[torch.Tensor] = None,
|
| 360 |
+
output_hidden_states: bool = False,
|
| 361 |
+
) -> GeneMambaSequenceClassifierOutput:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 365 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 366 |
+
labels (torch.Tensor, optional): Class labels for classification loss.
|
| 367 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
|
| 371 |
+
"""
|
| 372 |
+
outputs = self.genemamba(
|
| 373 |
+
input_ids=input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
pooled_embedding = outputs.pooled_embedding
|
| 379 |
+
logits = self.classifier(self.dropout(pooled_embedding))
|
| 380 |
+
|
| 381 |
+
loss = None
|
| 382 |
+
if labels is not None:
|
| 383 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 384 |
+
loss = loss_fct(logits, labels)
|
| 385 |
+
|
| 386 |
+
return GeneMambaSequenceClassifierOutput(
|
| 387 |
+
loss=loss,
|
| 388 |
+
logits=logits,
|
| 389 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 390 |
+
pooled_embedding=pooled_embedding,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Register tokenizer class
|
| 395 |
+
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
|
24l-512d/modeling_outputs.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom ModelOutput classes for GeneMamba.
|
| 3 |
+
Defines the output structure for different GeneMamba tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GeneMambaModelOutput(ModelOutput):
|
| 14 |
+
"""
|
| 15 |
+
Base output class for GeneMamba models.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
|
| 21 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 22 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 23 |
+
|
| 24 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
|
| 25 |
+
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
|
| 26 |
+
This is the recommended embedding to use for classification, clustering, etc.
|
| 27 |
+
|
| 28 |
+
embedding_pooling (str):
|
| 29 |
+
The pooling method used to generate pooled_embedding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
last_hidden_state: torch.FloatTensor = None
|
| 33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 34 |
+
pooled_embedding: torch.FloatTensor = None
|
| 35 |
+
embedding_pooling: str = "mean"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneMambaSequenceClassifierOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Output class for GeneMamba sequence classification models.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 45 |
+
Classification loss (if labels were provided).
|
| 46 |
+
|
| 47 |
+
logits (torch.FloatTensor of shape (batch_size, num_labels)):
|
| 48 |
+
Classification scores (before softmax).
|
| 49 |
+
|
| 50 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 51 |
+
Hidden-states of the model at the output of each layer.
|
| 52 |
+
|
| 53 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
|
| 54 |
+
Cell embedding before classification head.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loss: Optional[torch.FloatTensor] = None
|
| 58 |
+
logits: torch.FloatTensor = None
|
| 59 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 60 |
+
pooled_embedding: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GeneMambaMaskedLMOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Output class for GeneMamba masked language modeling.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 70 |
+
MLM loss (if labels were provided).
|
| 71 |
+
|
| 72 |
+
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
|
| 73 |
+
Prediction scores of the language modeling head.
|
| 74 |
+
|
| 75 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 76 |
+
Hidden-states of the model at the output of each layer.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
logits: torch.FloatTensor = None
|
| 81 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
24l-512d/special_tokens_map.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "[PAD]",
|
| 3 |
+
"unk_token": "[UNK]"
|
| 4 |
+
}
|
24l-512d/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
24l-512d/tokenizer_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {},
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 5 |
+
"pad_token": "[PAD]",
|
| 6 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 7 |
+
"unk_token": "[UNK]"
|
| 8 |
+
}
|
24l-768d/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "genemamba",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GeneMambaModel"
|
| 5 |
+
],
|
| 6 |
+
"vocab_size": 25426,
|
| 7 |
+
"max_position_embeddings": 2048,
|
| 8 |
+
"hidden_size": 768,
|
| 9 |
+
"num_hidden_layers": 24,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mamba_mode": "gate",
|
| 14 |
+
"embedding_pooling": "mean",
|
| 15 |
+
"num_labels": 2,
|
| 16 |
+
"pad_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.40.2",
|
| 22 |
+
"auto_map": {
|
| 23 |
+
"AutoConfig": "configuration_genemamba.GeneMambaConfig",
|
| 24 |
+
"AutoModel": "modeling_genemamba.GeneMambaModel",
|
| 25 |
+
"AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
|
| 26 |
+
"AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
|
| 27 |
+
}
|
| 28 |
+
}
|
24l-768d/configuration_genemamba.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GeneMamba model.
|
| 3 |
+
Defines all hyperparameters and settings for the GeneMamba architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneMambaConfig(PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for GeneMamba model.
|
| 13 |
+
|
| 14 |
+
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
|
| 15 |
+
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (int, optional, defaults to 25426):
|
| 19 |
+
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
|
| 20 |
+
|
| 21 |
+
hidden_size (int, optional, defaults to 512):
|
| 22 |
+
Dimensionality of the hidden/embedding layers (d_model in Mamba).
|
| 23 |
+
|
| 24 |
+
num_hidden_layers (int, optional, defaults to 24):
|
| 25 |
+
Number of Mamba layers (mamba_layer).
|
| 26 |
+
|
| 27 |
+
intermediate_size (int, optional, defaults to 2048):
|
| 28 |
+
Dimensionality of intermediate representations in MLP.
|
| 29 |
+
|
| 30 |
+
max_position_embeddings (int, optional, defaults to 2048):
|
| 31 |
+
Maximum sequence length (seq_len).
|
| 32 |
+
|
| 33 |
+
hidden_dropout_prob (float, optional, defaults to 0.1):
|
| 34 |
+
Dropout probability for hidden states.
|
| 35 |
+
|
| 36 |
+
initializer_range (float, optional, defaults to 0.02):
|
| 37 |
+
Standard deviation of truncated normal initializer.
|
| 38 |
+
|
| 39 |
+
mamba_mode (str, optional, defaults to "gate"):
|
| 40 |
+
Aggregation mode for bidirectional Mamba layers.
|
| 41 |
+
Options: "mean", "sum", "concat", "gate".
|
| 42 |
+
|
| 43 |
+
embedding_pooling (str, optional, defaults to "mean"):
|
| 44 |
+
Method for pooling to get cell embedding.
|
| 45 |
+
Options: "CLS", "mean", "weighted".
|
| 46 |
+
|
| 47 |
+
num_labels (int, optional, defaults to 2):
|
| 48 |
+
Number of labels for sequence classification tasks.
|
| 49 |
+
|
| 50 |
+
pad_token_id (int, optional, defaults to 1):
|
| 51 |
+
Token ID for padding.
|
| 52 |
+
|
| 53 |
+
bos_token_id (int, optional, defaults to None):
|
| 54 |
+
Token ID for beginning of sequence.
|
| 55 |
+
|
| 56 |
+
eos_token_id (int, optional, defaults to None):
|
| 57 |
+
Token ID for end of sequence.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_type = "genemamba"
|
| 61 |
+
attribute_map = {
|
| 62 |
+
"hidden_size": "hidden_size",
|
| 63 |
+
"num_hidden_layers": "num_hidden_layers",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
vocab_size: int = 25426,
|
| 69 |
+
hidden_size: int = 512,
|
| 70 |
+
num_hidden_layers: int = 24,
|
| 71 |
+
intermediate_size: int = 2048,
|
| 72 |
+
max_position_embeddings: int = 2048,
|
| 73 |
+
hidden_dropout_prob: float = 0.1,
|
| 74 |
+
initializer_range: float = 0.02,
|
| 75 |
+
mamba_mode: str = "gate",
|
| 76 |
+
embedding_pooling: str = "mean",
|
| 77 |
+
num_labels: int = 2,
|
| 78 |
+
pad_token_id: int = 1,
|
| 79 |
+
bos_token_id: Optional[int] = None,
|
| 80 |
+
eos_token_id: Optional[int] = None,
|
| 81 |
+
**kwargs
|
| 82 |
+
):
|
| 83 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = vocab_size
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.num_hidden_layers = num_hidden_layers
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.max_position_embeddings = max_position_embeddings
|
| 90 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 91 |
+
self.initializer_range = initializer_range
|
| 92 |
+
self.mamba_mode = mamba_mode
|
| 93 |
+
self.embedding_pooling = embedding_pooling
|
| 94 |
+
self.num_labels = num_labels
|
| 95 |
+
self.pad_token_id = pad_token_id
|
| 96 |
+
self.bos_token_id = bos_token_id
|
| 97 |
+
self.eos_token_id = eos_token_id
|
24l-768d/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b423a3555eecacc88ff587c1d3f689a2caa05ede0a01d09dbaae175f23a2e7e1
|
| 3 |
+
size 508241792
|
24l-768d/modeling_genemamba.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
|
| 3 |
+
Includes backbone model and task-specific heads for various downstream tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.init import normal_, constant_
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 16 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
|
| 17 |
+
from transformers.models.auto import register_model_for_auto_class
|
| 18 |
+
|
| 19 |
+
from mamba_ssm import Mamba
|
| 20 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm
|
| 21 |
+
|
| 22 |
+
from .configuration_genemamba import GeneMambaConfig
|
| 23 |
+
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================
|
| 29 |
+
# Core Architecture Components
|
| 30 |
+
# ===========================
|
| 31 |
+
|
| 32 |
+
class EncoderLayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Single Mamba encoder layer with residual connection.
|
| 35 |
+
Applies a Mamba2 or Mamba layer followed by addition with input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (int): Dimension of hidden states.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size: int):
|
| 42 |
+
super(EncoderLayer, self).__init__()
|
| 43 |
+
self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
|
| 44 |
+
|
| 45 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output after Mamba layer and residual connection.
|
| 52 |
+
"""
|
| 53 |
+
output = self.mamba(X) + X
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MambaMixer(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Stack of Mamba encoder layers with bidirectional processing and aggregation.
|
| 60 |
+
Processes sequences in both forward and reverse directions, then aggregates.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
|
| 64 |
+
hidden_size (int): Dimension of hidden states.
|
| 65 |
+
num_hidden_layers (int): Number of Mamba layers.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
mode: str = "gate",
|
| 71 |
+
hidden_size: int = 512,
|
| 72 |
+
num_hidden_layers: int = 24
|
| 73 |
+
):
|
| 74 |
+
super(MambaMixer, self).__init__()
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
|
| 78 |
+
# Create Mamba layers
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Aggregation modules for certain modes
|
| 84 |
+
if mode in ["concat", "gate"]:
|
| 85 |
+
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
|
| 86 |
+
|
| 87 |
+
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Reverse a sequence based on actual length (ignoring padding).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 93 |
+
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Reversed tensor.
|
| 97 |
+
"""
|
| 98 |
+
batch_size, seq_length, embedding_dim = X.size()
|
| 99 |
+
|
| 100 |
+
if mask is None:
|
| 101 |
+
# Simple flip
|
| 102 |
+
return X.flip([1])
|
| 103 |
+
|
| 104 |
+
# Flip based on actual sequence length (marked by mask)
|
| 105 |
+
lengths = (~mask).sum(dim=1)
|
| 106 |
+
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
|
| 107 |
+
flip_mask = pos_tensor < lengths.unsqueeze(1)
|
| 108 |
+
reversed_positions = torch.where(
|
| 109 |
+
flip_mask,
|
| 110 |
+
lengths.unsqueeze(1) - 1 - pos_tensor,
|
| 111 |
+
pos_tensor
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
|
| 115 |
+
return X_reverse
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
X: torch.Tensor,
|
| 120 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Process sequence through bidirectional Mamba layers.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 127 |
+
padding_mask (torch.Tensor, optional): Padding mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Output after processing all layers and aggregation.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
# Flip sequence for reverse processing
|
| 135 |
+
X_flip = self.flip_sequence(X, padding_mask)
|
| 136 |
+
|
| 137 |
+
# Forward and reverse passes
|
| 138 |
+
X_f = layer(X)
|
| 139 |
+
X_b = layer(X_flip)
|
| 140 |
+
|
| 141 |
+
# Flip back the reverse output
|
| 142 |
+
X_b = self.flip_sequence(X_b, padding_mask)
|
| 143 |
+
|
| 144 |
+
# Aggregate forward and reverse
|
| 145 |
+
if self.mode == "mean":
|
| 146 |
+
X = (X_f + X_b) / 2
|
| 147 |
+
elif self.mode == "sum":
|
| 148 |
+
X = X_f + X_b
|
| 149 |
+
elif self.mode == "concat":
|
| 150 |
+
X = torch.cat([X_f, X_b], dim=-1)
|
| 151 |
+
X = self.aggr(X)
|
| 152 |
+
elif self.mode == "gate":
|
| 153 |
+
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
|
| 154 |
+
X = z * X_f + (1 - z) * X_b
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid aggregation mode: {self.mode}")
|
| 157 |
+
|
| 158 |
+
return X
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ===========================
|
| 162 |
+
# Base Model Classes
|
| 163 |
+
# ===========================
|
| 164 |
+
|
| 165 |
+
class GeneMambaPreTrainedModel(PreTrainedModel):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all GeneMamba models.
|
| 168 |
+
Handles weight initialization and provides standard model interfaces.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
config_class = GeneMambaConfig
|
| 172 |
+
base_model_prefix = "genemamba"
|
| 173 |
+
supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
"""Initialize module weights."""
|
| 177 |
+
if isinstance(module, nn.Linear):
|
| 178 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
constant_(module.bias, 0.0)
|
| 181 |
+
elif isinstance(module, nn.Embedding):
|
| 182 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 183 |
+
if module.padding_idx is not None:
|
| 184 |
+
module.weight.data[module.padding_idx].zero_()
|
| 185 |
+
elif isinstance(module, nn.LayerNorm):
|
| 186 |
+
constant_(module.bias, 0.0)
|
| 187 |
+
constant_(module.weight, 1.0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GeneMambaModel(GeneMambaPreTrainedModel):
|
| 191 |
+
"""
|
| 192 |
+
GeneMamba backbone model - outputs cell embeddings and hidden states.
|
| 193 |
+
This is the core model used by task-specific heads.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (GeneMambaConfig): Model configuration class.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: GeneMambaConfig):
|
| 200 |
+
super().__init__(config)
|
| 201 |
+
self.config = config
|
| 202 |
+
|
| 203 |
+
# Embedding layer
|
| 204 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 205 |
+
|
| 206 |
+
# Mamba layers with bidirectional aggregation
|
| 207 |
+
self.mamba_mixer = MambaMixer(
|
| 208 |
+
mode=config.mamba_mode,
|
| 209 |
+
hidden_size=config.hidden_size,
|
| 210 |
+
num_hidden_layers=config.num_hidden_layers
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Final layer normalization
|
| 214 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 215 |
+
|
| 216 |
+
self.apply(self._init_weights)
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 219 |
+
"""Return embedding layer."""
|
| 220 |
+
return self.embeddings
|
| 221 |
+
|
| 222 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 223 |
+
"""Set embedding layer."""
|
| 224 |
+
self.embeddings = value
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 230 |
+
output_hidden_states: bool = False,
|
| 231 |
+
) -> GeneMambaModelOutput:
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 235 |
+
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
|
| 236 |
+
output_hidden_states (bool): Whether to output hidden states from all layers.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
|
| 240 |
+
"""
|
| 241 |
+
# Get embeddings
|
| 242 |
+
hidden_states = self.embeddings(input_ids)
|
| 243 |
+
|
| 244 |
+
# Pass through Mamba layers
|
| 245 |
+
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 246 |
+
|
| 247 |
+
# Apply final normalization
|
| 248 |
+
hidden_states = self.norm(hidden_states)
|
| 249 |
+
|
| 250 |
+
# Compute pooled embedding (cell representation)
|
| 251 |
+
if self.config.embedding_pooling == "CLS":
|
| 252 |
+
# Use first token (CLS)
|
| 253 |
+
pooled_embedding = hidden_states[:, 0, :]
|
| 254 |
+
elif self.config.embedding_pooling == "mean":
|
| 255 |
+
# Mean pooling over sequence
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
|
| 258 |
+
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
|
| 259 |
+
else:
|
| 260 |
+
pooled_embedding = hidden_states.mean(dim=1)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
|
| 263 |
+
|
| 264 |
+
return GeneMambaModelOutput(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
pooled_embedding=pooled_embedding,
|
| 267 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 268 |
+
embedding_pooling=self.config.embedding_pooling,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ===========================
|
| 273 |
+
# Task-Specific Models
|
| 274 |
+
# ===========================
|
| 275 |
+
|
| 276 |
+
@register_model_for_auto_class("AutoModel")
|
| 277 |
+
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
|
| 278 |
+
"""
|
| 279 |
+
GeneMamba model for masked language modeling (MLM).
|
| 280 |
+
Suitable for pretraining and domain adaptation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config (GeneMambaConfig): Model configuration class.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: GeneMambaConfig):
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.genemamba = GeneMambaModel(config)
|
| 289 |
+
|
| 290 |
+
# Language modeling head
|
| 291 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 292 |
+
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
labels: Optional[torch.Tensor] = None,
|
| 300 |
+
output_hidden_states: bool = False,
|
| 301 |
+
) -> GeneMambaMaskedLMOutput:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 305 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 306 |
+
labels (torch.Tensor, optional): Target token ids for MLM loss.
|
| 307 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GeneMambaMaskedLMOutput: Contains logits and optional loss.
|
| 311 |
+
"""
|
| 312 |
+
outputs = self.genemamba(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
| 319 |
+
|
| 320 |
+
loss = None
|
| 321 |
+
if labels is not None:
|
| 322 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 323 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 324 |
+
|
| 325 |
+
return GeneMambaMaskedLMOutput(
|
| 326 |
+
loss=loss,
|
| 327 |
+
logits=logits,
|
| 328 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model_for_auto_class("AutoModelForSequenceClassification")
|
| 333 |
+
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
|
| 334 |
+
"""
|
| 335 |
+
GeneMamba model for sequence classification tasks.
|
| 336 |
+
Ideal for cell type annotation, tissue classification, etc.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
config (GeneMambaConfig): Model configuration class.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: GeneMambaConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
self.config = config
|
| 346 |
+
|
| 347 |
+
self.genemamba = GeneMambaModel(config)
|
| 348 |
+
|
| 349 |
+
# Classification head
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 352 |
+
|
| 353 |
+
self.apply(self._init_weights)
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
input_ids: torch.Tensor,
|
| 358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
labels: Optional[torch.Tensor] = None,
|
| 360 |
+
output_hidden_states: bool = False,
|
| 361 |
+
) -> GeneMambaSequenceClassifierOutput:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 365 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 366 |
+
labels (torch.Tensor, optional): Class labels for classification loss.
|
| 367 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
|
| 371 |
+
"""
|
| 372 |
+
outputs = self.genemamba(
|
| 373 |
+
input_ids=input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
pooled_embedding = outputs.pooled_embedding
|
| 379 |
+
logits = self.classifier(self.dropout(pooled_embedding))
|
| 380 |
+
|
| 381 |
+
loss = None
|
| 382 |
+
if labels is not None:
|
| 383 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 384 |
+
loss = loss_fct(logits, labels)
|
| 385 |
+
|
| 386 |
+
return GeneMambaSequenceClassifierOutput(
|
| 387 |
+
loss=loss,
|
| 388 |
+
logits=logits,
|
| 389 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 390 |
+
pooled_embedding=pooled_embedding,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Register tokenizer class
|
| 395 |
+
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
|
24l-768d/modeling_outputs.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom ModelOutput classes for GeneMamba.
|
| 3 |
+
Defines the output structure for different GeneMamba tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GeneMambaModelOutput(ModelOutput):
|
| 14 |
+
"""
|
| 15 |
+
Base output class for GeneMamba models.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
|
| 21 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 22 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 23 |
+
|
| 24 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
|
| 25 |
+
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
|
| 26 |
+
This is the recommended embedding to use for classification, clustering, etc.
|
| 27 |
+
|
| 28 |
+
embedding_pooling (str):
|
| 29 |
+
The pooling method used to generate pooled_embedding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
last_hidden_state: torch.FloatTensor = None
|
| 33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 34 |
+
pooled_embedding: torch.FloatTensor = None
|
| 35 |
+
embedding_pooling: str = "mean"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneMambaSequenceClassifierOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Output class for GeneMamba sequence classification models.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 45 |
+
Classification loss (if labels were provided).
|
| 46 |
+
|
| 47 |
+
logits (torch.FloatTensor of shape (batch_size, num_labels)):
|
| 48 |
+
Classification scores (before softmax).
|
| 49 |
+
|
| 50 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 51 |
+
Hidden-states of the model at the output of each layer.
|
| 52 |
+
|
| 53 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
|
| 54 |
+
Cell embedding before classification head.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loss: Optional[torch.FloatTensor] = None
|
| 58 |
+
logits: torch.FloatTensor = None
|
| 59 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 60 |
+
pooled_embedding: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GeneMambaMaskedLMOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Output class for GeneMamba masked language modeling.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 70 |
+
MLM loss (if labels were provided).
|
| 71 |
+
|
| 72 |
+
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
|
| 73 |
+
Prediction scores of the language modeling head.
|
| 74 |
+
|
| 75 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 76 |
+
Hidden-states of the model at the output of each layer.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
logits: torch.FloatTensor = None
|
| 81 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
24l-768d/special_tokens_map.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "[PAD]",
|
| 3 |
+
"unk_token": "[UNK]"
|
| 4 |
+
}
|
24l-768d/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
24l-768d/tokenizer_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {},
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 5 |
+
"pad_token": "[PAD]",
|
| 6 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 7 |
+
"unk_token": "[UNK]"
|
| 8 |
+
}
|
48l-512d/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "genemamba",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GeneMambaModel"
|
| 5 |
+
],
|
| 6 |
+
"vocab_size": 25426,
|
| 7 |
+
"max_position_embeddings": 2048,
|
| 8 |
+
"hidden_size": 512,
|
| 9 |
+
"num_hidden_layers": 48,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mamba_mode": "gate",
|
| 14 |
+
"embedding_pooling": "mean",
|
| 15 |
+
"num_labels": 2,
|
| 16 |
+
"pad_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.40.2",
|
| 22 |
+
"auto_map": {
|
| 23 |
+
"AutoConfig": "configuration_genemamba.GeneMambaConfig",
|
| 24 |
+
"AutoModel": "modeling_genemamba.GeneMambaModel",
|
| 25 |
+
"AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
|
| 26 |
+
"AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
|
| 27 |
+
}
|
| 28 |
+
}
|
48l-512d/configuration_genemamba.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GeneMamba model.
|
| 3 |
+
Defines all hyperparameters and settings for the GeneMamba architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneMambaConfig(PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for GeneMamba model.
|
| 13 |
+
|
| 14 |
+
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
|
| 15 |
+
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (int, optional, defaults to 25426):
|
| 19 |
+
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
|
| 20 |
+
|
| 21 |
+
hidden_size (int, optional, defaults to 512):
|
| 22 |
+
Dimensionality of the hidden/embedding layers (d_model in Mamba).
|
| 23 |
+
|
| 24 |
+
num_hidden_layers (int, optional, defaults to 24):
|
| 25 |
+
Number of Mamba layers (mamba_layer).
|
| 26 |
+
|
| 27 |
+
intermediate_size (int, optional, defaults to 2048):
|
| 28 |
+
Dimensionality of intermediate representations in MLP.
|
| 29 |
+
|
| 30 |
+
max_position_embeddings (int, optional, defaults to 2048):
|
| 31 |
+
Maximum sequence length (seq_len).
|
| 32 |
+
|
| 33 |
+
hidden_dropout_prob (float, optional, defaults to 0.1):
|
| 34 |
+
Dropout probability for hidden states.
|
| 35 |
+
|
| 36 |
+
initializer_range (float, optional, defaults to 0.02):
|
| 37 |
+
Standard deviation of truncated normal initializer.
|
| 38 |
+
|
| 39 |
+
mamba_mode (str, optional, defaults to "gate"):
|
| 40 |
+
Aggregation mode for bidirectional Mamba layers.
|
| 41 |
+
Options: "mean", "sum", "concat", "gate".
|
| 42 |
+
|
| 43 |
+
embedding_pooling (str, optional, defaults to "mean"):
|
| 44 |
+
Method for pooling to get cell embedding.
|
| 45 |
+
Options: "CLS", "mean", "weighted".
|
| 46 |
+
|
| 47 |
+
num_labels (int, optional, defaults to 2):
|
| 48 |
+
Number of labels for sequence classification tasks.
|
| 49 |
+
|
| 50 |
+
pad_token_id (int, optional, defaults to 1):
|
| 51 |
+
Token ID for padding.
|
| 52 |
+
|
| 53 |
+
bos_token_id (int, optional, defaults to None):
|
| 54 |
+
Token ID for beginning of sequence.
|
| 55 |
+
|
| 56 |
+
eos_token_id (int, optional, defaults to None):
|
| 57 |
+
Token ID for end of sequence.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_type = "genemamba"
|
| 61 |
+
attribute_map = {
|
| 62 |
+
"hidden_size": "hidden_size",
|
| 63 |
+
"num_hidden_layers": "num_hidden_layers",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
vocab_size: int = 25426,
|
| 69 |
+
hidden_size: int = 512,
|
| 70 |
+
num_hidden_layers: int = 24,
|
| 71 |
+
intermediate_size: int = 2048,
|
| 72 |
+
max_position_embeddings: int = 2048,
|
| 73 |
+
hidden_dropout_prob: float = 0.1,
|
| 74 |
+
initializer_range: float = 0.02,
|
| 75 |
+
mamba_mode: str = "gate",
|
| 76 |
+
embedding_pooling: str = "mean",
|
| 77 |
+
num_labels: int = 2,
|
| 78 |
+
pad_token_id: int = 1,
|
| 79 |
+
bos_token_id: Optional[int] = None,
|
| 80 |
+
eos_token_id: Optional[int] = None,
|
| 81 |
+
**kwargs
|
| 82 |
+
):
|
| 83 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = vocab_size
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.num_hidden_layers = num_hidden_layers
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.max_position_embeddings = max_position_embeddings
|
| 90 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 91 |
+
self.initializer_range = initializer_range
|
| 92 |
+
self.mamba_mode = mamba_mode
|
| 93 |
+
self.embedding_pooling = embedding_pooling
|
| 94 |
+
self.num_labels = num_labels
|
| 95 |
+
self.pad_token_id = pad_token_id
|
| 96 |
+
self.bos_token_id = bos_token_id
|
| 97 |
+
self.eos_token_id = eos_token_id
|
48l-512d/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a715342c6cc00b20161a05941d9d181cca73c7ecc9cae17fd3a04bf92590a7d
|
| 3 |
+
size 421748360
|
48l-512d/modeling_genemamba.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
|
| 3 |
+
Includes backbone model and task-specific heads for various downstream tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.init import normal_, constant_
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 16 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
|
| 17 |
+
from transformers.models.auto import register_model_for_auto_class
|
| 18 |
+
|
| 19 |
+
from mamba_ssm import Mamba
|
| 20 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm
|
| 21 |
+
|
| 22 |
+
from .configuration_genemamba import GeneMambaConfig
|
| 23 |
+
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================
|
| 29 |
+
# Core Architecture Components
|
| 30 |
+
# ===========================
|
| 31 |
+
|
| 32 |
+
class EncoderLayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Single Mamba encoder layer with residual connection.
|
| 35 |
+
Applies a Mamba2 or Mamba layer followed by addition with input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (int): Dimension of hidden states.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size: int):
|
| 42 |
+
super(EncoderLayer, self).__init__()
|
| 43 |
+
self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
|
| 44 |
+
|
| 45 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output after Mamba layer and residual connection.
|
| 52 |
+
"""
|
| 53 |
+
output = self.mamba(X) + X
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MambaMixer(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Stack of Mamba encoder layers with bidirectional processing and aggregation.
|
| 60 |
+
Processes sequences in both forward and reverse directions, then aggregates.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
|
| 64 |
+
hidden_size (int): Dimension of hidden states.
|
| 65 |
+
num_hidden_layers (int): Number of Mamba layers.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
mode: str = "gate",
|
| 71 |
+
hidden_size: int = 512,
|
| 72 |
+
num_hidden_layers: int = 24
|
| 73 |
+
):
|
| 74 |
+
super(MambaMixer, self).__init__()
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
|
| 78 |
+
# Create Mamba layers
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Aggregation modules for certain modes
|
| 84 |
+
if mode in ["concat", "gate"]:
|
| 85 |
+
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
|
| 86 |
+
|
| 87 |
+
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Reverse a sequence based on actual length (ignoring padding).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 93 |
+
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Reversed tensor.
|
| 97 |
+
"""
|
| 98 |
+
batch_size, seq_length, embedding_dim = X.size()
|
| 99 |
+
|
| 100 |
+
if mask is None:
|
| 101 |
+
# Simple flip
|
| 102 |
+
return X.flip([1])
|
| 103 |
+
|
| 104 |
+
# Flip based on actual sequence length (marked by mask)
|
| 105 |
+
lengths = (~mask).sum(dim=1)
|
| 106 |
+
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
|
| 107 |
+
flip_mask = pos_tensor < lengths.unsqueeze(1)
|
| 108 |
+
reversed_positions = torch.where(
|
| 109 |
+
flip_mask,
|
| 110 |
+
lengths.unsqueeze(1) - 1 - pos_tensor,
|
| 111 |
+
pos_tensor
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
|
| 115 |
+
return X_reverse
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
X: torch.Tensor,
|
| 120 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Process sequence through bidirectional Mamba layers.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 127 |
+
padding_mask (torch.Tensor, optional): Padding mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Output after processing all layers and aggregation.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
# Flip sequence for reverse processing
|
| 135 |
+
X_flip = self.flip_sequence(X, padding_mask)
|
| 136 |
+
|
| 137 |
+
# Forward and reverse passes
|
| 138 |
+
X_f = layer(X)
|
| 139 |
+
X_b = layer(X_flip)
|
| 140 |
+
|
| 141 |
+
# Flip back the reverse output
|
| 142 |
+
X_b = self.flip_sequence(X_b, padding_mask)
|
| 143 |
+
|
| 144 |
+
# Aggregate forward and reverse
|
| 145 |
+
if self.mode == "mean":
|
| 146 |
+
X = (X_f + X_b) / 2
|
| 147 |
+
elif self.mode == "sum":
|
| 148 |
+
X = X_f + X_b
|
| 149 |
+
elif self.mode == "concat":
|
| 150 |
+
X = torch.cat([X_f, X_b], dim=-1)
|
| 151 |
+
X = self.aggr(X)
|
| 152 |
+
elif self.mode == "gate":
|
| 153 |
+
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
|
| 154 |
+
X = z * X_f + (1 - z) * X_b
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid aggregation mode: {self.mode}")
|
| 157 |
+
|
| 158 |
+
return X
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ===========================
|
| 162 |
+
# Base Model Classes
|
| 163 |
+
# ===========================
|
| 164 |
+
|
| 165 |
+
class GeneMambaPreTrainedModel(PreTrainedModel):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all GeneMamba models.
|
| 168 |
+
Handles weight initialization and provides standard model interfaces.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
config_class = GeneMambaConfig
|
| 172 |
+
base_model_prefix = "genemamba"
|
| 173 |
+
supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
"""Initialize module weights."""
|
| 177 |
+
if isinstance(module, nn.Linear):
|
| 178 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
constant_(module.bias, 0.0)
|
| 181 |
+
elif isinstance(module, nn.Embedding):
|
| 182 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 183 |
+
if module.padding_idx is not None:
|
| 184 |
+
module.weight.data[module.padding_idx].zero_()
|
| 185 |
+
elif isinstance(module, nn.LayerNorm):
|
| 186 |
+
constant_(module.bias, 0.0)
|
| 187 |
+
constant_(module.weight, 1.0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GeneMambaModel(GeneMambaPreTrainedModel):
|
| 191 |
+
"""
|
| 192 |
+
GeneMamba backbone model - outputs cell embeddings and hidden states.
|
| 193 |
+
This is the core model used by task-specific heads.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (GeneMambaConfig): Model configuration class.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: GeneMambaConfig):
|
| 200 |
+
super().__init__(config)
|
| 201 |
+
self.config = config
|
| 202 |
+
|
| 203 |
+
# Embedding layer
|
| 204 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 205 |
+
|
| 206 |
+
# Mamba layers with bidirectional aggregation
|
| 207 |
+
self.mamba_mixer = MambaMixer(
|
| 208 |
+
mode=config.mamba_mode,
|
| 209 |
+
hidden_size=config.hidden_size,
|
| 210 |
+
num_hidden_layers=config.num_hidden_layers
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Final layer normalization
|
| 214 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 215 |
+
|
| 216 |
+
self.apply(self._init_weights)
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 219 |
+
"""Return embedding layer."""
|
| 220 |
+
return self.embeddings
|
| 221 |
+
|
| 222 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 223 |
+
"""Set embedding layer."""
|
| 224 |
+
self.embeddings = value
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 230 |
+
output_hidden_states: bool = False,
|
| 231 |
+
) -> GeneMambaModelOutput:
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 235 |
+
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
|
| 236 |
+
output_hidden_states (bool): Whether to output hidden states from all layers.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
|
| 240 |
+
"""
|
| 241 |
+
# Get embeddings
|
| 242 |
+
hidden_states = self.embeddings(input_ids)
|
| 243 |
+
|
| 244 |
+
# Pass through Mamba layers
|
| 245 |
+
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 246 |
+
|
| 247 |
+
# Apply final normalization
|
| 248 |
+
hidden_states = self.norm(hidden_states)
|
| 249 |
+
|
| 250 |
+
# Compute pooled embedding (cell representation)
|
| 251 |
+
if self.config.embedding_pooling == "CLS":
|
| 252 |
+
# Use first token (CLS)
|
| 253 |
+
pooled_embedding = hidden_states[:, 0, :]
|
| 254 |
+
elif self.config.embedding_pooling == "mean":
|
| 255 |
+
# Mean pooling over sequence
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
|
| 258 |
+
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
|
| 259 |
+
else:
|
| 260 |
+
pooled_embedding = hidden_states.mean(dim=1)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
|
| 263 |
+
|
| 264 |
+
return GeneMambaModelOutput(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
pooled_embedding=pooled_embedding,
|
| 267 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 268 |
+
embedding_pooling=self.config.embedding_pooling,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ===========================
|
| 273 |
+
# Task-Specific Models
|
| 274 |
+
# ===========================
|
| 275 |
+
|
| 276 |
+
@register_model_for_auto_class("AutoModel")
|
| 277 |
+
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
|
| 278 |
+
"""
|
| 279 |
+
GeneMamba model for masked language modeling (MLM).
|
| 280 |
+
Suitable for pretraining and domain adaptation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config (GeneMambaConfig): Model configuration class.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: GeneMambaConfig):
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.genemamba = GeneMambaModel(config)
|
| 289 |
+
|
| 290 |
+
# Language modeling head
|
| 291 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 292 |
+
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
labels: Optional[torch.Tensor] = None,
|
| 300 |
+
output_hidden_states: bool = False,
|
| 301 |
+
) -> GeneMambaMaskedLMOutput:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 305 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 306 |
+
labels (torch.Tensor, optional): Target token ids for MLM loss.
|
| 307 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GeneMambaMaskedLMOutput: Contains logits and optional loss.
|
| 311 |
+
"""
|
| 312 |
+
outputs = self.genemamba(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
| 319 |
+
|
| 320 |
+
loss = None
|
| 321 |
+
if labels is not None:
|
| 322 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 323 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 324 |
+
|
| 325 |
+
return GeneMambaMaskedLMOutput(
|
| 326 |
+
loss=loss,
|
| 327 |
+
logits=logits,
|
| 328 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model_for_auto_class("AutoModelForSequenceClassification")
|
| 333 |
+
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
|
| 334 |
+
"""
|
| 335 |
+
GeneMamba model for sequence classification tasks.
|
| 336 |
+
Ideal for cell type annotation, tissue classification, etc.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
config (GeneMambaConfig): Model configuration class.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: GeneMambaConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
self.config = config
|
| 346 |
+
|
| 347 |
+
self.genemamba = GeneMambaModel(config)
|
| 348 |
+
|
| 349 |
+
# Classification head
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 352 |
+
|
| 353 |
+
self.apply(self._init_weights)
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
input_ids: torch.Tensor,
|
| 358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
labels: Optional[torch.Tensor] = None,
|
| 360 |
+
output_hidden_states: bool = False,
|
| 361 |
+
) -> GeneMambaSequenceClassifierOutput:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 365 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 366 |
+
labels (torch.Tensor, optional): Class labels for classification loss.
|
| 367 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
|
| 371 |
+
"""
|
| 372 |
+
outputs = self.genemamba(
|
| 373 |
+
input_ids=input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
pooled_embedding = outputs.pooled_embedding
|
| 379 |
+
logits = self.classifier(self.dropout(pooled_embedding))
|
| 380 |
+
|
| 381 |
+
loss = None
|
| 382 |
+
if labels is not None:
|
| 383 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 384 |
+
loss = loss_fct(logits, labels)
|
| 385 |
+
|
| 386 |
+
return GeneMambaSequenceClassifierOutput(
|
| 387 |
+
loss=loss,
|
| 388 |
+
logits=logits,
|
| 389 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 390 |
+
pooled_embedding=pooled_embedding,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Register tokenizer class
|
| 395 |
+
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
|
48l-512d/modeling_outputs.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom ModelOutput classes for GeneMamba.
|
| 3 |
+
Defines the output structure for different GeneMamba tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GeneMambaModelOutput(ModelOutput):
|
| 14 |
+
"""
|
| 15 |
+
Base output class for GeneMamba models.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
|
| 21 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 22 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 23 |
+
|
| 24 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
|
| 25 |
+
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
|
| 26 |
+
This is the recommended embedding to use for classification, clustering, etc.
|
| 27 |
+
|
| 28 |
+
embedding_pooling (str):
|
| 29 |
+
The pooling method used to generate pooled_embedding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
last_hidden_state: torch.FloatTensor = None
|
| 33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 34 |
+
pooled_embedding: torch.FloatTensor = None
|
| 35 |
+
embedding_pooling: str = "mean"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneMambaSequenceClassifierOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Output class for GeneMamba sequence classification models.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 45 |
+
Classification loss (if labels were provided).
|
| 46 |
+
|
| 47 |
+
logits (torch.FloatTensor of shape (batch_size, num_labels)):
|
| 48 |
+
Classification scores (before softmax).
|
| 49 |
+
|
| 50 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 51 |
+
Hidden-states of the model at the output of each layer.
|
| 52 |
+
|
| 53 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
|
| 54 |
+
Cell embedding before classification head.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loss: Optional[torch.FloatTensor] = None
|
| 58 |
+
logits: torch.FloatTensor = None
|
| 59 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 60 |
+
pooled_embedding: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GeneMambaMaskedLMOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Output class for GeneMamba masked language modeling.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 70 |
+
MLM loss (if labels were provided).
|
| 71 |
+
|
| 72 |
+
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
|
| 73 |
+
Prediction scores of the language modeling head.
|
| 74 |
+
|
| 75 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 76 |
+
Hidden-states of the model at the output of each layer.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
logits: torch.FloatTensor = None
|
| 81 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
48l-512d/special_tokens_map.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "[PAD]",
|
| 3 |
+
"unk_token": "[UNK]"
|
| 4 |
+
}
|
48l-512d/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
48l-512d/tokenizer_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {},
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 5 |
+
"pad_token": "[PAD]",
|
| 6 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 7 |
+
"unk_token": "[UNK]"
|
| 8 |
+
}
|
48l-768d/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "genemamba",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GeneMambaModel"
|
| 5 |
+
],
|
| 6 |
+
"vocab_size": 25426,
|
| 7 |
+
"max_position_embeddings": 2048,
|
| 8 |
+
"hidden_size": 768,
|
| 9 |
+
"num_hidden_layers": 48,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mamba_mode": "gate",
|
| 14 |
+
"embedding_pooling": "mean",
|
| 15 |
+
"num_labels": 2,
|
| 16 |
+
"pad_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.40.2",
|
| 22 |
+
"auto_map": {
|
| 23 |
+
"AutoConfig": "configuration_genemamba.GeneMambaConfig",
|
| 24 |
+
"AutoModel": "modeling_genemamba.GeneMambaModel",
|
| 25 |
+
"AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
|
| 26 |
+
"AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
|
| 27 |
+
}
|
| 28 |
+
}
|
48l-768d/configuration_genemamba.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GeneMamba model.
|
| 3 |
+
Defines all hyperparameters and settings for the GeneMamba architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneMambaConfig(PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for GeneMamba model.
|
| 13 |
+
|
| 14 |
+
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
|
| 15 |
+
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (int, optional, defaults to 25426):
|
| 19 |
+
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
|
| 20 |
+
|
| 21 |
+
hidden_size (int, optional, defaults to 512):
|
| 22 |
+
Dimensionality of the hidden/embedding layers (d_model in Mamba).
|
| 23 |
+
|
| 24 |
+
num_hidden_layers (int, optional, defaults to 24):
|
| 25 |
+
Number of Mamba layers (mamba_layer).
|
| 26 |
+
|
| 27 |
+
intermediate_size (int, optional, defaults to 2048):
|
| 28 |
+
Dimensionality of intermediate representations in MLP.
|
| 29 |
+
|
| 30 |
+
max_position_embeddings (int, optional, defaults to 2048):
|
| 31 |
+
Maximum sequence length (seq_len).
|
| 32 |
+
|
| 33 |
+
hidden_dropout_prob (float, optional, defaults to 0.1):
|
| 34 |
+
Dropout probability for hidden states.
|
| 35 |
+
|
| 36 |
+
initializer_range (float, optional, defaults to 0.02):
|
| 37 |
+
Standard deviation of truncated normal initializer.
|
| 38 |
+
|
| 39 |
+
mamba_mode (str, optional, defaults to "gate"):
|
| 40 |
+
Aggregation mode for bidirectional Mamba layers.
|
| 41 |
+
Options: "mean", "sum", "concat", "gate".
|
| 42 |
+
|
| 43 |
+
embedding_pooling (str, optional, defaults to "mean"):
|
| 44 |
+
Method for pooling to get cell embedding.
|
| 45 |
+
Options: "CLS", "mean", "weighted".
|
| 46 |
+
|
| 47 |
+
num_labels (int, optional, defaults to 2):
|
| 48 |
+
Number of labels for sequence classification tasks.
|
| 49 |
+
|
| 50 |
+
pad_token_id (int, optional, defaults to 1):
|
| 51 |
+
Token ID for padding.
|
| 52 |
+
|
| 53 |
+
bos_token_id (int, optional, defaults to None):
|
| 54 |
+
Token ID for beginning of sequence.
|
| 55 |
+
|
| 56 |
+
eos_token_id (int, optional, defaults to None):
|
| 57 |
+
Token ID for end of sequence.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_type = "genemamba"
|
| 61 |
+
attribute_map = {
|
| 62 |
+
"hidden_size": "hidden_size",
|
| 63 |
+
"num_hidden_layers": "num_hidden_layers",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
vocab_size: int = 25426,
|
| 69 |
+
hidden_size: int = 512,
|
| 70 |
+
num_hidden_layers: int = 24,
|
| 71 |
+
intermediate_size: int = 2048,
|
| 72 |
+
max_position_embeddings: int = 2048,
|
| 73 |
+
hidden_dropout_prob: float = 0.1,
|
| 74 |
+
initializer_range: float = 0.02,
|
| 75 |
+
mamba_mode: str = "gate",
|
| 76 |
+
embedding_pooling: str = "mean",
|
| 77 |
+
num_labels: int = 2,
|
| 78 |
+
pad_token_id: int = 1,
|
| 79 |
+
bos_token_id: Optional[int] = None,
|
| 80 |
+
eos_token_id: Optional[int] = None,
|
| 81 |
+
**kwargs
|
| 82 |
+
):
|
| 83 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = vocab_size
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.num_hidden_layers = num_hidden_layers
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.max_position_embeddings = max_position_embeddings
|
| 90 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 91 |
+
self.initializer_range = initializer_range
|
| 92 |
+
self.mamba_mode = mamba_mode
|
| 93 |
+
self.embedding_pooling = embedding_pooling
|
| 94 |
+
self.num_labels = num_labels
|
| 95 |
+
self.pad_token_id = pad_token_id
|
| 96 |
+
self.bos_token_id = bos_token_id
|
| 97 |
+
self.eos_token_id = eos_token_id
|
48l-768d/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:728514a211350e69937d73398dffa4c6bbb7f59366fb6c8b39f27437a6a5af77
|
| 3 |
+
size 860161160
|
48l-768d/modeling_genemamba.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
|
| 3 |
+
Includes backbone model and task-specific heads for various downstream tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.init import normal_, constant_
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 16 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
|
| 17 |
+
from transformers.models.auto import register_model_for_auto_class
|
| 18 |
+
|
| 19 |
+
from mamba_ssm import Mamba
|
| 20 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm
|
| 21 |
+
|
| 22 |
+
from .configuration_genemamba import GeneMambaConfig
|
| 23 |
+
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================
|
| 29 |
+
# Core Architecture Components
|
| 30 |
+
# ===========================
|
| 31 |
+
|
| 32 |
+
class EncoderLayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Single Mamba encoder layer with residual connection.
|
| 35 |
+
Applies a Mamba2 or Mamba layer followed by addition with input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (int): Dimension of hidden states.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size: int):
|
| 42 |
+
super(EncoderLayer, self).__init__()
|
| 43 |
+
self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
|
| 44 |
+
|
| 45 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output after Mamba layer and residual connection.
|
| 52 |
+
"""
|
| 53 |
+
output = self.mamba(X) + X
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MambaMixer(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Stack of Mamba encoder layers with bidirectional processing and aggregation.
|
| 60 |
+
Processes sequences in both forward and reverse directions, then aggregates.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
|
| 64 |
+
hidden_size (int): Dimension of hidden states.
|
| 65 |
+
num_hidden_layers (int): Number of Mamba layers.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
mode: str = "gate",
|
| 71 |
+
hidden_size: int = 512,
|
| 72 |
+
num_hidden_layers: int = 24
|
| 73 |
+
):
|
| 74 |
+
super(MambaMixer, self).__init__()
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
|
| 78 |
+
# Create Mamba layers
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Aggregation modules for certain modes
|
| 84 |
+
if mode in ["concat", "gate"]:
|
| 85 |
+
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
|
| 86 |
+
|
| 87 |
+
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Reverse a sequence based on actual length (ignoring padding).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 93 |
+
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Reversed tensor.
|
| 97 |
+
"""
|
| 98 |
+
batch_size, seq_length, embedding_dim = X.size()
|
| 99 |
+
|
| 100 |
+
if mask is None:
|
| 101 |
+
# Simple flip
|
| 102 |
+
return X.flip([1])
|
| 103 |
+
|
| 104 |
+
# Flip based on actual sequence length (marked by mask)
|
| 105 |
+
lengths = (~mask).sum(dim=1)
|
| 106 |
+
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
|
| 107 |
+
flip_mask = pos_tensor < lengths.unsqueeze(1)
|
| 108 |
+
reversed_positions = torch.where(
|
| 109 |
+
flip_mask,
|
| 110 |
+
lengths.unsqueeze(1) - 1 - pos_tensor,
|
| 111 |
+
pos_tensor
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
|
| 115 |
+
return X_reverse
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
X: torch.Tensor,
|
| 120 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Process sequence through bidirectional Mamba layers.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 127 |
+
padding_mask (torch.Tensor, optional): Padding mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Output after processing all layers and aggregation.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
# Flip sequence for reverse processing
|
| 135 |
+
X_flip = self.flip_sequence(X, padding_mask)
|
| 136 |
+
|
| 137 |
+
# Forward and reverse passes
|
| 138 |
+
X_f = layer(X)
|
| 139 |
+
X_b = layer(X_flip)
|
| 140 |
+
|
| 141 |
+
# Flip back the reverse output
|
| 142 |
+
X_b = self.flip_sequence(X_b, padding_mask)
|
| 143 |
+
|
| 144 |
+
# Aggregate forward and reverse
|
| 145 |
+
if self.mode == "mean":
|
| 146 |
+
X = (X_f + X_b) / 2
|
| 147 |
+
elif self.mode == "sum":
|
| 148 |
+
X = X_f + X_b
|
| 149 |
+
elif self.mode == "concat":
|
| 150 |
+
X = torch.cat([X_f, X_b], dim=-1)
|
| 151 |
+
X = self.aggr(X)
|
| 152 |
+
elif self.mode == "gate":
|
| 153 |
+
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
|
| 154 |
+
X = z * X_f + (1 - z) * X_b
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid aggregation mode: {self.mode}")
|
| 157 |
+
|
| 158 |
+
return X
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ===========================
|
| 162 |
+
# Base Model Classes
|
| 163 |
+
# ===========================
|
| 164 |
+
|
| 165 |
+
class GeneMambaPreTrainedModel(PreTrainedModel):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all GeneMamba models.
|
| 168 |
+
Handles weight initialization and provides standard model interfaces.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
config_class = GeneMambaConfig
|
| 172 |
+
base_model_prefix = "genemamba"
|
| 173 |
+
supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
"""Initialize module weights."""
|
| 177 |
+
if isinstance(module, nn.Linear):
|
| 178 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
constant_(module.bias, 0.0)
|
| 181 |
+
elif isinstance(module, nn.Embedding):
|
| 182 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 183 |
+
if module.padding_idx is not None:
|
| 184 |
+
module.weight.data[module.padding_idx].zero_()
|
| 185 |
+
elif isinstance(module, nn.LayerNorm):
|
| 186 |
+
constant_(module.bias, 0.0)
|
| 187 |
+
constant_(module.weight, 1.0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GeneMambaModel(GeneMambaPreTrainedModel):
|
| 191 |
+
"""
|
| 192 |
+
GeneMamba backbone model - outputs cell embeddings and hidden states.
|
| 193 |
+
This is the core model used by task-specific heads.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (GeneMambaConfig): Model configuration class.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: GeneMambaConfig):
|
| 200 |
+
super().__init__(config)
|
| 201 |
+
self.config = config
|
| 202 |
+
|
| 203 |
+
# Embedding layer
|
| 204 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 205 |
+
|
| 206 |
+
# Mamba layers with bidirectional aggregation
|
| 207 |
+
self.mamba_mixer = MambaMixer(
|
| 208 |
+
mode=config.mamba_mode,
|
| 209 |
+
hidden_size=config.hidden_size,
|
| 210 |
+
num_hidden_layers=config.num_hidden_layers
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Final layer normalization
|
| 214 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 215 |
+
|
| 216 |
+
self.apply(self._init_weights)
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 219 |
+
"""Return embedding layer."""
|
| 220 |
+
return self.embeddings
|
| 221 |
+
|
| 222 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 223 |
+
"""Set embedding layer."""
|
| 224 |
+
self.embeddings = value
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 230 |
+
output_hidden_states: bool = False,
|
| 231 |
+
) -> GeneMambaModelOutput:
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 235 |
+
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
|
| 236 |
+
output_hidden_states (bool): Whether to output hidden states from all layers.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
|
| 240 |
+
"""
|
| 241 |
+
# Get embeddings
|
| 242 |
+
hidden_states = self.embeddings(input_ids)
|
| 243 |
+
|
| 244 |
+
# Pass through Mamba layers
|
| 245 |
+
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 246 |
+
|
| 247 |
+
# Apply final normalization
|
| 248 |
+
hidden_states = self.norm(hidden_states)
|
| 249 |
+
|
| 250 |
+
# Compute pooled embedding (cell representation)
|
| 251 |
+
if self.config.embedding_pooling == "CLS":
|
| 252 |
+
# Use first token (CLS)
|
| 253 |
+
pooled_embedding = hidden_states[:, 0, :]
|
| 254 |
+
elif self.config.embedding_pooling == "mean":
|
| 255 |
+
# Mean pooling over sequence
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
|
| 258 |
+
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
|
| 259 |
+
else:
|
| 260 |
+
pooled_embedding = hidden_states.mean(dim=1)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
|
| 263 |
+
|
| 264 |
+
return GeneMambaModelOutput(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
pooled_embedding=pooled_embedding,
|
| 267 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 268 |
+
embedding_pooling=self.config.embedding_pooling,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ===========================
|
| 273 |
+
# Task-Specific Models
|
| 274 |
+
# ===========================
|
| 275 |
+
|
| 276 |
+
@register_model_for_auto_class("AutoModel")
|
| 277 |
+
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
|
| 278 |
+
"""
|
| 279 |
+
GeneMamba model for masked language modeling (MLM).
|
| 280 |
+
Suitable for pretraining and domain adaptation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config (GeneMambaConfig): Model configuration class.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: GeneMambaConfig):
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.genemamba = GeneMambaModel(config)
|
| 289 |
+
|
| 290 |
+
# Language modeling head
|
| 291 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 292 |
+
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
labels: Optional[torch.Tensor] = None,
|
| 300 |
+
output_hidden_states: bool = False,
|
| 301 |
+
) -> GeneMambaMaskedLMOutput:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 305 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 306 |
+
labels (torch.Tensor, optional): Target token ids for MLM loss.
|
| 307 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GeneMambaMaskedLMOutput: Contains logits and optional loss.
|
| 311 |
+
"""
|
| 312 |
+
outputs = self.genemamba(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
| 319 |
+
|
| 320 |
+
loss = None
|
| 321 |
+
if labels is not None:
|
| 322 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 323 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 324 |
+
|
| 325 |
+
return GeneMambaMaskedLMOutput(
|
| 326 |
+
loss=loss,
|
| 327 |
+
logits=logits,
|
| 328 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model_for_auto_class("AutoModelForSequenceClassification")
|
| 333 |
+
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
|
| 334 |
+
"""
|
| 335 |
+
GeneMamba model for sequence classification tasks.
|
| 336 |
+
Ideal for cell type annotation, tissue classification, etc.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
config (GeneMambaConfig): Model configuration class.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: GeneMambaConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
self.config = config
|
| 346 |
+
|
| 347 |
+
self.genemamba = GeneMambaModel(config)
|
| 348 |
+
|
| 349 |
+
# Classification head
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 352 |
+
|
| 353 |
+
self.apply(self._init_weights)
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
input_ids: torch.Tensor,
|
| 358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
labels: Optional[torch.Tensor] = None,
|
| 360 |
+
output_hidden_states: bool = False,
|
| 361 |
+
) -> GeneMambaSequenceClassifierOutput:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 365 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 366 |
+
labels (torch.Tensor, optional): Class labels for classification loss.
|
| 367 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
|
| 371 |
+
"""
|
| 372 |
+
outputs = self.genemamba(
|
| 373 |
+
input_ids=input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
pooled_embedding = outputs.pooled_embedding
|
| 379 |
+
logits = self.classifier(self.dropout(pooled_embedding))
|
| 380 |
+
|
| 381 |
+
loss = None
|
| 382 |
+
if labels is not None:
|
| 383 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 384 |
+
loss = loss_fct(logits, labels)
|
| 385 |
+
|
| 386 |
+
return GeneMambaSequenceClassifierOutput(
|
| 387 |
+
loss=loss,
|
| 388 |
+
logits=logits,
|
| 389 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 390 |
+
pooled_embedding=pooled_embedding,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Register tokenizer class
|
| 395 |
+
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
|
48l-768d/modeling_outputs.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom ModelOutput classes for GeneMamba.
|
| 3 |
+
Defines the output structure for different GeneMamba tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GeneMambaModelOutput(ModelOutput):
|
| 14 |
+
"""
|
| 15 |
+
Base output class for GeneMamba models.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
|
| 21 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 22 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 23 |
+
|
| 24 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
|
| 25 |
+
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
|
| 26 |
+
This is the recommended embedding to use for classification, clustering, etc.
|
| 27 |
+
|
| 28 |
+
embedding_pooling (str):
|
| 29 |
+
The pooling method used to generate pooled_embedding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
last_hidden_state: torch.FloatTensor = None
|
| 33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 34 |
+
pooled_embedding: torch.FloatTensor = None
|
| 35 |
+
embedding_pooling: str = "mean"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneMambaSequenceClassifierOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Output class for GeneMamba sequence classification models.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 45 |
+
Classification loss (if labels were provided).
|
| 46 |
+
|
| 47 |
+
logits (torch.FloatTensor of shape (batch_size, num_labels)):
|
| 48 |
+
Classification scores (before softmax).
|
| 49 |
+
|
| 50 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 51 |
+
Hidden-states of the model at the output of each layer.
|
| 52 |
+
|
| 53 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
|
| 54 |
+
Cell embedding before classification head.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loss: Optional[torch.FloatTensor] = None
|
| 58 |
+
logits: torch.FloatTensor = None
|
| 59 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 60 |
+
pooled_embedding: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GeneMambaMaskedLMOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Output class for GeneMamba masked language modeling.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 70 |
+
MLM loss (if labels were provided).
|
| 71 |
+
|
| 72 |
+
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
|
| 73 |
+
Prediction scores of the language modeling head.
|
| 74 |
+
|
| 75 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 76 |
+
Hidden-states of the model at the output of each layer.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
logits: torch.FloatTensor = None
|
| 81 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
48l-768d/special_tokens_map.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "[PAD]",
|
| 3 |
+
"unk_token": "[UNK]"
|
| 4 |
+
}
|
48l-768d/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
48l-768d/tokenizer_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {},
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 5 |
+
"pad_token": "[PAD]",
|
| 6 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 7 |
+
"unk_token": "[UNK]"
|
| 8 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- genomics
|
| 5 |
+
- single-cell
|
| 6 |
+
- mamba
|
| 7 |
+
- biology
|
| 8 |
+
pipeline_tag: feature-extraction
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# GeneMamba
|
| 12 |
+
|
| 13 |
+
This repository contains a **default GeneMamba model** plus full usage assets:
|
| 14 |
+
- default model weights at repository root (**24l-512d**)
|
| 15 |
+
- custom modeling/config files for `trust_remote_code=True`
|
| 16 |
+
- preprocessing example from `h5ad` to `input_ids`
|
| 17 |
+
- tokenizer assets and id mapping files
|
| 18 |
+
|
| 19 |
+
Additional model sizes are provided as subfolders:
|
| 20 |
+
- `24l-512d` (same architecture class as default)
|
| 21 |
+
- `24l-768d`
|
| 22 |
+
- `48l-512d`
|
| 23 |
+
- `48l-768d`
|
| 24 |
+
|
| 25 |
+
## 1) Input format (very important)
|
| 26 |
+
|
| 27 |
+
GeneMamba input is **ranked gene token IDs** per cell:
|
| 28 |
+
1. Start from one cell expression vector
|
| 29 |
+
2. Keep genes with expression > 0
|
| 30 |
+
3. Sort genes by expression descending
|
| 31 |
+
4. Convert each gene ID (Ensembl, e.g. `ENSG00000000003`) to token ID
|
| 32 |
+
5. Use resulting list as `input_ids`
|
| 33 |
+
|
| 34 |
+
Each sample is one list of integers:
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
{"input_ids": [145, 2088, 531, 91, ...]}
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
For batch input, shape is typically `(batch_size, seq_len)` after padding/truncation.
|
| 41 |
+
|
| 42 |
+
## 2) Where tokenizer and id mapping come from
|
| 43 |
+
|
| 44 |
+
- Main tokenizer used for model inference: `tokenizer.json`
|
| 45 |
+
- Original full tokenizer table: `tokenizer_assets/gene_tokenizer.json`
|
| 46 |
+
- Gene symbol -> token id mapping: `tokenizer_assets/symbol2id.pkl`
|
| 47 |
+
- Token id -> gene symbol mapping: `tokenizer_assets/id2symbol.pkl`
|
| 48 |
+
|
| 49 |
+
Special tokens:
|
| 50 |
+
- `[UNK]` = 0
|
| 51 |
+
- `[PAD]` = 1
|
| 52 |
+
|
| 53 |
+
## 3) Preprocess your data
|
| 54 |
+
|
| 55 |
+
See script:
|
| 56 |
+
- `examples/00_preprocess_to_input_ids.py`
|
| 57 |
+
|
| 58 |
+
Example:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python examples/00_preprocess_to_input_ids.py \
|
| 62 |
+
--h5ad /path/to/your_data.h5ad \
|
| 63 |
+
--tokenizer_json tokenizer.json \
|
| 64 |
+
--output_arrow ./my_data/sorted_gene_token_ids.arrow
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
This output Arrow file has one column: `input_ids`.
|
| 68 |
+
|
| 69 |
+
## 4) Load model and extract embedding
|
| 70 |
+
|
| 71 |
+
### Default load (24l-512d)
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
from transformers import AutoModel, AutoTokenizer
|
| 75 |
+
|
| 76 |
+
model = AutoModel.from_pretrained(
|
| 77 |
+
"mineself2016/GeneMamba",
|
| 78 |
+
trust_remote_code=True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 82 |
+
"mineself2016/GeneMamba",
|
| 83 |
+
trust_remote_code=True
|
| 84 |
+
)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Load other sizes (via `subfolder`)
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
from transformers import AutoModel
|
| 91 |
+
|
| 92 |
+
model_24l_768d = AutoModel.from_pretrained(
|
| 93 |
+
"mineself2016/GeneMamba",
|
| 94 |
+
subfolder="24l-768d",
|
| 95 |
+
trust_remote_code=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
model_48l_512d = AutoModel.from_pretrained(
|
| 99 |
+
"mineself2016/GeneMamba",
|
| 100 |
+
subfolder="48l-512d",
|
| 101 |
+
trust_remote_code=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
model_48l_768d = AutoModel.from_pretrained(
|
| 105 |
+
"mineself2016/GeneMamba",
|
| 106 |
+
subfolder="48l-768d",
|
| 107 |
+
trust_remote_code=True,
|
| 108 |
+
)
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
More complete example:
|
| 112 |
+
- `examples/01_extract_embeddings.py`
|
| 113 |
+
|
| 114 |
+
## 6) Downstream task examples (added)
|
| 115 |
+
|
| 116 |
+
See:
|
| 117 |
+
- `examples/downstream/README.md`
|
| 118 |
+
|
| 119 |
+
Included downstream tasks:
|
| 120 |
+
- cell type annotation fine-tuning
|
| 121 |
+
- zero-shot embedding + logistic regression
|
| 122 |
+
- batch integration proxy evaluation
|
| 123 |
+
- original legacy downstream scripts from `gene_mamba/analysis/cell_type_annotation`
|
| 124 |
+
|
| 125 |
+
## 7) Source of preprocessing logic
|
| 126 |
+
|
| 127 |
+
The preprocessing/tokenization pipeline is aligned with assets from:
|
| 128 |
+
- `/project/zhiwei/cq5/PythonWorkSpace/gene_mamba`
|
| 129 |
+
|
| 130 |
+
Key references used:
|
| 131 |
+
- tokenizer: `gene_tokenizer.json`
|
| 132 |
+
- mappings: `symbol2id.pkl`, `id2symbol.pkl`
|
| 133 |
+
- dataset build logic (Arrow + `input_ids`): `utils.py` (`build_dataset`)
|
config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "genemamba",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GeneMambaModel"
|
| 5 |
+
],
|
| 6 |
+
"vocab_size": 25426,
|
| 7 |
+
"max_position_embeddings": 2048,
|
| 8 |
+
"hidden_size": 512,
|
| 9 |
+
"num_hidden_layers": 24,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mamba_mode": "gate",
|
| 14 |
+
"embedding_pooling": "mean",
|
| 15 |
+
"num_labels": 2,
|
| 16 |
+
"pad_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.40.2",
|
| 22 |
+
"auto_map": {
|
| 23 |
+
"AutoConfig": "configuration_genemamba.GeneMambaConfig",
|
| 24 |
+
"AutoModel": "modeling_genemamba.GeneMambaModel",
|
| 25 |
+
"AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
|
| 26 |
+
"AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
|
| 27 |
+
}
|
| 28 |
+
}
|
configuration_genemamba.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GeneMamba model.
|
| 3 |
+
Defines all hyperparameters and settings for the GeneMamba architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneMambaConfig(PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for GeneMamba model.
|
| 13 |
+
|
| 14 |
+
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
|
| 15 |
+
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (int, optional, defaults to 25426):
|
| 19 |
+
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
|
| 20 |
+
|
| 21 |
+
hidden_size (int, optional, defaults to 512):
|
| 22 |
+
Dimensionality of the hidden/embedding layers (d_model in Mamba).
|
| 23 |
+
|
| 24 |
+
num_hidden_layers (int, optional, defaults to 24):
|
| 25 |
+
Number of Mamba layers (mamba_layer).
|
| 26 |
+
|
| 27 |
+
intermediate_size (int, optional, defaults to 2048):
|
| 28 |
+
Dimensionality of intermediate representations in MLP.
|
| 29 |
+
|
| 30 |
+
max_position_embeddings (int, optional, defaults to 2048):
|
| 31 |
+
Maximum sequence length (seq_len).
|
| 32 |
+
|
| 33 |
+
hidden_dropout_prob (float, optional, defaults to 0.1):
|
| 34 |
+
Dropout probability for hidden states.
|
| 35 |
+
|
| 36 |
+
initializer_range (float, optional, defaults to 0.02):
|
| 37 |
+
Standard deviation of truncated normal initializer.
|
| 38 |
+
|
| 39 |
+
mamba_mode (str, optional, defaults to "gate"):
|
| 40 |
+
Aggregation mode for bidirectional Mamba layers.
|
| 41 |
+
Options: "mean", "sum", "concat", "gate".
|
| 42 |
+
|
| 43 |
+
embedding_pooling (str, optional, defaults to "mean"):
|
| 44 |
+
Method for pooling to get cell embedding.
|
| 45 |
+
Options: "CLS", "mean", "weighted".
|
| 46 |
+
|
| 47 |
+
num_labels (int, optional, defaults to 2):
|
| 48 |
+
Number of labels for sequence classification tasks.
|
| 49 |
+
|
| 50 |
+
pad_token_id (int, optional, defaults to 1):
|
| 51 |
+
Token ID for padding.
|
| 52 |
+
|
| 53 |
+
bos_token_id (int, optional, defaults to None):
|
| 54 |
+
Token ID for beginning of sequence.
|
| 55 |
+
|
| 56 |
+
eos_token_id (int, optional, defaults to None):
|
| 57 |
+
Token ID for end of sequence.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_type = "genemamba"
|
| 61 |
+
attribute_map = {
|
| 62 |
+
"hidden_size": "hidden_size",
|
| 63 |
+
"num_hidden_layers": "num_hidden_layers",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
vocab_size: int = 25426,
|
| 69 |
+
hidden_size: int = 512,
|
| 70 |
+
num_hidden_layers: int = 24,
|
| 71 |
+
intermediate_size: int = 2048,
|
| 72 |
+
max_position_embeddings: int = 2048,
|
| 73 |
+
hidden_dropout_prob: float = 0.1,
|
| 74 |
+
initializer_range: float = 0.02,
|
| 75 |
+
mamba_mode: str = "gate",
|
| 76 |
+
embedding_pooling: str = "mean",
|
| 77 |
+
num_labels: int = 2,
|
| 78 |
+
pad_token_id: int = 1,
|
| 79 |
+
bos_token_id: Optional[int] = None,
|
| 80 |
+
eos_token_id: Optional[int] = None,
|
| 81 |
+
**kwargs
|
| 82 |
+
):
|
| 83 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = vocab_size
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.num_hidden_layers = num_hidden_layers
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.max_position_embeddings = max_position_embeddings
|
| 90 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 91 |
+
self.initializer_range = initializer_range
|
| 92 |
+
self.mamba_mode = mamba_mode
|
| 93 |
+
self.embedding_pooling = embedding_pooling
|
| 94 |
+
self.num_labels = num_labels
|
| 95 |
+
self.pad_token_id = pad_token_id
|
| 96 |
+
self.bos_token_id = bos_token_id
|
| 97 |
+
self.eos_token_id = eos_token_id
|
examples/00_preprocess_to_input_ids.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import scanpy as sc
|
| 8 |
+
import pyarrow as pa
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_vocab(tokenizer_json_path: str):
|
| 12 |
+
with open(tokenizer_json_path, "r") as f:
|
| 13 |
+
tokenizer = json.load(f)
|
| 14 |
+
vocab = tokenizer["model"]["vocab"]
|
| 15 |
+
pad_id = vocab.get("[PAD]", 1)
|
| 16 |
+
unk_id = vocab.get("[UNK]", 0)
|
| 17 |
+
return vocab, pad_id, unk_id
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def ranked_gene_ids_for_cell(expr_values, gene_names, vocab):
|
| 21 |
+
nonzero_idx = np.where(expr_values > 0)[0]
|
| 22 |
+
if len(nonzero_idx) == 0:
|
| 23 |
+
return []
|
| 24 |
+
|
| 25 |
+
genes = np.array(gene_names)[nonzero_idx]
|
| 26 |
+
values = expr_values[nonzero_idx]
|
| 27 |
+
|
| 28 |
+
order = np.argsort(-values)
|
| 29 |
+
ranked_genes = genes[order]
|
| 30 |
+
|
| 31 |
+
token_ids = [vocab[g] for g in ranked_genes if g in vocab]
|
| 32 |
+
return token_ids
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
parser = argparse.ArgumentParser(description="Convert h5ad to GeneMamba input_ids (Arrow)")
|
| 37 |
+
parser.add_argument("--h5ad", required=True, help="Input h5ad file")
|
| 38 |
+
parser.add_argument("--tokenizer_json", required=True, help="Path to tokenizer.json or gene_tokenizer.json")
|
| 39 |
+
parser.add_argument("--output_arrow", required=True, help="Output arrow file path")
|
| 40 |
+
parser.add_argument("--max_cells", type=int, default=None, help="Optional: process first N cells only")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
adata = sc.read_h5ad(args.h5ad)
|
| 44 |
+
vocab, _, _ = load_vocab(args.tokenizer_json)
|
| 45 |
+
|
| 46 |
+
gene_names = list(adata.var_names)
|
| 47 |
+
n_cells = adata.n_obs if args.max_cells is None else min(args.max_cells, adata.n_obs)
|
| 48 |
+
|
| 49 |
+
rows = []
|
| 50 |
+
X = adata.X
|
| 51 |
+
|
| 52 |
+
for i in range(n_cells):
|
| 53 |
+
row = X[i]
|
| 54 |
+
if hasattr(row, "toarray"):
|
| 55 |
+
expr = row.toarray().ravel()
|
| 56 |
+
else:
|
| 57 |
+
expr = np.asarray(row).ravel()
|
| 58 |
+
|
| 59 |
+
token_ids = ranked_gene_ids_for_cell(expr, gene_names, vocab)
|
| 60 |
+
rows.append(token_ids)
|
| 61 |
+
|
| 62 |
+
df = pd.DataFrame({"input_ids": rows})
|
| 63 |
+
table = pa.Table.from_pandas(df)
|
| 64 |
+
|
| 65 |
+
output_path = Path(args.output_arrow)
|
| 66 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
with pa.OSFile(str(output_path), "wb") as sink:
|
| 68 |
+
with pa.ipc.new_stream(sink, table.schema) as writer:
|
| 69 |
+
writer.write_table(table)
|
| 70 |
+
|
| 71 |
+
print(f"Saved {len(rows)} cells to {output_path}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
examples/01_extract_embeddings.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 1: Extract Cell Embeddings
|
| 3 |
+
Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python examples/1_extract_embeddings.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
print("=" * 80)
|
| 16 |
+
print("GeneMamba Phase 1: Extract Cell Embeddings")
|
| 17 |
+
print("=" * 80)
|
| 18 |
+
|
| 19 |
+
# ============================================================
|
| 20 |
+
# Step 1: Load pretrained model and tokenizer
|
| 21 |
+
# ============================================================
|
| 22 |
+
print("\n[Step 1] Loading model and tokenizer...")
|
| 23 |
+
|
| 24 |
+
# For this example, we use a local model path
|
| 25 |
+
# In practice, you would use: "username/GeneMamba-24l-512d"
|
| 26 |
+
model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 30 |
+
model_name,
|
| 31 |
+
trust_remote_code=True,
|
| 32 |
+
local_files_only=True # Try local first
|
| 33 |
+
)
|
| 34 |
+
model = AutoModel.from_pretrained(
|
| 35 |
+
model_name,
|
| 36 |
+
trust_remote_code=True,
|
| 37 |
+
local_files_only=True
|
| 38 |
+
)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Note: Could not load from '{model_name}': {e}")
|
| 41 |
+
print("Using mock data for demonstration...")
|
| 42 |
+
|
| 43 |
+
# For demonstration without actual checkpoint
|
| 44 |
+
from configuration_genemamba import GeneMambaConfig
|
| 45 |
+
from modeling_genemamba import GeneMambaModel
|
| 46 |
+
|
| 47 |
+
config = GeneMambaConfig(
|
| 48 |
+
vocab_size=25426,
|
| 49 |
+
hidden_size=512,
|
| 50 |
+
num_hidden_layers=24,
|
| 51 |
+
embedding_pooling="mean",
|
| 52 |
+
)
|
| 53 |
+
model = GeneMambaModel(config)
|
| 54 |
+
tokenizer = None
|
| 55 |
+
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
model = model.to(device)
|
| 58 |
+
model.eval()
|
| 59 |
+
|
| 60 |
+
print(f"✓ Model loaded on device: {device}")
|
| 61 |
+
print(f"✓ Model config: hidden_size={model.config.hidden_size}, "
|
| 62 |
+
f"num_layers={model.config.num_hidden_layers}")
|
| 63 |
+
|
| 64 |
+
# ============================================================
|
| 65 |
+
# Step 2: Prepare simulated single-cell data
|
| 66 |
+
# ============================================================
|
| 67 |
+
print("\n[Step 2] Preparing sample data...")
|
| 68 |
+
|
| 69 |
+
batch_size = 8
|
| 70 |
+
seq_len = 2048
|
| 71 |
+
vocab_size = 25426
|
| 72 |
+
|
| 73 |
+
# Simulate ranked gene sequences
|
| 74 |
+
# In practice, this would come from your scRNA-seq data
|
| 75 |
+
# Genes should be ranked by expression (highest first)
|
| 76 |
+
input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
|
| 77 |
+
|
| 78 |
+
print(f"✓ Created sample input:")
|
| 79 |
+
print(f" - Batch size: {batch_size}")
|
| 80 |
+
print(f" - Sequence length: {seq_len}")
|
| 81 |
+
print(f" - Input shape: {input_ids.shape}")
|
| 82 |
+
|
| 83 |
+
# ============================================================
|
| 84 |
+
# Step 3: Inference - Extract embeddings
|
| 85 |
+
# ============================================================
|
| 86 |
+
print("\n[Step 3] Extracting cell embeddings...")
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
outputs = model(input_ids, output_hidden_states=False)
|
| 90 |
+
|
| 91 |
+
# Get the pooled embedding (cell representation)
|
| 92 |
+
cell_embeddings = outputs.pooled_embedding
|
| 93 |
+
|
| 94 |
+
print(f"✓ Extraction complete!")
|
| 95 |
+
print(f" - Cell embeddings shape: {cell_embeddings.shape}")
|
| 96 |
+
print(f" - Pooling method used: {outputs.embedding_pooling}")
|
| 97 |
+
print(f" - Embedding type: {cell_embeddings.dtype}")
|
| 98 |
+
|
| 99 |
+
# ============================================================
|
| 100 |
+
# Step 4: Example downstream analyses
|
| 101 |
+
# ============================================================
|
| 102 |
+
print("\n[Step 4] Example downstream uses...")
|
| 103 |
+
|
| 104 |
+
# Example 1: Clustering (KMeans)
|
| 105 |
+
from sklearn.cluster import KMeans
|
| 106 |
+
n_clusters = 3
|
| 107 |
+
kmeans = KMeans(n_clusters=n_clusters, n_init=10)
|
| 108 |
+
clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
|
| 109 |
+
print(f"✓ Clustering: Assigned {len(np.unique(clusters))} clusters")
|
| 110 |
+
|
| 111 |
+
# Example 2: Dimensionality reduction (PCA)
|
| 112 |
+
from sklearn.decomposition import PCA
|
| 113 |
+
pca = PCA(n_components=2)
|
| 114 |
+
embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
|
| 115 |
+
print(f"✓ PCA reduction: {cell_embeddings.shape} → {embedding_2d.shape}")
|
| 116 |
+
|
| 117 |
+
# Example 3: Similarity search
|
| 118 |
+
# Find the most similar cell to the first cell
|
| 119 |
+
similarities = torch.nn.functional.cosine_similarity(
|
| 120 |
+
cell_embeddings[0:1],
|
| 121 |
+
cell_embeddings
|
| 122 |
+
)
|
| 123 |
+
most_similar_idx = torch.argmax(similarities).item()
|
| 124 |
+
print(f"✓ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
|
| 125 |
+
f"(similarity: {similarities[most_similar_idx]:.4f})")
|
| 126 |
+
|
| 127 |
+
# Example 4: Statistics
|
| 128 |
+
print("\n[Step 5] Embedding statistics:")
|
| 129 |
+
print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
|
| 130 |
+
print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}")
|
| 131 |
+
print(f" - Min: {cell_embeddings.min():.4f}")
|
| 132 |
+
print(f" - Max: {cell_embeddings.max():.4f}")
|
| 133 |
+
|
| 134 |
+
# ============================================================
|
| 135 |
+
# Step 6: Save embeddings (optional)
|
| 136 |
+
# ============================================================
|
| 137 |
+
print("\n[Step 6] Saving embeddings...")
|
| 138 |
+
|
| 139 |
+
np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
|
| 140 |
+
print("✓ Embeddings saved to 'cell_embeddings.npy'")
|
| 141 |
+
|
| 142 |
+
print("\n" + "=" * 80)
|
| 143 |
+
print("Phase 1 Complete!")
|
| 144 |
+
print("=" * 80)
|
| 145 |
+
|
| 146 |
+
return model, cell_embeddings
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
model, embeddings = main()
|
examples/downstream/10_finetune_classification.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 2: Downstream Task - Fine-tune for Classification
|
| 3 |
+
Demonstrates cell type annotation and other sequence classification tasks.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python examples/2_finetune_classification.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from torch.utils.data import Dataset, DataLoader
|
| 12 |
+
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GeneExpressionDataset(Dataset):
|
| 16 |
+
"""
|
| 17 |
+
Simple dataset for gene expression classification.
|
| 18 |
+
In practice, this would load from h5ad or other single-cell formats.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, input_ids, labels, max_length=2048):
|
| 22 |
+
self.input_ids = input_ids
|
| 23 |
+
self.labels = labels
|
| 24 |
+
self.max_length = max_length
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
return len(self.input_ids)
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, idx):
|
| 30 |
+
input_id = self.input_ids[idx]
|
| 31 |
+
label = self.labels[idx]
|
| 32 |
+
|
| 33 |
+
return {
|
| 34 |
+
"input_ids": torch.tensor(input_id, dtype=torch.long),
|
| 35 |
+
"labels": torch.tensor(label, dtype=torch.long),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_mock_data(n_samples=1000, n_features=2048, n_classes=5):
|
| 40 |
+
"""Create mock single-cell data for demonstration."""
|
| 41 |
+
|
| 42 |
+
print("Creating mock dataset...")
|
| 43 |
+
|
| 44 |
+
# Create random ranked gene sequences
|
| 45 |
+
input_ids = np.random.randint(2, 25426, (n_samples, n_features))
|
| 46 |
+
|
| 47 |
+
# Create random labels (e.g., cell types)
|
| 48 |
+
labels = np.random.randint(0, n_classes, n_samples)
|
| 49 |
+
|
| 50 |
+
# Split into train/val/test
|
| 51 |
+
train_size = int(0.7 * n_samples)
|
| 52 |
+
val_size = int(0.15 * n_samples)
|
| 53 |
+
|
| 54 |
+
train_ids = input_ids[:train_size]
|
| 55 |
+
train_labels = labels[:train_size]
|
| 56 |
+
|
| 57 |
+
val_ids = input_ids[train_size:train_size + val_size]
|
| 58 |
+
val_labels = labels[train_size:train_size + val_size]
|
| 59 |
+
|
| 60 |
+
test_ids = input_ids[train_size + val_size:]
|
| 61 |
+
test_labels = labels[train_size + val_size:]
|
| 62 |
+
|
| 63 |
+
print(f"✓ Dataset created:")
|
| 64 |
+
print(f" - Train: {len(train_ids)} samples")
|
| 65 |
+
print(f" - Val: {len(val_ids)} samples")
|
| 66 |
+
print(f" - Test: {len(test_ids)} samples")
|
| 67 |
+
print(f" - Classes: {n_classes}")
|
| 68 |
+
|
| 69 |
+
return (
|
| 70 |
+
GeneExpressionDataset(train_ids, train_labels),
|
| 71 |
+
GeneExpressionDataset(val_ids, val_labels),
|
| 72 |
+
GeneExpressionDataset(test_ids, test_labels),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
print("=" * 80)
|
| 78 |
+
print("GeneMamba Phase 2: Downstream Classification")
|
| 79 |
+
print("=" * 80)
|
| 80 |
+
|
| 81 |
+
# ============================================================
|
| 82 |
+
# Step 1: Load pretrained model with classification head
|
| 83 |
+
# ============================================================
|
| 84 |
+
print("\n[Step 1] Loading pretrained model with classification head...")
|
| 85 |
+
|
| 86 |
+
num_classes = 5
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 90 |
+
"GeneMamba-24l-512d",
|
| 91 |
+
num_labels=num_classes,
|
| 92 |
+
trust_remote_code=True,
|
| 93 |
+
local_files_only=True,
|
| 94 |
+
)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Note: Could not load from hub ({e})")
|
| 97 |
+
print("Using local initialization...")
|
| 98 |
+
|
| 99 |
+
# Initialize locally
|
| 100 |
+
from configuration_genemamba import GeneMambaConfig
|
| 101 |
+
from modeling_genemamba import GeneMambaForSequenceClassification
|
| 102 |
+
|
| 103 |
+
config = GeneMambaConfig(
|
| 104 |
+
vocab_size=25426,
|
| 105 |
+
hidden_size=512,
|
| 106 |
+
num_hidden_layers=24,
|
| 107 |
+
num_labels=num_classes,
|
| 108 |
+
)
|
| 109 |
+
model = GeneMambaForSequenceClassification(config)
|
| 110 |
+
|
| 111 |
+
print(f"✓ Model loaded")
|
| 112 |
+
print(f" - Classification head: input={model.config.hidden_size} → output={num_classes}")
|
| 113 |
+
|
| 114 |
+
# ============================================================
|
| 115 |
+
# Step 2: Prepare data
|
| 116 |
+
# ============================================================
|
| 117 |
+
print("\n[Step 2] Preparing dataset...")
|
| 118 |
+
|
| 119 |
+
train_dataset, val_dataset, test_dataset = create_mock_data(
|
| 120 |
+
n_samples=1000,
|
| 121 |
+
n_features=2048,
|
| 122 |
+
n_classes=num_classes,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# ============================================================
|
| 126 |
+
# Step 3: Set up training arguments
|
| 127 |
+
# ============================================================
|
| 128 |
+
print("\n[Step 3] Setting up training...")
|
| 129 |
+
|
| 130 |
+
output_dir = "./classification_results"
|
| 131 |
+
|
| 132 |
+
training_args = TrainingArguments(
|
| 133 |
+
output_dir=output_dir,
|
| 134 |
+
num_train_epochs=3,
|
| 135 |
+
per_device_train_batch_size=16,
|
| 136 |
+
per_device_eval_batch_size=16,
|
| 137 |
+
learning_rate=2e-5,
|
| 138 |
+
weight_decay=0.01,
|
| 139 |
+
warmup_steps=100,
|
| 140 |
+
logging_steps=50,
|
| 141 |
+
eval_strategy="epoch",
|
| 142 |
+
save_strategy="epoch",
|
| 143 |
+
load_best_model_at_end=True,
|
| 144 |
+
metric_for_best_model="accuracy",
|
| 145 |
+
report_to="none", # Disable W&B logging
|
| 146 |
+
seed=42,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
print(f"✓ Training config:")
|
| 150 |
+
print(f" - Output dir: {output_dir}")
|
| 151 |
+
print(f" - Epochs: {training_args.num_train_epochs}")
|
| 152 |
+
print(f" - Batch size: {training_args.per_device_train_batch_size}")
|
| 153 |
+
print(f" - Learning rate: {training_args.learning_rate}")
|
| 154 |
+
|
| 155 |
+
# ============================================================
|
| 156 |
+
# Step 4: Train using Trainer
|
| 157 |
+
# ============================================================
|
| 158 |
+
print("\n[Step 4] Training model...")
|
| 159 |
+
|
| 160 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
| 161 |
+
|
| 162 |
+
def compute_metrics(eval_pred):
|
| 163 |
+
"""Compute evaluation metrics."""
|
| 164 |
+
predictions, labels = eval_pred
|
| 165 |
+
predictions = np.argmax(predictions, axis=1)
|
| 166 |
+
|
| 167 |
+
return {
|
| 168 |
+
"accuracy": accuracy_score(labels, predictions),
|
| 169 |
+
"f1": f1_score(labels, predictions, average="weighted", zero_division=0),
|
| 170 |
+
"precision": precision_score(labels, predictions, average="weighted", zero_division=0),
|
| 171 |
+
"recall": recall_score(labels, predictions, average="weighted", zero_division=0),
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
trainer = Trainer(
|
| 175 |
+
model=model,
|
| 176 |
+
args=training_args,
|
| 177 |
+
train_dataset=train_dataset,
|
| 178 |
+
eval_dataset=val_dataset,
|
| 179 |
+
compute_metrics=compute_metrics,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
train_result = trainer.train()
|
| 183 |
+
|
| 184 |
+
print(f"✓ Training complete!")
|
| 185 |
+
print(f" - Final training loss: {train_result.training_loss:.4f}")
|
| 186 |
+
|
| 187 |
+
# ============================================================
|
| 188 |
+
# Step 5: Evaluate on test set
|
| 189 |
+
# ============================================================
|
| 190 |
+
print("\n[Step 5] Evaluating on test set...")
|
| 191 |
+
|
| 192 |
+
test_results = trainer.evaluate(test_dataset)
|
| 193 |
+
|
| 194 |
+
print(f"✓ Test Results:")
|
| 195 |
+
for metric, value in test_results.items():
|
| 196 |
+
if isinstance(value, float):
|
| 197 |
+
print(f" - {metric}: {value:.4f}")
|
| 198 |
+
|
| 199 |
+
# ============================================================
|
| 200 |
+
# Step 6: Make predictions
|
| 201 |
+
# ============================================================
|
| 202 |
+
print("\n[Step 6] Making predictions...")
|
| 203 |
+
|
| 204 |
+
predictions = trainer.predict(test_dataset)
|
| 205 |
+
predicted_classes = np.argmax(predictions.predictions, axis=1)
|
| 206 |
+
|
| 207 |
+
print(f"✓ Predictions made:")
|
| 208 |
+
print(f" - Predicted classes: {len(predicted_classes)} samples")
|
| 209 |
+
print(f" - Class distribution: {np.bincount(predicted_classes)}")
|
| 210 |
+
|
| 211 |
+
# ============================================================
|
| 212 |
+
# Step 7: Save model
|
| 213 |
+
# ============================================================
|
| 214 |
+
print("\n[Step 7] Saving model...")
|
| 215 |
+
|
| 216 |
+
save_dir = "./my_genemamba_classifier"
|
| 217 |
+
model.save_pretrained(save_dir)
|
| 218 |
+
print(f"✓ Model saved to '{save_dir}'")
|
| 219 |
+
|
| 220 |
+
# ============================================================
|
| 221 |
+
# Step 8: Load and test saved model
|
| 222 |
+
# ============================================================
|
| 223 |
+
print("\n[Step 8] Testing model reloading...")
|
| 224 |
+
|
| 225 |
+
loaded_model = AutoModelForSequenceClassification.from_pretrained(
|
| 226 |
+
save_dir,
|
| 227 |
+
trust_remote_code=True,
|
| 228 |
+
)
|
| 229 |
+
loaded_model.eval()
|
| 230 |
+
|
| 231 |
+
# Test on a single batch
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
sample_input = torch.randint(2, 25426, (1, 2048))
|
| 234 |
+
output = loaded_model(sample_input)
|
| 235 |
+
logits = output.logits
|
| 236 |
+
prediction = torch.argmax(logits, dim=1)
|
| 237 |
+
|
| 238 |
+
print(f"✓ Loaded model test prediction: class {prediction.item()}")
|
| 239 |
+
|
| 240 |
+
print("\n" + "=" * 80)
|
| 241 |
+
print("Phase 2 Complete! Model ready for deployment.")
|
| 242 |
+
print("=" * 80)
|
| 243 |
+
|
| 244 |
+
return model, trainer
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
model, trainer = main()
|
examples/downstream/11_zero_shot_logreg.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Zero-shot downstream baseline:
|
| 3 |
+
1) Extract frozen GeneMamba embeddings
|
| 4 |
+
2) Train LogisticRegression on train split
|
| 5 |
+
3) Evaluate on test split
|
| 6 |
+
|
| 7 |
+
Expected h5ad columns:
|
| 8 |
+
- obs['celltype']
|
| 9 |
+
- obs['partition'] with values in {'train', 'test'}
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import scanpy as sc
|
| 15 |
+
import torch
|
| 16 |
+
from sklearn.linear_model import LogisticRegression
|
| 17 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
| 18 |
+
from sklearn.preprocessing import LabelEncoder
|
| 19 |
+
from transformers import AutoModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1):
|
| 23 |
+
gene_names = np.array(adata.var_names)
|
| 24 |
+
X = adata.X
|
| 25 |
+
out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64)
|
| 26 |
+
|
| 27 |
+
for i in range(adata.n_obs):
|
| 28 |
+
row = X[i]
|
| 29 |
+
if hasattr(row, "toarray"):
|
| 30 |
+
expr = row.toarray().ravel()
|
| 31 |
+
else:
|
| 32 |
+
expr = np.asarray(row).ravel()
|
| 33 |
+
|
| 34 |
+
nz = np.where(expr > 0)[0]
|
| 35 |
+
if len(nz) == 0:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
genes = gene_names[nz]
|
| 39 |
+
vals = expr[nz]
|
| 40 |
+
order = np.argsort(-vals)
|
| 41 |
+
ranked_genes = genes[order]
|
| 42 |
+
ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len]
|
| 43 |
+
out[i, : len(ids)] = ids
|
| 44 |
+
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument("--model_path", required=True)
|
| 51 |
+
parser.add_argument("--h5ad", required=True)
|
| 52 |
+
parser.add_argument("--symbol2id_npy", default=None, help="Optional .npy dumped dict path")
|
| 53 |
+
parser.add_argument("--seq_len", type=int, default=2048)
|
| 54 |
+
parser.add_argument("--batch_size", type=int, default=64)
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
|
| 57 |
+
adata = sc.read_h5ad(args.h5ad)
|
| 58 |
+
assert "celltype" in adata.obs, "h5ad must include obs['celltype']"
|
| 59 |
+
assert "partition" in adata.obs, "h5ad must include obs['partition']"
|
| 60 |
+
|
| 61 |
+
if args.symbol2id_npy is None:
|
| 62 |
+
raise ValueError("Please provide --symbol2id_npy (dict saved by np.save(..., allow_pickle=True))")
|
| 63 |
+
|
| 64 |
+
symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item()
|
| 65 |
+
|
| 66 |
+
input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len)
|
| 67 |
+
labels = LabelEncoder().fit_transform(adata.obs["celltype"].values)
|
| 68 |
+
|
| 69 |
+
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
|
| 70 |
+
model.eval().cuda()
|
| 71 |
+
|
| 72 |
+
embeds = []
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for s in range(0, input_ids.shape[0], args.batch_size):
|
| 75 |
+
batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda")
|
| 76 |
+
out = model(batch)
|
| 77 |
+
embeds.append(out.pooled_embedding.detach().cpu().numpy())
|
| 78 |
+
embeds = np.concatenate(embeds, axis=0)
|
| 79 |
+
|
| 80 |
+
train_mask = adata.obs["partition"].values == "train"
|
| 81 |
+
test_mask = adata.obs["partition"].values == "test"
|
| 82 |
+
|
| 83 |
+
X_train, y_train = embeds[train_mask], labels[train_mask]
|
| 84 |
+
X_test, y_test = embeds[test_mask], labels[test_mask]
|
| 85 |
+
|
| 86 |
+
clf = LogisticRegression(max_iter=2000)
|
| 87 |
+
clf.fit(X_train, y_train)
|
| 88 |
+
pred = clf.predict(X_test)
|
| 89 |
+
|
| 90 |
+
print("accuracy:", accuracy_score(y_test, pred))
|
| 91 |
+
print("micro_f1:", f1_score(y_test, pred, average="micro"))
|
| 92 |
+
print("macro_f1:", f1_score(y_test, pred, average="macro"))
|
| 93 |
+
print("precision_weighted:", precision_score(y_test, pred, average="weighted", zero_division=0))
|
| 94 |
+
print("recall_weighted:", recall_score(y_test, pred, average="weighted", zero_division=0))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
main()
|
examples/downstream/12_batch_integration_eval.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch integration downstream example:
|
| 3 |
+
- Extract embeddings with frozen GeneMamba
|
| 4 |
+
- Evaluate simple batch mixing score proxy (silhouette by batch)
|
| 5 |
+
|
| 6 |
+
Expected h5ad columns:
|
| 7 |
+
- obs['batch']
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import scanpy as sc
|
| 13 |
+
import torch
|
| 14 |
+
from sklearn.metrics import silhouette_score
|
| 15 |
+
from sklearn.preprocessing import LabelEncoder
|
| 16 |
+
from transformers import AutoModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1):
|
| 20 |
+
gene_names = np.array(adata.var_names)
|
| 21 |
+
X = adata.X
|
| 22 |
+
out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64)
|
| 23 |
+
|
| 24 |
+
for i in range(adata.n_obs):
|
| 25 |
+
row = X[i]
|
| 26 |
+
if hasattr(row, "toarray"):
|
| 27 |
+
expr = row.toarray().ravel()
|
| 28 |
+
else:
|
| 29 |
+
expr = np.asarray(row).ravel()
|
| 30 |
+
|
| 31 |
+
nz = np.where(expr > 0)[0]
|
| 32 |
+
if len(nz) == 0:
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
genes = gene_names[nz]
|
| 36 |
+
vals = expr[nz]
|
| 37 |
+
order = np.argsort(-vals)
|
| 38 |
+
ranked_genes = genes[order]
|
| 39 |
+
ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len]
|
| 40 |
+
out[i, : len(ids)] = ids
|
| 41 |
+
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main():
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
parser.add_argument("--model_path", required=True)
|
| 48 |
+
parser.add_argument("--h5ad", required=True)
|
| 49 |
+
parser.add_argument("--symbol2id_npy", required=True)
|
| 50 |
+
parser.add_argument("--seq_len", type=int, default=2048)
|
| 51 |
+
parser.add_argument("--batch_size", type=int, default=64)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
adata = sc.read_h5ad(args.h5ad)
|
| 55 |
+
assert "batch" in adata.obs, "h5ad must include obs['batch']"
|
| 56 |
+
|
| 57 |
+
symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item()
|
| 58 |
+
input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len)
|
| 59 |
+
|
| 60 |
+
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
|
| 61 |
+
model.eval().cuda()
|
| 62 |
+
|
| 63 |
+
embeds = []
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
for s in range(0, input_ids.shape[0], args.batch_size):
|
| 66 |
+
batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda")
|
| 67 |
+
out = model(batch)
|
| 68 |
+
embeds.append(out.pooled_embedding.detach().cpu().numpy())
|
| 69 |
+
embeds = np.concatenate(embeds, axis=0)
|
| 70 |
+
|
| 71 |
+
batch_labels = LabelEncoder().fit_transform(adata.obs["batch"].values)
|
| 72 |
+
score = silhouette_score(embeds, batch_labels, metric="euclidean")
|
| 73 |
+
|
| 74 |
+
print("silhouette_by_batch:", score)
|
| 75 |
+
print("(Closer to 0 typically indicates better batch mixing than very high positive values.)")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
examples/downstream/20_continue_pretraining_reference.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 3: Continue Pretraining
|
| 3 |
+
Demonstrates how to continue pretraining GeneMamba on your own data using masked LM objective.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python examples/3_continue_pretraining.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoModelForMaskedLM,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
Trainer,
|
| 16 |
+
TrainingArguments,
|
| 17 |
+
DataCollatorForLanguageModeling,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PretrainingDataset(Dataset):
|
| 22 |
+
"""
|
| 23 |
+
Dataset for pretraining/continued pretraining.
|
| 24 |
+
Loads sequences and their lengths.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, input_ids_list, max_length=2048):
|
| 28 |
+
self.input_ids_list = input_ids_list
|
| 29 |
+
self.max_length = max_length
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.input_ids_list)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, idx):
|
| 35 |
+
input_ids = self.input_ids_list[idx]
|
| 36 |
+
|
| 37 |
+
# Pad or truncate to max_length
|
| 38 |
+
if len(input_ids) >= self.max_length:
|
| 39 |
+
input_ids = input_ids[:self.max_length]
|
| 40 |
+
else:
|
| 41 |
+
input_ids = np.pad(
|
| 42 |
+
input_ids,
|
| 43 |
+
(0, self.max_length - len(input_ids)),
|
| 44 |
+
constant_values=1 # Pad token ID
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return {
|
| 48 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
|
| 53 |
+
"""Create mock single-cell sequences for pretraining."""
|
| 54 |
+
|
| 55 |
+
print("Creating mock pretraining dataset...")
|
| 56 |
+
|
| 57 |
+
# Create ranked gene sequences
|
| 58 |
+
# In practice, these would come from your scRNA-seq data
|
| 59 |
+
sequences = []
|
| 60 |
+
for _ in range(n_sequences):
|
| 61 |
+
# Random ranked sequence
|
| 62 |
+
seq = np.random.randint(2, 25426, seq_len)
|
| 63 |
+
sequences.append(seq)
|
| 64 |
+
|
| 65 |
+
print(f"✓ Created {n_sequences} sequences of length {seq_len}")
|
| 66 |
+
|
| 67 |
+
return sequences
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main():
|
| 71 |
+
print("=" * 80)
|
| 72 |
+
print("GeneMamba Phase 3: Continue Pretraining")
|
| 73 |
+
print("=" * 80)
|
| 74 |
+
|
| 75 |
+
# ============================================================
|
| 76 |
+
# Step 1: Load pretrained model for masked LM
|
| 77 |
+
# ============================================================
|
| 78 |
+
print("\n[Step 1] Loading model for masked LM...")
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
model = AutoModelForMaskedLM.from_pretrained(
|
| 82 |
+
"GeneMamba-24l-512d",
|
| 83 |
+
trust_remote_code=True,
|
| 84 |
+
local_files_only=True,
|
| 85 |
+
)
|
| 86 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 87 |
+
"GeneMamba-24l-512d",
|
| 88 |
+
trust_remote_code=True,
|
| 89 |
+
local_files_only=True,
|
| 90 |
+
)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"Note: Could not load from hub ({e})")
|
| 93 |
+
print("Using local initialization...")
|
| 94 |
+
|
| 95 |
+
# Initialize locally
|
| 96 |
+
from configuration_genemamba import GeneMambaConfig
|
| 97 |
+
from modeling_genemamba import GeneMambaForMaskedLM
|
| 98 |
+
|
| 99 |
+
config = GeneMambaConfig(
|
| 100 |
+
vocab_size=25426,
|
| 101 |
+
hidden_size=512,
|
| 102 |
+
num_hidden_layers=24,
|
| 103 |
+
)
|
| 104 |
+
model = GeneMambaForMaskedLM(config)
|
| 105 |
+
tokenizer = None
|
| 106 |
+
|
| 107 |
+
print(f"✓ Model loaded")
|
| 108 |
+
print(f" - Architecture: {model.config.num_hidden_layers} layers, "
|
| 109 |
+
f"hidden_size={model.config.hidden_size}")
|
| 110 |
+
|
| 111 |
+
# ============================================================
|
| 112 |
+
# Step 2: Prepare pretraining data
|
| 113 |
+
# ============================================================
|
| 114 |
+
print("\n[Step 2] Preparing pretraining dataset...")
|
| 115 |
+
|
| 116 |
+
sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
|
| 117 |
+
|
| 118 |
+
# Split train/eval
|
| 119 |
+
train_size = int(0.9 * len(sequences))
|
| 120 |
+
train_sequences = sequences[:train_size]
|
| 121 |
+
eval_sequences = sequences[train_size:]
|
| 122 |
+
|
| 123 |
+
train_dataset = PretrainingDataset(train_sequences)
|
| 124 |
+
eval_dataset = PretrainingDataset(eval_sequences)
|
| 125 |
+
|
| 126 |
+
print(f"✓ Datasets created:")
|
| 127 |
+
print(f" - Training: {len(train_dataset)} samples")
|
| 128 |
+
print(f" - Evaluation: {len(eval_dataset)} samples")
|
| 129 |
+
|
| 130 |
+
# ============================================================
|
| 131 |
+
# Step 3: Set up data collator for MLM
|
| 132 |
+
# ============================================================
|
| 133 |
+
print("\n[Step 3] Setting up data collator...")
|
| 134 |
+
|
| 135 |
+
if tokenizer is not None:
|
| 136 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 137 |
+
tokenizer=tokenizer,
|
| 138 |
+
mlm=True,
|
| 139 |
+
mlm_probability=0.15, # Mask 15% of tokens
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
# Custom collator if no tokenizer available
|
| 143 |
+
class CustomDataCollator:
|
| 144 |
+
def __call__(self, batch):
|
| 145 |
+
input_ids = torch.stack([item["input_ids"] for item in batch])
|
| 146 |
+
|
| 147 |
+
# Create masked labels (for MLM loss)
|
| 148 |
+
labels = input_ids.clone()
|
| 149 |
+
mask = torch.rand(input_ids.shape) < 0.15
|
| 150 |
+
|
| 151 |
+
# Set input to [MASK] token (id=0)
|
| 152 |
+
input_ids[mask] = 0
|
| 153 |
+
|
| 154 |
+
# Set labels to -100 where not masked (loss ignores these)
|
| 155 |
+
labels[~mask] = -100
|
| 156 |
+
|
| 157 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 158 |
+
|
| 159 |
+
data_collator = CustomDataCollator()
|
| 160 |
+
|
| 161 |
+
print(f"✓ Data collator ready (MLM probability: 0.15)")
|
| 162 |
+
|
| 163 |
+
# ============================================================
|
| 164 |
+
# Step 4: Set up training arguments
|
| 165 |
+
# ============================================================
|
| 166 |
+
print("\n[Step 4] Setting up training...")
|
| 167 |
+
|
| 168 |
+
output_dir = "./pretrain_results"
|
| 169 |
+
|
| 170 |
+
training_args = TrainingArguments(
|
| 171 |
+
output_dir=output_dir,
|
| 172 |
+
num_train_epochs=2,
|
| 173 |
+
per_device_train_batch_size=16,
|
| 174 |
+
per_device_eval_batch_size=16,
|
| 175 |
+
learning_rate=2e-5,
|
| 176 |
+
weight_decay=0.01,
|
| 177 |
+
warmup_steps=500,
|
| 178 |
+
logging_steps=100,
|
| 179 |
+
eval_strategy="epoch",
|
| 180 |
+
save_strategy="epoch",
|
| 181 |
+
load_best_model_at_end=True,
|
| 182 |
+
metric_for_best_model="eval_loss",
|
| 183 |
+
report_to="none", # Disable W&B
|
| 184 |
+
seed=42,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
print(f"✓ Training config:")
|
| 188 |
+
print(f" - Output dir: {output_dir}")
|
| 189 |
+
print(f" - Epochs: {training_args.num_train_epochs}")
|
| 190 |
+
print(f" - Batch size: {training_args.per_device_train_batch_size}")
|
| 191 |
+
print(f" - Learning rate: {training_args.learning_rate}")
|
| 192 |
+
print(f" - MLM masking: 15%")
|
| 193 |
+
|
| 194 |
+
# ============================================================
|
| 195 |
+
# Step 5: Train
|
| 196 |
+
# ============================================================
|
| 197 |
+
print("\n[Step 5] Starting continued pretraining...")
|
| 198 |
+
|
| 199 |
+
trainer = Trainer(
|
| 200 |
+
model=model,
|
| 201 |
+
args=training_args,
|
| 202 |
+
train_dataset=train_dataset,
|
| 203 |
+
eval_dataset=eval_dataset,
|
| 204 |
+
data_collator=data_collator,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
train_result = trainer.train()
|
| 208 |
+
|
| 209 |
+
print(f"✓ Training complete!")
|
| 210 |
+
print(f" - Final training loss: {train_result.training_loss:.4f}")
|
| 211 |
+
|
| 212 |
+
# ============================================================
|
| 213 |
+
# Step 6: Evaluate
|
| 214 |
+
# ============================================================
|
| 215 |
+
print("\n[Step 6] Evaluating on held-out set...")
|
| 216 |
+
|
| 217 |
+
eval_results = trainer.evaluate()
|
| 218 |
+
|
| 219 |
+
print(f"✓ Evaluation Results:")
|
| 220 |
+
for metric, value in eval_results.items():
|
| 221 |
+
if isinstance(value, (int, float)):
|
| 222 |
+
print(f" - {metric}: {value:.4f}")
|
| 223 |
+
|
| 224 |
+
# ============================================================
|
| 225 |
+
# Step 7: Save model
|
| 226 |
+
# ============================================================
|
| 227 |
+
print("\n[Step 7] Saving continued pretrained model...")
|
| 228 |
+
|
| 229 |
+
save_dir = "./genemamba_continued_pretrain"
|
| 230 |
+
model.save_pretrained(save_dir)
|
| 231 |
+
if tokenizer is not None:
|
| 232 |
+
tokenizer.save_pretrained(save_dir)
|
| 233 |
+
|
| 234 |
+
print(f"✓ Model saved to '{save_dir}'")
|
| 235 |
+
|
| 236 |
+
# ============================================================
|
| 237 |
+
# Step 8: Test model inference
|
| 238 |
+
# ============================================================
|
| 239 |
+
print("\n[Step 8] Testing inference on masked input...")
|
| 240 |
+
|
| 241 |
+
model.eval()
|
| 242 |
+
|
| 243 |
+
# Create sample input with masked tokens
|
| 244 |
+
sample_input = torch.randint(2, 25426, (1, 2048))
|
| 245 |
+
sample_input[0, :10] = 0 # Mask first 10 tokens
|
| 246 |
+
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
outputs = model(sample_input)
|
| 249 |
+
logits = outputs.logits
|
| 250 |
+
predictions = torch.argmax(logits, dim=-1)
|
| 251 |
+
|
| 252 |
+
print(f"✓ Sample predictions generated")
|
| 253 |
+
print(f" - Input shape: {sample_input.shape}")
|
| 254 |
+
print(f" - Output logits shape: {logits.shape}")
|
| 255 |
+
print(f" - Top predicted genes (tokens): {predictions[0, :10].tolist()}")
|
| 256 |
+
|
| 257 |
+
print("\n" + "=" * 80)
|
| 258 |
+
print("Phase 3 Complete! Model ready for downstream tasks or further training.")
|
| 259 |
+
print("=" * 80)
|
| 260 |
+
|
| 261 |
+
return model, trainer
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
model, trainer = main()
|
examples/downstream/21_pretrain_from_scratch_reference.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 4: Train from Scratch
|
| 3 |
+
Demonstrates how to initialize and train a GeneMamba model from scratch.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python examples/4_pretrain_from_scratch.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoConfig,
|
| 14 |
+
Trainer,
|
| 15 |
+
TrainingArguments,
|
| 16 |
+
DataCollatorForLanguageModeling,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PretrainingDataset(Dataset):
|
| 21 |
+
"""Dataset for pretraining."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, input_ids_list, max_length=2048):
|
| 24 |
+
self.input_ids_list = input_ids_list
|
| 25 |
+
self.max_length = max_length
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
return len(self.input_ids_list)
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, idx):
|
| 31 |
+
input_ids = self.input_ids_list[idx]
|
| 32 |
+
|
| 33 |
+
# Pad or truncate
|
| 34 |
+
if len(input_ids) >= self.max_length:
|
| 35 |
+
input_ids = input_ids[:self.max_length]
|
| 36 |
+
else:
|
| 37 |
+
input_ids = np.pad(
|
| 38 |
+
input_ids,
|
| 39 |
+
(0, self.max_length - len(input_ids)),
|
| 40 |
+
constant_values=1
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
|
| 49 |
+
"""Create mock pretraining data."""
|
| 50 |
+
|
| 51 |
+
print("Creating mock pretraining dataset for from-scratch training...")
|
| 52 |
+
|
| 53 |
+
sequences = []
|
| 54 |
+
for _ in range(n_sequences):
|
| 55 |
+
seq = np.random.randint(2, 25426, seq_len)
|
| 56 |
+
sequences.append(seq)
|
| 57 |
+
|
| 58 |
+
print(f"✓ Created {n_sequences} sequences")
|
| 59 |
+
|
| 60 |
+
return sequences
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
print("=" * 80)
|
| 65 |
+
print("GeneMamba Phase 4: Train from Scratch")
|
| 66 |
+
print("=" * 80)
|
| 67 |
+
|
| 68 |
+
# ============================================================
|
| 69 |
+
# Step 1: Create config from scratch
|
| 70 |
+
# ============================================================
|
| 71 |
+
print("\n[Step 1] Creating model configuration...")
|
| 72 |
+
|
| 73 |
+
from configuration_genemamba import GeneMambaConfig
|
| 74 |
+
from modeling_genemamba import GeneMambaForMaskedLM
|
| 75 |
+
|
| 76 |
+
config = GeneMambaConfig(
|
| 77 |
+
vocab_size=25426,
|
| 78 |
+
hidden_size=256, # Smaller for faster demo
|
| 79 |
+
num_hidden_layers=12, # Reduced for demo
|
| 80 |
+
intermediate_size=1024,
|
| 81 |
+
max_position_embeddings=2048,
|
| 82 |
+
mamba_mode="gate",
|
| 83 |
+
embedding_pooling="mean",
|
| 84 |
+
num_labels=2,
|
| 85 |
+
hidden_dropout_prob=0.1,
|
| 86 |
+
initializer_range=0.02,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
print(f"✓ Config created:")
|
| 90 |
+
print(f" - Model type: {config.model_type}")
|
| 91 |
+
print(f" - Hidden size: {config.hidden_size}")
|
| 92 |
+
print(f" - Num layers: {config.num_hidden_layers}")
|
| 93 |
+
print(f" - Vocab size: {config.vocab_size}")
|
| 94 |
+
print(f" - Mode: {config.mamba_mode}")
|
| 95 |
+
|
| 96 |
+
# ============================================================
|
| 97 |
+
# Step 2: Initialize model from config
|
| 98 |
+
# ============================================================
|
| 99 |
+
print("\n[Step 2] Initializing model from config...")
|
| 100 |
+
|
| 101 |
+
model = GeneMambaForMaskedLM(config)
|
| 102 |
+
|
| 103 |
+
# Count parameters
|
| 104 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 105 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 106 |
+
|
| 107 |
+
print(f"✓ Model initialized:")
|
| 108 |
+
print(f" - Total parameters: {total_params / 1e6:.2f}M")
|
| 109 |
+
print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M")
|
| 110 |
+
|
| 111 |
+
# ============================================================
|
| 112 |
+
# Step 3: Prepare data
|
| 113 |
+
# ============================================================
|
| 114 |
+
print("\n[Step 3] Preparing training data...")
|
| 115 |
+
|
| 116 |
+
sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
|
| 117 |
+
|
| 118 |
+
# Split
|
| 119 |
+
train_size = int(0.8 * len(sequences))
|
| 120 |
+
train_sequences = sequences[:train_size]
|
| 121 |
+
eval_sequences = sequences[train_size:]
|
| 122 |
+
|
| 123 |
+
train_dataset = PretrainingDataset(train_sequences)
|
| 124 |
+
eval_dataset = PretrainingDataset(eval_sequences)
|
| 125 |
+
|
| 126 |
+
print(f"✓ Datasets created:")
|
| 127 |
+
print(f" - Train: {len(train_dataset)}")
|
| 128 |
+
print(f" - Eval: {len(eval_dataset)}")
|
| 129 |
+
|
| 130 |
+
# ============================================================
|
| 131 |
+
# Step 4: Data collator for MLM
|
| 132 |
+
# ============================================================
|
| 133 |
+
print("\n[Step 4] Setting up data collator...")
|
| 134 |
+
|
| 135 |
+
class CustomDataCollator:
|
| 136 |
+
"""Custom collator for MLM without tokenizer."""
|
| 137 |
+
|
| 138 |
+
def __call__(self, batch):
|
| 139 |
+
input_ids = torch.stack([item["input_ids"] for item in batch])
|
| 140 |
+
|
| 141 |
+
# Create labels for MLM
|
| 142 |
+
labels = input_ids.clone()
|
| 143 |
+
|
| 144 |
+
# Mask 15% of tokens
|
| 145 |
+
mask = torch.rand(input_ids.shape) < 0.15
|
| 146 |
+
input_ids[mask] = 0 # [MASK] token
|
| 147 |
+
|
| 148 |
+
# Don't compute loss on non-masked tokens
|
| 149 |
+
labels[~mask] = -100
|
| 150 |
+
|
| 151 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 152 |
+
|
| 153 |
+
data_collator = CustomDataCollator()
|
| 154 |
+
print(f"✓ Data collator ready")
|
| 155 |
+
|
| 156 |
+
# ============================================================
|
| 157 |
+
# Step 5: Training arguments
|
| 158 |
+
# ============================================================
|
| 159 |
+
print("\n[Step 5] Setting up training...")
|
| 160 |
+
|
| 161 |
+
output_dir = "./from_scratch_pretrain"
|
| 162 |
+
|
| 163 |
+
training_args = TrainingArguments(
|
| 164 |
+
output_dir=output_dir,
|
| 165 |
+
num_train_epochs=5,
|
| 166 |
+
per_device_train_batch_size=16,
|
| 167 |
+
per_device_eval_batch_size=16,
|
| 168 |
+
learning_rate=5e-4,
|
| 169 |
+
weight_decay=0.01,
|
| 170 |
+
warmup_steps=500,
|
| 171 |
+
logging_steps=50,
|
| 172 |
+
eval_strategy="epoch",
|
| 173 |
+
save_strategy="epoch",
|
| 174 |
+
load_best_model_at_end=True,
|
| 175 |
+
metric_for_best_model="eval_loss",
|
| 176 |
+
report_to="none",
|
| 177 |
+
seed=42,
|
| 178 |
+
optim="adamw_torch",
|
| 179 |
+
gradient_accumulation_steps=1,
|
| 180 |
+
max_grad_norm=1.0,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
print(f"✓ Training config:")
|
| 184 |
+
print(f" - Output: {output_dir}")
|
| 185 |
+
print(f" - Epochs: {training_args.num_train_epochs}")
|
| 186 |
+
print(f" - Batch size: {training_args.per_device_train_batch_size}")
|
| 187 |
+
print(f" - Learning rate: {training_args.learning_rate}")
|
| 188 |
+
|
| 189 |
+
# ============================================================
|
| 190 |
+
# Step 6: Train
|
| 191 |
+
# ============================================================
|
| 192 |
+
print("\n[Step 6] Starting training from scratch...")
|
| 193 |
+
print("(This may take a while. In practice, use more GPUs/data for real pretraining)")
|
| 194 |
+
|
| 195 |
+
trainer = Trainer(
|
| 196 |
+
model=model,
|
| 197 |
+
args=training_args,
|
| 198 |
+
train_dataset=train_dataset,
|
| 199 |
+
eval_dataset=eval_dataset,
|
| 200 |
+
data_collator=data_collator,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
train_result = trainer.train()
|
| 204 |
+
|
| 205 |
+
print(f"✓ Training complete!")
|
| 206 |
+
print(f" - Final training loss: {train_result.training_loss:.4f}")
|
| 207 |
+
|
| 208 |
+
# ============================================================
|
| 209 |
+
# Step 7: Evaluate
|
| 210 |
+
# ============================================================
|
| 211 |
+
print("\n[Step 7] Evaluating...")
|
| 212 |
+
|
| 213 |
+
eval_results = trainer.evaluate()
|
| 214 |
+
|
| 215 |
+
print(f"✓ Evaluation Results:")
|
| 216 |
+
for metric, value in eval_results.items():
|
| 217 |
+
if isinstance(value, (int, float)):
|
| 218 |
+
print(f" - {metric}: {value:.4f}")
|
| 219 |
+
|
| 220 |
+
# ============================================================
|
| 221 |
+
# Step 8: Save model and config
|
| 222 |
+
# ============================================================
|
| 223 |
+
print("\n[Step 8] Saving model...")
|
| 224 |
+
|
| 225 |
+
save_dir = "./my_genemamba_from_scratch"
|
| 226 |
+
model.save_pretrained(save_dir)
|
| 227 |
+
config.save_pretrained(save_dir)
|
| 228 |
+
|
| 229 |
+
print(f"✓ Model and config saved to '{save_dir}'")
|
| 230 |
+
print(f" Files created:")
|
| 231 |
+
print(f" - config.json")
|
| 232 |
+
print(f" - model.safetensors (or pytorch_model.bin)")
|
| 233 |
+
|
| 234 |
+
# ============================================================
|
| 235 |
+
# Step 9: Reload and verify
|
| 236 |
+
# ============================================================
|
| 237 |
+
print("\n[Step 9] Reloading model from checkpoint...")
|
| 238 |
+
|
| 239 |
+
from transformers import AutoModelForMaskedLM
|
| 240 |
+
|
| 241 |
+
loaded_model = AutoModelForMaskedLM.from_pretrained(
|
| 242 |
+
save_dir,
|
| 243 |
+
trust_remote_code=True,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
loaded_model.eval()
|
| 247 |
+
|
| 248 |
+
# Test inference
|
| 249 |
+
with torch.no_grad():
|
| 250 |
+
sample_input = torch.randint(2, 25426, (2, 2048))
|
| 251 |
+
sample_input[:, :10] = 0 # Mask first 10 tokens
|
| 252 |
+
|
| 253 |
+
outputs = loaded_model(sample_input)
|
| 254 |
+
logits = outputs.logits
|
| 255 |
+
|
| 256 |
+
print(f"✓ Model reloaded and tested!")
|
| 257 |
+
print(f" - Input shape: {sample_input.shape}")
|
| 258 |
+
print(f" - Logits shape: {logits.shape}")
|
| 259 |
+
|
| 260 |
+
# ============================================================
|
| 261 |
+
# Step 10: Optional - Convert to different format
|
| 262 |
+
# ============================================================
|
| 263 |
+
print("\n[Step 10] Model ready for conversion/deployment!")
|
| 264 |
+
print(f"✓ You can now:")
|
| 265 |
+
print(f" 1. Push to Hugging Face Hub:")
|
| 266 |
+
print(f" model.push_to_hub('your-username/GeneMamba-custom')")
|
| 267 |
+
print(f" 2. Use with downstream tasks:")
|
| 268 |
+
print(f" AutoModelForSequenceClassification.from_pretrained('{save_dir}', num_labels=N)")
|
| 269 |
+
print(f" 3. Extract embeddings:")
|
| 270 |
+
print(f" AutoModel.from_pretrained('{save_dir}')")
|
| 271 |
+
|
| 272 |
+
print("\n" + "=" * 80)
|
| 273 |
+
print("Phase 4 Complete! Model trained from scratch and ready to use.")
|
| 274 |
+
print("=" * 80)
|
| 275 |
+
|
| 276 |
+
return model, trainer, config
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
model, trainer, config = main()
|
examples/downstream/README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Downstream Examples
|
| 2 |
+
|
| 3 |
+
This folder now contains both **ready-to-run** examples and **legacy scripts** from the original GeneMamba project.
|
| 4 |
+
|
| 5 |
+
## Ready-to-run scripts
|
| 6 |
+
|
| 7 |
+
- `10_finetune_classification.py`
|
| 8 |
+
Fine-tune `AutoModelForSequenceClassification` for cell-type annotation.
|
| 9 |
+
|
| 10 |
+
- `11_zero_shot_logreg.py`
|
| 11 |
+
Freeze GeneMamba, extract `pooled_embedding`, train LogisticRegression on train split, evaluate on test split.
|
| 12 |
+
|
| 13 |
+
- `12_batch_integration_eval.py`
|
| 14 |
+
Batch integration proxy evaluation using silhouette score by `obs['batch']`.
|
| 15 |
+
|
| 16 |
+
## Reference training scripts
|
| 17 |
+
|
| 18 |
+
- `20_continue_pretraining_reference.py`
|
| 19 |
+
- `21_pretrain_from_scratch_reference.py`
|
| 20 |
+
|
| 21 |
+
## Legacy scripts from original repo
|
| 22 |
+
|
| 23 |
+
- `legacy_from_gene_mamba/mamba2_classification_finetune_with_label.py`
|
| 24 |
+
- `legacy_from_gene_mamba/mamba2_classification_finetune_without_label.py`
|
| 25 |
+
- `legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py`
|
| 26 |
+
|
| 27 |
+
## Required h5ad conventions
|
| 28 |
+
|
| 29 |
+
For downstream compatibility, standardize columns in `adata.obs`:
|
| 30 |
+
|
| 31 |
+
- `celltype` for label
|
| 32 |
+
- `batch` for batch id
|
| 33 |
+
- `partition` in `{train, test}` for train/test split
|
| 34 |
+
|
| 35 |
+
This matches conventions described in the original `dataset/downstream/README.md`.
|
examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_with_label.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Trainer
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import pyarrow as pa
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from matplotlib import pyplot as plt
|
| 11 |
+
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from transformers import AutoTokenizer, TrainingArguments
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 18 |
+
from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
|
| 19 |
+
|
| 20 |
+
from dotmap import DotMap
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
import os
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
# from trange import trange
|
| 27 |
+
|
| 28 |
+
sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
|
| 29 |
+
from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
|
| 30 |
+
from utils import permute_genes_by_expression
|
| 31 |
+
from utils2 import standardize_columns
|
| 32 |
+
|
| 33 |
+
import importlib
|
| 34 |
+
importlib.reload(sys.modules['models'])
|
| 35 |
+
importlib.reload(sys.modules['utils'])
|
| 36 |
+
importlib.reload(sys.modules['utils2'])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# %%
|
| 40 |
+
DATA_PATH = "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/"
|
| 41 |
+
# CHECKPOINT_PATH = "/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_48l_512d/1/3m/checkpoint-31250"
|
| 42 |
+
TOKENIZER_PATH = "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json"
|
| 43 |
+
SAVE_PATH = "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/embeddings/cell"
|
| 44 |
+
|
| 45 |
+
# %%
|
| 46 |
+
import argparse
|
| 47 |
+
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument("--dataset_name", type=str)
|
| 50 |
+
parser.add_argument("--ckpt_path", type = str)
|
| 51 |
+
parser.add_argument("--seq_len", type=int, default=2048)
|
| 52 |
+
parser.add_argument("--batch_size", type=int, default=24)
|
| 53 |
+
parser.add_argument("--num_epochs", type=int, default=5)
|
| 54 |
+
parser.add_argument("--test_size", type = float, default=0.1)
|
| 55 |
+
parser.add_argument("--split", type=lambda x: x.lower() in ["true", "1", "yes"], default=False,)
|
| 56 |
+
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# args = DotMap({
|
| 61 |
+
# "dataset_name": "ms",
|
| 62 |
+
# "seq_len": 512,
|
| 63 |
+
# "batch_size": 24,
|
| 64 |
+
# "num_epochs": 5,
|
| 65 |
+
# "test_size": 0.1
|
| 66 |
+
# })
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
#%%
|
| 70 |
+
CHECKPOINT_PATH = args.ckpt_path
|
| 71 |
+
model_name = CHECKPOINT_PATH.split("/")[-4]
|
| 72 |
+
mamba_layer = int(model_name.split("_")[1][:-1])
|
| 73 |
+
d_model = int(model_name.split("_")[2][:-1])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# make the sub directories to save the results
|
| 77 |
+
SAVE_PATH = os.path.join(SAVE_PATH, model_name)
|
| 78 |
+
sub_directories = ["predictions", "metrics", "figures", "repr"]
|
| 79 |
+
for sub_dir in sub_directories:
|
| 80 |
+
os.makedirs(os.path.join(SAVE_PATH, sub_dir), exist_ok=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# %%
|
| 84 |
+
import scanpy as sc
|
| 85 |
+
|
| 86 |
+
# Load the .h5ad file
|
| 87 |
+
dataset_name = args.dataset_name
|
| 88 |
+
# adata = sc.read_h5ad(os.path.join(DATA_PATH ,f'{dataset_name}.h5ad'))
|
| 89 |
+
|
| 90 |
+
adata = None
|
| 91 |
+
|
| 92 |
+
if args.split:
|
| 93 |
+
adata = sc.read_h5ad(os.path.join(DATA_PATH ,f'split/{dataset_name}_split.h5ad'))
|
| 94 |
+
print(f"Read data from {dataset_name}_split.h5ad")
|
| 95 |
+
dataset_name = dataset_name + "_split"
|
| 96 |
+
else:
|
| 97 |
+
adata = sc.read_h5ad(os.path.join(DATA_PATH ,f'processed/{dataset_name}_processed.h5ad'))
|
| 98 |
+
print(f"Read data from {dataset_name}_processed.h5ad")
|
| 99 |
+
|
| 100 |
+
# Display basic information about the data
|
| 101 |
+
print(adata)
|
| 102 |
+
|
| 103 |
+
# %%
|
| 104 |
+
# adata = standardize_columns(adata, dataset_name)
|
| 105 |
+
# assert "batch" in adata.obs.columns and "celltype" in adata.obs.columns
|
| 106 |
+
|
| 107 |
+
# %%
|
| 108 |
+
from sklearn.preprocessing import LabelEncoder
|
| 109 |
+
|
| 110 |
+
y_names = np.array(adata.obs['celltype'].values.tolist())
|
| 111 |
+
|
| 112 |
+
label_encoder = LabelEncoder()
|
| 113 |
+
y = label_encoder.fit_transform(y_names)
|
| 114 |
+
|
| 115 |
+
num_class = len(label_encoder.classes_)
|
| 116 |
+
|
| 117 |
+
# %%
|
| 118 |
+
from transformers import PretrainedConfig
|
| 119 |
+
|
| 120 |
+
config = PretrainedConfig.from_dict({
|
| 121 |
+
"d_model": d_model,
|
| 122 |
+
"mamba_layer": mamba_layer,
|
| 123 |
+
})
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# %%
|
| 127 |
+
model_cell_cls = GeneMamba2ForCellClassification(config, model_path=CHECKPOINT_PATH, tokenizer_path = TOKENIZER_PATH, args=None, output_dim_cls = num_class, hidden_dim= 512, num_layers_cls = 4)
|
| 128 |
+
|
| 129 |
+
# %%
|
| 130 |
+
permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model_cell_cls.tokenizer, model_cell_cls.symbol2id)
|
| 131 |
+
permuted_gene_ids
|
| 132 |
+
|
| 133 |
+
# %%
|
| 134 |
+
seq_len = args.seq_len
|
| 135 |
+
|
| 136 |
+
input_data = permuted_gene_ids[:, :seq_len]
|
| 137 |
+
|
| 138 |
+
# %%
|
| 139 |
+
model_cell_cls.tokenizer.cls_token_id
|
| 140 |
+
|
| 141 |
+
# %%
|
| 142 |
+
torch.tensor([model_cell_cls.tokenizer.cls_token_id for _ in range(input_data.shape[0])])
|
| 143 |
+
|
| 144 |
+
# %%
|
| 145 |
+
model_cell_cls.tokenizer.cls_token_id
|
| 146 |
+
|
| 147 |
+
# %%
|
| 148 |
+
input_data.shape[0]
|
| 149 |
+
|
| 150 |
+
# %%
|
| 151 |
+
input_data
|
| 152 |
+
|
| 153 |
+
# %%
|
| 154 |
+
# add the cls token to the input data
|
| 155 |
+
input_data = np.hstack([np.array([model_cell_cls.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
|
| 156 |
+
input_data
|
| 157 |
+
|
| 158 |
+
# %%
|
| 159 |
+
input_data.shape
|
| 160 |
+
|
| 161 |
+
#%%
|
| 162 |
+
from sklearn.model_selection import train_test_split
|
| 163 |
+
import numpy as np
|
| 164 |
+
|
| 165 |
+
def manual_stratified_split(X, y, test_size=0.1, random_state=None):
|
| 166 |
+
# separate the samples for each class
|
| 167 |
+
unique_classes = np.unique(y)
|
| 168 |
+
X_train, X_test, y_train, y_test = [], [], [], []
|
| 169 |
+
|
| 170 |
+
for cls in unique_classes:
|
| 171 |
+
cls_indices = np.where(y == cls)[0]
|
| 172 |
+
|
| 173 |
+
if len(cls_indices) > 1:
|
| 174 |
+
|
| 175 |
+
cls_train, cls_test = train_test_split(cls_indices, test_size=test_size, random_state=random_state)
|
| 176 |
+
else:
|
| 177 |
+
# if a class has only one sample, put it in the training set
|
| 178 |
+
cls_train, cls_test = cls_indices, []
|
| 179 |
+
|
| 180 |
+
X_train.extend(X[cls_train])
|
| 181 |
+
y_train.extend(y[cls_train])
|
| 182 |
+
X_test.extend(X[cls_test])
|
| 183 |
+
y_test.extend(y[cls_test])
|
| 184 |
+
|
| 185 |
+
return np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)
|
| 186 |
+
|
| 187 |
+
# %%
|
| 188 |
+
# from sklearn.model_selection import train_test_split
|
| 189 |
+
|
| 190 |
+
# X_train, X_test, y_train, y_test = manual_stratified_split(input_data, y, test_size=args.test_size, random_state=42)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
#%%
|
| 194 |
+
# train and test split is done and stored in the adata.obs["partition"] column, so we can extract the train and test data from there
|
| 195 |
+
|
| 196 |
+
X_train = input_data[adata.obs["partition"] == "train"]
|
| 197 |
+
X_test = input_data[adata.obs["partition"] == "test"]
|
| 198 |
+
y_train = y[adata.obs["partition"] == "train"]
|
| 199 |
+
y_test = y[adata.obs["partition"] == "test"]
|
| 200 |
+
|
| 201 |
+
X_train.shape, X_test.shape, y_train.shape, y_test.shape
|
| 202 |
+
|
| 203 |
+
# %%
|
| 204 |
+
from torch.utils.data import DataLoader, Dataset
|
| 205 |
+
|
| 206 |
+
class GeneDataset(Dataset):
|
| 207 |
+
def __init__(self, data, y):
|
| 208 |
+
self.data = data
|
| 209 |
+
self.labels = y
|
| 210 |
+
|
| 211 |
+
def __len__(self):
|
| 212 |
+
return len(self.data)
|
| 213 |
+
|
| 214 |
+
def __getitem__(self, idx):
|
| 215 |
+
return self.data[idx], self.labels[idx]
|
| 216 |
+
|
| 217 |
+
train_dataset = GeneDataset(X_train, y_train)
|
| 218 |
+
test_dataset = GeneDataset(X_test, y_test)
|
| 219 |
+
all_dataset = GeneDataset(input_data, y)
|
| 220 |
+
|
| 221 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
| 222 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
|
| 223 |
+
all_loader = DataLoader(all_dataset, batch_size=args.batch_size, shuffle=False)
|
| 224 |
+
|
| 225 |
+
# %%
|
| 226 |
+
from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
|
| 227 |
+
|
| 228 |
+
# %%
|
| 229 |
+
def compute_metrics(y_pred, y_prob, y_true):
|
| 230 |
+
|
| 231 |
+
metrics = {
|
| 232 |
+
"accuracy": accuracy_score(y_true, y_pred),
|
| 233 |
+
"Micro-F1 score": f1_score(y_true, y_pred, average='micro'),
|
| 234 |
+
"Macro-F1 score": f1_score(y_true, y_pred, average='macro'),
|
| 235 |
+
"precision": precision_score(y_true, y_pred, average='weighted'),
|
| 236 |
+
"recall": recall_score(y_true, y_pred, average='weighted'),
|
| 237 |
+
# "auc_roc": roc_auc_score(y_true, y_prob, multi_class = 'ovr'),
|
| 238 |
+
}
|
| 239 |
+
return metrics
|
| 240 |
+
|
| 241 |
+
# %%
|
| 242 |
+
epochs = args.num_epochs
|
| 243 |
+
optimizer = torch.optim.Adam(model_cell_cls.parameters(), lr=1e-4)
|
| 244 |
+
loss = torch.nn.CrossEntropyLoss()
|
| 245 |
+
|
| 246 |
+
for epoch in range(epochs):
|
| 247 |
+
model_cell_cls.train()
|
| 248 |
+
for i, batch in enumerate(train_loader):
|
| 249 |
+
data = batch[0]
|
| 250 |
+
target = batch[1]
|
| 251 |
+
data = data.to(model_cell_cls.device)
|
| 252 |
+
target = target.to(model_cell_cls.device)
|
| 253 |
+
model_cell_cls = model_cell_cls.to(model_cell_cls.device)
|
| 254 |
+
|
| 255 |
+
optimizer.zero_grad()
|
| 256 |
+
output = model_cell_cls(data, None)
|
| 257 |
+
loss_val = loss(output, target)
|
| 258 |
+
loss_val.backward()
|
| 259 |
+
optimizer.step()
|
| 260 |
+
if i % 10 == 0:
|
| 261 |
+
print(f"Epoch {epoch}, Iteration {i}, Loss: {loss_val}")
|
| 262 |
+
|
| 263 |
+
model_cell_cls.eval()
|
| 264 |
+
with torch.no_grad():
|
| 265 |
+
# add code to compute the metrics
|
| 266 |
+
pred_prob = []
|
| 267 |
+
pred_label = []
|
| 268 |
+
targets = []
|
| 269 |
+
cell_repr = []
|
| 270 |
+
|
| 271 |
+
for i, batch in enumerate(test_loader):
|
| 272 |
+
data = batch[0]
|
| 273 |
+
target = batch[1]
|
| 274 |
+
data = data.to(model_cell_cls.device)
|
| 275 |
+
target = target.to(model_cell_cls.device)
|
| 276 |
+
model_cell_cls = model_cell_cls.to(model_cell_cls.device)
|
| 277 |
+
|
| 278 |
+
output, output_test_repr = model_cell_cls(data, None, return_cls = True)
|
| 279 |
+
cell_repr.append(output_test_repr.cpu().numpy())
|
| 280 |
+
|
| 281 |
+
# calculate the probability from the output
|
| 282 |
+
pred_prob.append(torch.nn.functional.softmax(output, dim=1).cpu().numpy())
|
| 283 |
+
|
| 284 |
+
_, predicted = torch.max(output, 1)
|
| 285 |
+
pred_label.append(predicted.cpu().numpy())
|
| 286 |
+
targets.append(target.cpu().numpy())
|
| 287 |
+
|
| 288 |
+
pred_prob = np.concatenate(pred_prob)
|
| 289 |
+
pred_label = np.concatenate(pred_label)
|
| 290 |
+
targets = np.concatenate(targets)
|
| 291 |
+
cell_repr = np.concatenate(cell_repr)
|
| 292 |
+
|
| 293 |
+
# break
|
| 294 |
+
# save the predictions
|
| 295 |
+
np.save(os.path.join(SAVE_PATH, f"predictions/pred_prob_{dataset_name}_{epoch}.npy"), pred_prob)
|
| 296 |
+
np.save(os.path.join(SAVE_PATH, f"predictions/pred_label_{dataset_name}_{epoch}.npy"), pred_label)
|
| 297 |
+
np.save(os.path.join(SAVE_PATH, f"predictions/targets_{dataset_name}_{epoch}.npy"), targets)
|
| 298 |
+
|
| 299 |
+
metrics = compute_metrics(pred_label, pred_prob, targets)
|
| 300 |
+
|
| 301 |
+
with open(os.path.join(SAVE_PATH, f"metrics/metrics_{dataset_name}_{epoch}.txt"), "w") as f:
|
| 302 |
+
print(metrics, file=f)
|
| 303 |
+
print(metrics)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# draw scatter plot for the first two components
|
| 307 |
+
from sklearn.decomposition import PCA
|
| 308 |
+
|
| 309 |
+
pca = PCA(n_components=2)
|
| 310 |
+
pca_result = pca.fit_transform(cell_repr)
|
| 311 |
+
|
| 312 |
+
plt.figure(figsize=(8, 8))
|
| 313 |
+
|
| 314 |
+
plt.scatter(pca_result[:, 0], pca_result[:, 1], c = targets)
|
| 315 |
+
plt.savefig(os.path.join(SAVE_PATH, f"figures/scatter_{dataset_name}_{epoch}.png"))
|
| 316 |
+
# plt.show()
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# %%
|
| 320 |
+
model_cell_cls.eval()
|
| 321 |
+
|
| 322 |
+
def cell_embeddings(data_loader, model_cell_cls):
|
| 323 |
+
cell_repr = []
|
| 324 |
+
|
| 325 |
+
for i, batch in enumerate(data_loader):
|
| 326 |
+
data = batch[0]
|
| 327 |
+
target = batch[1]
|
| 328 |
+
data = data.to(model_cell_cls.device)
|
| 329 |
+
target = target.to(model_cell_cls.device)
|
| 330 |
+
model_cell_cls = model_cell_cls.to(model_cell_cls.device)
|
| 331 |
+
|
| 332 |
+
output, output_test_repr = model_cell_cls(data, None, return_cls = True)
|
| 333 |
+
cell_repr.append(output_test_repr.detach().cpu().numpy())
|
| 334 |
+
if i % 10 == 0:
|
| 335 |
+
print(f"Processed {i} batches")
|
| 336 |
+
|
| 337 |
+
cell_repr = np.concatenate(cell_repr)
|
| 338 |
+
return cell_repr
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
test_cell_repr = cell_embeddings(test_loader, model_cell_cls)
|
| 342 |
+
save_path_test = os.path.join(SAVE_PATH, f"repr/{dataset_name}_test_cell_repr.npy")
|
| 343 |
+
np.save(save_path_test, test_cell_repr)
|
| 344 |
+
del test_cell_repr
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
train_cell_repr = cell_embeddings(train_loader, model_cell_cls)
|
| 348 |
+
save_path_train = os.path.join(SAVE_PATH, f"repr/{dataset_name}_train_cell_repr.npy")
|
| 349 |
+
np.save(save_path_train, train_cell_repr)
|
| 350 |
+
del train_cell_repr
|
| 351 |
+
|
| 352 |
+
all_cell_repr = cell_embeddings(all_loader, model_cell_cls)
|
| 353 |
+
save_path_all = os.path.join(SAVE_PATH, f"repr/{dataset_name}_cell_repr.npy")
|
| 354 |
+
np.save(save_path_all, all_cell_repr)
|
| 355 |
+
del all_cell_repr
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# %%
|
| 359 |
+
# original_data = adata.X.toarray()
|
| 360 |
+
# original_data.shape
|
| 361 |
+
|
| 362 |
+
# %%
|
| 363 |
+
# draw the scatter figure on the original data
|
| 364 |
+
# from sklearn.decomposition import PCA
|
| 365 |
+
|
| 366 |
+
# pca = PCA(n_components=2)
|
| 367 |
+
# pca_result = pca.fit_transform(original_data)
|
| 368 |
+
|
| 369 |
+
# plt.figure(figsize=(8, 8))
|
| 370 |
+
|
| 371 |
+
# plt.scatter(pca_result[:, 0], pca_result[:, 1], c = y)
|
| 372 |
+
# plt.show()
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# %%
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
|
examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Trainer
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import pyarrow as pa
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from matplotlib import pyplot as plt
|
| 11 |
+
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from transformers import AutoTokenizer, TrainingArguments
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 18 |
+
from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
|
| 19 |
+
|
| 20 |
+
from dotmap import DotMap
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
import os
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
|
| 27 |
+
|
| 28 |
+
from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
|
| 29 |
+
from utils import permute_genes_by_expression, build_downstream_dataset
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
import importlib
|
| 33 |
+
importlib.reload(sys.modules['models'])
|
| 34 |
+
importlib.reload(sys.modules['utils'])
|
| 35 |
+
|
| 36 |
+
# %%
|
| 37 |
+
import scanpy as sc
|
| 38 |
+
|
| 39 |
+
import argparse
|
| 40 |
+
|
| 41 |
+
parser = argparse.ArgumentParser()
|
| 42 |
+
parser.add_argument("--dataset_name", type=str)
|
| 43 |
+
|
| 44 |
+
args2 = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
# Load the .h5ad file
|
| 47 |
+
dataset_name = args2.dataset_name
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
assert dataset_name in ["pbmc12k", "perirhinal_cortex", "covid19"]
|
| 51 |
+
|
| 52 |
+
adata = sc.read_h5ad(f'/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/processed/{dataset_name}_processed.h5ad')
|
| 53 |
+
|
| 54 |
+
assert "celltype" in adata.obs
|
| 55 |
+
|
| 56 |
+
print(adata)
|
| 57 |
+
|
| 58 |
+
# %%
|
| 59 |
+
from sklearn.preprocessing import LabelEncoder
|
| 60 |
+
|
| 61 |
+
y_names = np.array(adata.obs['celltype'].values.tolist())
|
| 62 |
+
|
| 63 |
+
label_encoder = LabelEncoder()
|
| 64 |
+
y = label_encoder.fit_transform(y_names)
|
| 65 |
+
|
| 66 |
+
num_class = len(label_encoder.classes_)
|
| 67 |
+
|
| 68 |
+
# %%
|
| 69 |
+
from transformers import PretrainedConfig
|
| 70 |
+
|
| 71 |
+
config = PretrainedConfig.from_dict({
|
| 72 |
+
"d_model": 512,
|
| 73 |
+
"mamba_layer": 24,
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# %%
|
| 78 |
+
model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
|
| 79 |
+
|
| 80 |
+
# %%
|
| 81 |
+
permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model.tokenizer, model.symbol2id)
|
| 82 |
+
permuted_gene_ids
|
| 83 |
+
|
| 84 |
+
# %%
|
| 85 |
+
num_samples = permuted_gene_ids.shape[0]
|
| 86 |
+
num_avaliable_gpu = torch.cuda.device_count()
|
| 87 |
+
|
| 88 |
+
# %%
|
| 89 |
+
from dotmap import DotMap
|
| 90 |
+
|
| 91 |
+
args = DotMap(
|
| 92 |
+
{
|
| 93 |
+
# "model": "state-spaces/mamba-130m-hf",
|
| 94 |
+
# "tokenizer": "state-spaces/mamba-130m-hf",
|
| 95 |
+
"learning_rate": 5e-5,
|
| 96 |
+
"batch_size": 16,
|
| 97 |
+
"gradient_accumulation_steps": 1,
|
| 98 |
+
"optim": "adamw_torch",
|
| 99 |
+
# "data_path": "/home/cong/study/codeSpace/VSCodeSpace/PythonWorkSpace/TCRPrediction/mamba_transformer/smiles_data.txt",
|
| 100 |
+
# "num_epochs": args2.num_epochs,
|
| 101 |
+
"seq_len": 2048,
|
| 102 |
+
"num_samples": num_samples,
|
| 103 |
+
"num_gpus": num_avaliable_gpu,
|
| 104 |
+
"output_dir": "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned/debug",
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# %%
|
| 109 |
+
input_data = permuted_gene_ids[:, :args.seq_len]
|
| 110 |
+
|
| 111 |
+
# %%
|
| 112 |
+
input_data.shape
|
| 113 |
+
|
| 114 |
+
#%%
|
| 115 |
+
# check if cls_token in the tokenizer:
|
| 116 |
+
if model.tokenizer.cls_token_id is None:
|
| 117 |
+
model.tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
| 118 |
+
|
| 119 |
+
#%%
|
| 120 |
+
input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
|
| 121 |
+
|
| 122 |
+
#%%
|
| 123 |
+
input_data.shape
|
| 124 |
+
|
| 125 |
+
# %%
|
| 126 |
+
sample_dataset = build_downstream_dataset(input_data, model.tokenizer)
|
| 127 |
+
sample_dataset
|
| 128 |
+
|
| 129 |
+
# input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
|
| 130 |
+
# input_data
|
| 131 |
+
|
| 132 |
+
# %%
|
| 133 |
+
args=TrainingArguments(
|
| 134 |
+
learning_rate=args.learning_rate,
|
| 135 |
+
num_train_epochs = 4,
|
| 136 |
+
per_device_train_batch_size=args.batch_size,
|
| 137 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 138 |
+
optim=args.optim,
|
| 139 |
+
output_dir=os.path.join(args.output_dir, dataset_name),
|
| 140 |
+
# output_dir=f"/scratch/zhiwei/cq5/logs/mamba/test/context_length",
|
| 141 |
+
# logging_dir=f"{args.output_dir}/{args.num_epochs}/{args.num_samples // 1000000 + args.bulk_id}m_logging",
|
| 142 |
+
logging_steps=args.num_samples // args.batch_size // 10,
|
| 143 |
+
save_steps=args.num_samples // args.batch_size // 10,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# %%
|
| 148 |
+
model.finetune(sample_dataset, args)
|
| 149 |
+
|
| 150 |
+
# %%
|
| 151 |
+
# ckpt_pth = get_last_checkpoint(os.path.join(args.output_dir, dataset_name))
|
| 152 |
+
# ckpt_pth
|
| 153 |
+
|
| 154 |
+
# #%%
|
| 155 |
+
# model = GeneMamba2(config, model_path=, tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
|
| 156 |
+
|
| 157 |
+
#%%
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Trainer
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import pyarrow as pa
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from matplotlib import pyplot as plt
|
| 11 |
+
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from transformers import AutoTokenizer, TrainingArguments
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 18 |
+
from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
|
| 19 |
+
|
| 20 |
+
from dotmap import DotMap
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
import os
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
|
| 27 |
+
|
| 28 |
+
from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
|
| 29 |
+
from utils import permute_genes_by_expression, build_downstream_dataset, get_last_checkpoint
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
import importlib
|
| 33 |
+
importlib.reload(sys.modules['models'])
|
| 34 |
+
importlib.reload(sys.modules['utils'])
|
| 35 |
+
|
| 36 |
+
# %%
|
| 37 |
+
import scanpy as sc
|
| 38 |
+
|
| 39 |
+
# import argparse
|
| 40 |
+
|
| 41 |
+
# parser = argparse.ArgumentParser()
|
| 42 |
+
# parser.add_argument("--dataset_name", type=str)
|
| 43 |
+
|
| 44 |
+
# args2 = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
# dataset_name = args2.dataset_name
|
| 47 |
+
|
| 48 |
+
dataset_name = "pbmc12k"
|
| 49 |
+
|
| 50 |
+
assert dataset_name in ["pbmc12k", "perirhinal_cortex", "covid19"]
|
| 51 |
+
|
| 52 |
+
adata = sc.read_h5ad(f'/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/processed/{dataset_name}_processed.h5ad')
|
| 53 |
+
|
| 54 |
+
assert "celltype" in adata.obs
|
| 55 |
+
|
| 56 |
+
print(adata)
|
| 57 |
+
|
| 58 |
+
# %%
|
| 59 |
+
from transformers import PretrainedConfig
|
| 60 |
+
|
| 61 |
+
config = PretrainedConfig.from_dict({
|
| 62 |
+
"d_model": 512,
|
| 63 |
+
"mamba_layer": 24,
|
| 64 |
+
})
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# %%
|
| 68 |
+
model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
|
| 69 |
+
|
| 70 |
+
# %%
|
| 71 |
+
permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model.tokenizer, model.symbol2id)
|
| 72 |
+
permuted_gene_ids
|
| 73 |
+
|
| 74 |
+
# %%
|
| 75 |
+
num_samples = permuted_gene_ids.shape[0]
|
| 76 |
+
num_avaliable_gpu = torch.cuda.device_count()
|
| 77 |
+
|
| 78 |
+
# %%
|
| 79 |
+
from dotmap import DotMap
|
| 80 |
+
|
| 81 |
+
args = DotMap(
|
| 82 |
+
{
|
| 83 |
+
# "model": "state-spaces/mamba-130m-hf",
|
| 84 |
+
# "tokenizer": "state-spaces/mamba-130m-hf",
|
| 85 |
+
"learning_rate": 5e-5,
|
| 86 |
+
"batch_size": 16,
|
| 87 |
+
"gradient_accumulation_steps": 1,
|
| 88 |
+
"optim": "adamw_torch",
|
| 89 |
+
# "data_path": "/home/cong/study/codeSpace/VSCodeSpace/PythonWorkSpace/TCRPrediction/mamba_transformer/smiles_data.txt",
|
| 90 |
+
# "num_epochs": args2.num_epochs,
|
| 91 |
+
"seq_len": 2048,
|
| 92 |
+
"num_samples": num_samples,
|
| 93 |
+
"num_gpus": num_avaliable_gpu,
|
| 94 |
+
"output_dir": "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned",
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
#%%
|
| 100 |
+
model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
|
| 101 |
+
|
| 102 |
+
model.resize_token_embeddings()
|
| 103 |
+
|
| 104 |
+
#%%
|
| 105 |
+
def get_last_checkpoint(output_dir):
|
| 106 |
+
checkpoints = os.listdir(output_dir)
|
| 107 |
+
checkpoints = [ckpt for ckpt in checkpoints if "checkpoint" in ckpt]
|
| 108 |
+
checkpoints = [int(ckpt.split("-")[1]) for ckpt in checkpoints]
|
| 109 |
+
checkpoints = sorted(checkpoints)
|
| 110 |
+
last_checkpoint = checkpoints[-1]
|
| 111 |
+
last_checkpoint = os.path.join(output_dir, f"checkpoint-{last_checkpoint}")
|
| 112 |
+
return last_checkpoint
|
| 113 |
+
|
| 114 |
+
ckpt_pth = f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned/{dataset_name}"
|
| 115 |
+
|
| 116 |
+
last_checkpoint = get_last_checkpoint(ckpt_pth)
|
| 117 |
+
state_dict_pth = os.path.join(last_checkpoint, "model.safetensors")
|
| 118 |
+
|
| 119 |
+
print(state_dict_pth)
|
| 120 |
+
|
| 121 |
+
#%%
|
| 122 |
+
from safetensors.torch import load_file
|
| 123 |
+
|
| 124 |
+
state_dict = load_file(state_dict_pth)
|
| 125 |
+
|
| 126 |
+
model.model.load_state_dict(state_dict)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# %%
|
| 130 |
+
input_data = permuted_gene_ids[:, :args.seq_len]
|
| 131 |
+
|
| 132 |
+
# %%
|
| 133 |
+
input_data.shape
|
| 134 |
+
|
| 135 |
+
#%%
|
| 136 |
+
# check if cls_token in the tokenizer:
|
| 137 |
+
if model.tokenizer.cls_token_id is None:
|
| 138 |
+
model.tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
| 139 |
+
|
| 140 |
+
#%%
|
| 141 |
+
input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
|
| 142 |
+
|
| 143 |
+
#%%
|
| 144 |
+
input_data.shape
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
#%%
|
| 148 |
+
from torch.utils.data import DataLoader, Dataset
|
| 149 |
+
|
| 150 |
+
class GeneDataset(Dataset):
|
| 151 |
+
def __init__(self, data):
|
| 152 |
+
self.data = data
|
| 153 |
+
|
| 154 |
+
def __len__(self):
|
| 155 |
+
return len(self.data)
|
| 156 |
+
|
| 157 |
+
def __getitem__(self, idx):
|
| 158 |
+
return self.data[idx]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
#%%
|
| 162 |
+
all_dataset = GeneDataset(input_data)
|
| 163 |
+
all_loader = DataLoader(all_dataset, batch_size = args.batch_size, shuffle=False)
|
| 164 |
+
|
| 165 |
+
# %%
|
| 166 |
+
def cell_embeddings(data_loader, model):
|
| 167 |
+
|
| 168 |
+
cell_repr = []
|
| 169 |
+
|
| 170 |
+
for i, batch in enumerate(data_loader):
|
| 171 |
+
batch = batch.to(model.device)
|
| 172 |
+
outputs = model(batch)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
cls_representation = outputs.hidden_states[:, 0, :]
|
| 176 |
+
cell_repr.append(cls_representation.detach().cpu().numpy())
|
| 177 |
+
|
| 178 |
+
if i % 10 == 0:
|
| 179 |
+
print(f"Processed {i} batches")
|
| 180 |
+
|
| 181 |
+
cell_repr = np.concatenate(cell_repr)
|
| 182 |
+
return cell_repr
|
| 183 |
+
|
| 184 |
+
# %%
|
| 185 |
+
model = model.to("cuda")
|
| 186 |
+
model.eval()
|
| 187 |
+
|
| 188 |
+
# %%
|
| 189 |
+
cell_repr = cell_embeddings(all_loader, model)
|
| 190 |
+
cell_repr.shape
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# cell_repr = np.concatenate(cell_repr)
|
| 194 |
+
# %%
|
| 195 |
+
np.save(f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/embeddings/fine-tuned/{dataset_name}_cell_repr.npy", cell_repr)
|
| 196 |
+
|
| 197 |
+
# %%
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
|
| 3 |
+
size 262998656
|
modeling_genemamba.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
|
| 3 |
+
Includes backbone model and task-specific heads for various downstream tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.init import normal_, constant_
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 16 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
|
| 17 |
+
from transformers.models.auto import register_model_for_auto_class
|
| 18 |
+
|
| 19 |
+
from mamba_ssm import Mamba
|
| 20 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm
|
| 21 |
+
|
| 22 |
+
from .configuration_genemamba import GeneMambaConfig
|
| 23 |
+
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================
|
| 29 |
+
# Core Architecture Components
|
| 30 |
+
# ===========================
|
| 31 |
+
|
| 32 |
+
class EncoderLayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Single Mamba encoder layer with residual connection.
|
| 35 |
+
Applies a Mamba2 or Mamba layer followed by addition with input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (int): Dimension of hidden states.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size: int):
|
| 42 |
+
super(EncoderLayer, self).__init__()
|
| 43 |
+
self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
|
| 44 |
+
|
| 45 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output after Mamba layer and residual connection.
|
| 52 |
+
"""
|
| 53 |
+
output = self.mamba(X) + X
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MambaMixer(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Stack of Mamba encoder layers with bidirectional processing and aggregation.
|
| 60 |
+
Processes sequences in both forward and reverse directions, then aggregates.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
|
| 64 |
+
hidden_size (int): Dimension of hidden states.
|
| 65 |
+
num_hidden_layers (int): Number of Mamba layers.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
mode: str = "gate",
|
| 71 |
+
hidden_size: int = 512,
|
| 72 |
+
num_hidden_layers: int = 24
|
| 73 |
+
):
|
| 74 |
+
super(MambaMixer, self).__init__()
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
|
| 78 |
+
# Create Mamba layers
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Aggregation modules for certain modes
|
| 84 |
+
if mode in ["concat", "gate"]:
|
| 85 |
+
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
|
| 86 |
+
|
| 87 |
+
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Reverse a sequence based on actual length (ignoring padding).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 93 |
+
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Reversed tensor.
|
| 97 |
+
"""
|
| 98 |
+
batch_size, seq_length, embedding_dim = X.size()
|
| 99 |
+
|
| 100 |
+
if mask is None:
|
| 101 |
+
# Simple flip
|
| 102 |
+
return X.flip([1])
|
| 103 |
+
|
| 104 |
+
# Flip based on actual sequence length (marked by mask)
|
| 105 |
+
lengths = (~mask).sum(dim=1)
|
| 106 |
+
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
|
| 107 |
+
flip_mask = pos_tensor < lengths.unsqueeze(1)
|
| 108 |
+
reversed_positions = torch.where(
|
| 109 |
+
flip_mask,
|
| 110 |
+
lengths.unsqueeze(1) - 1 - pos_tensor,
|
| 111 |
+
pos_tensor
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
|
| 115 |
+
return X_reverse
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
X: torch.Tensor,
|
| 120 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Process sequence through bidirectional Mamba layers.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 127 |
+
padding_mask (torch.Tensor, optional): Padding mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Output after processing all layers and aggregation.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
# Flip sequence for reverse processing
|
| 135 |
+
X_flip = self.flip_sequence(X, padding_mask)
|
| 136 |
+
|
| 137 |
+
# Forward and reverse passes
|
| 138 |
+
X_f = layer(X)
|
| 139 |
+
X_b = layer(X_flip)
|
| 140 |
+
|
| 141 |
+
# Flip back the reverse output
|
| 142 |
+
X_b = self.flip_sequence(X_b, padding_mask)
|
| 143 |
+
|
| 144 |
+
# Aggregate forward and reverse
|
| 145 |
+
if self.mode == "mean":
|
| 146 |
+
X = (X_f + X_b) / 2
|
| 147 |
+
elif self.mode == "sum":
|
| 148 |
+
X = X_f + X_b
|
| 149 |
+
elif self.mode == "concat":
|
| 150 |
+
X = torch.cat([X_f, X_b], dim=-1)
|
| 151 |
+
X = self.aggr(X)
|
| 152 |
+
elif self.mode == "gate":
|
| 153 |
+
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
|
| 154 |
+
X = z * X_f + (1 - z) * X_b
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid aggregation mode: {self.mode}")
|
| 157 |
+
|
| 158 |
+
return X
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ===========================
|
| 162 |
+
# Base Model Classes
|
| 163 |
+
# ===========================
|
| 164 |
+
|
| 165 |
+
class GeneMambaPreTrainedModel(PreTrainedModel):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all GeneMamba models.
|
| 168 |
+
Handles weight initialization and provides standard model interfaces.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
config_class = GeneMambaConfig
|
| 172 |
+
base_model_prefix = "genemamba"
|
| 173 |
+
supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
"""Initialize module weights."""
|
| 177 |
+
if isinstance(module, nn.Linear):
|
| 178 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
constant_(module.bias, 0.0)
|
| 181 |
+
elif isinstance(module, nn.Embedding):
|
| 182 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 183 |
+
if module.padding_idx is not None:
|
| 184 |
+
module.weight.data[module.padding_idx].zero_()
|
| 185 |
+
elif isinstance(module, nn.LayerNorm):
|
| 186 |
+
constant_(module.bias, 0.0)
|
| 187 |
+
constant_(module.weight, 1.0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GeneMambaModel(GeneMambaPreTrainedModel):
|
| 191 |
+
"""
|
| 192 |
+
GeneMamba backbone model - outputs cell embeddings and hidden states.
|
| 193 |
+
This is the core model used by task-specific heads.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (GeneMambaConfig): Model configuration class.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: GeneMambaConfig):
|
| 200 |
+
super().__init__(config)
|
| 201 |
+
self.config = config
|
| 202 |
+
|
| 203 |
+
# Embedding layer
|
| 204 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 205 |
+
|
| 206 |
+
# Mamba layers with bidirectional aggregation
|
| 207 |
+
self.mamba_mixer = MambaMixer(
|
| 208 |
+
mode=config.mamba_mode,
|
| 209 |
+
hidden_size=config.hidden_size,
|
| 210 |
+
num_hidden_layers=config.num_hidden_layers
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Final layer normalization
|
| 214 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 215 |
+
|
| 216 |
+
self.apply(self._init_weights)
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 219 |
+
"""Return embedding layer."""
|
| 220 |
+
return self.embeddings
|
| 221 |
+
|
| 222 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 223 |
+
"""Set embedding layer."""
|
| 224 |
+
self.embeddings = value
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 230 |
+
output_hidden_states: bool = False,
|
| 231 |
+
) -> GeneMambaModelOutput:
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 235 |
+
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
|
| 236 |
+
output_hidden_states (bool): Whether to output hidden states from all layers.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
|
| 240 |
+
"""
|
| 241 |
+
# Get embeddings
|
| 242 |
+
hidden_states = self.embeddings(input_ids)
|
| 243 |
+
|
| 244 |
+
# Pass through Mamba layers
|
| 245 |
+
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 246 |
+
|
| 247 |
+
# Apply final normalization
|
| 248 |
+
hidden_states = self.norm(hidden_states)
|
| 249 |
+
|
| 250 |
+
# Compute pooled embedding (cell representation)
|
| 251 |
+
if self.config.embedding_pooling == "CLS":
|
| 252 |
+
# Use first token (CLS)
|
| 253 |
+
pooled_embedding = hidden_states[:, 0, :]
|
| 254 |
+
elif self.config.embedding_pooling == "mean":
|
| 255 |
+
# Mean pooling over sequence
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
|
| 258 |
+
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
|
| 259 |
+
else:
|
| 260 |
+
pooled_embedding = hidden_states.mean(dim=1)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
|
| 263 |
+
|
| 264 |
+
return GeneMambaModelOutput(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
pooled_embedding=pooled_embedding,
|
| 267 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 268 |
+
embedding_pooling=self.config.embedding_pooling,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ===========================
|
| 273 |
+
# Task-Specific Models
|
| 274 |
+
# ===========================
|
| 275 |
+
|
| 276 |
+
@register_model_for_auto_class("AutoModel")
|
| 277 |
+
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
|
| 278 |
+
"""
|
| 279 |
+
GeneMamba model for masked language modeling (MLM).
|
| 280 |
+
Suitable for pretraining and domain adaptation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config (GeneMambaConfig): Model configuration class.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: GeneMambaConfig):
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.genemamba = GeneMambaModel(config)
|
| 289 |
+
|
| 290 |
+
# Language modeling head
|
| 291 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 292 |
+
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
labels: Optional[torch.Tensor] = None,
|
| 300 |
+
output_hidden_states: bool = False,
|
| 301 |
+
) -> GeneMambaMaskedLMOutput:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 305 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 306 |
+
labels (torch.Tensor, optional): Target token ids for MLM loss.
|
| 307 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GeneMambaMaskedLMOutput: Contains logits and optional loss.
|
| 311 |
+
"""
|
| 312 |
+
outputs = self.genemamba(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
| 319 |
+
|
| 320 |
+
loss = None
|
| 321 |
+
if labels is not None:
|
| 322 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 323 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 324 |
+
|
| 325 |
+
return GeneMambaMaskedLMOutput(
|
| 326 |
+
loss=loss,
|
| 327 |
+
logits=logits,
|
| 328 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model_for_auto_class("AutoModelForSequenceClassification")
|
| 333 |
+
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
|
| 334 |
+
"""
|
| 335 |
+
GeneMamba model for sequence classification tasks.
|
| 336 |
+
Ideal for cell type annotation, tissue classification, etc.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
config (GeneMambaConfig): Model configuration class.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: GeneMambaConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
self.config = config
|
| 346 |
+
|
| 347 |
+
self.genemamba = GeneMambaModel(config)
|
| 348 |
+
|
| 349 |
+
# Classification head
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 352 |
+
|
| 353 |
+
self.apply(self._init_weights)
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
input_ids: torch.Tensor,
|
| 358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
labels: Optional[torch.Tensor] = None,
|
| 360 |
+
output_hidden_states: bool = False,
|
| 361 |
+
) -> GeneMambaSequenceClassifierOutput:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 365 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 366 |
+
labels (torch.Tensor, optional): Class labels for classification loss.
|
| 367 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
|
| 371 |
+
"""
|
| 372 |
+
outputs = self.genemamba(
|
| 373 |
+
input_ids=input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
pooled_embedding = outputs.pooled_embedding
|
| 379 |
+
logits = self.classifier(self.dropout(pooled_embedding))
|
| 380 |
+
|
| 381 |
+
loss = None
|
| 382 |
+
if labels is not None:
|
| 383 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 384 |
+
loss = loss_fct(logits, labels)
|
| 385 |
+
|
| 386 |
+
return GeneMambaSequenceClassifierOutput(
|
| 387 |
+
loss=loss,
|
| 388 |
+
logits=logits,
|
| 389 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 390 |
+
pooled_embedding=pooled_embedding,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Register tokenizer class
|
| 395 |
+
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
|
modeling_outputs.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom ModelOutput classes for GeneMamba.
|
| 3 |
+
Defines the output structure for different GeneMamba tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GeneMambaModelOutput(ModelOutput):
|
| 14 |
+
"""
|
| 15 |
+
Base output class for GeneMamba models.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
|
| 21 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 22 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 23 |
+
|
| 24 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
|
| 25 |
+
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
|
| 26 |
+
This is the recommended embedding to use for classification, clustering, etc.
|
| 27 |
+
|
| 28 |
+
embedding_pooling (str):
|
| 29 |
+
The pooling method used to generate pooled_embedding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
last_hidden_state: torch.FloatTensor = None
|
| 33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 34 |
+
pooled_embedding: torch.FloatTensor = None
|
| 35 |
+
embedding_pooling: str = "mean"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneMambaSequenceClassifierOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Output class for GeneMamba sequence classification models.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 45 |
+
Classification loss (if labels were provided).
|
| 46 |
+
|
| 47 |
+
logits (torch.FloatTensor of shape (batch_size, num_labels)):
|
| 48 |
+
Classification scores (before softmax).
|
| 49 |
+
|
| 50 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 51 |
+
Hidden-states of the model at the output of each layer.
|
| 52 |
+
|
| 53 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
|
| 54 |
+
Cell embedding before classification head.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loss: Optional[torch.FloatTensor] = None
|
| 58 |
+
logits: torch.FloatTensor = None
|
| 59 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 60 |
+
pooled_embedding: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GeneMambaMaskedLMOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Output class for GeneMamba masked language modeling.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 70 |
+
MLM loss (if labels were provided).
|
| 71 |
+
|
| 72 |
+
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
|
| 73 |
+
Prediction scores of the language modeling head.
|
| 74 |
+
|
| 75 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 76 |
+
Hidden-states of the model at the output of each layer.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
logits: torch.FloatTensor = None
|
| 81 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "[PAD]",
|
| 3 |
+
"unk_token": "[UNK]"
|
| 4 |
+
}
|