GeneMamba / configuration_genemamba.py
mineself2016's picture
Unify repo: default 24l-512d at root, add size variants via subfolder
c174f3b verified
"""
Configuration for GeneMamba model.
Defines all hyperparameters and settings for the GeneMamba architecture.
"""
from transformers import PretrainedConfig
from typing import Optional
class GeneMambaConfig(PretrainedConfig):
"""
Configuration class for GeneMamba model.
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
Args:
vocab_size (int, optional, defaults to 25426):
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
hidden_size (int, optional, defaults to 512):
Dimensionality of the hidden/embedding layers (d_model in Mamba).
num_hidden_layers (int, optional, defaults to 24):
Number of Mamba layers (mamba_layer).
intermediate_size (int, optional, defaults to 2048):
Dimensionality of intermediate representations in MLP.
max_position_embeddings (int, optional, defaults to 2048):
Maximum sequence length (seq_len).
hidden_dropout_prob (float, optional, defaults to 0.1):
Dropout probability for hidden states.
initializer_range (float, optional, defaults to 0.02):
Standard deviation of truncated normal initializer.
mamba_mode (str, optional, defaults to "gate"):
Aggregation mode for bidirectional Mamba layers.
Options: "mean", "sum", "concat", "gate".
embedding_pooling (str, optional, defaults to "mean"):
Method for pooling to get cell embedding.
Options: "CLS", "mean", "weighted".
num_labels (int, optional, defaults to 2):
Number of labels for sequence classification tasks.
pad_token_id (int, optional, defaults to 1):
Token ID for padding.
bos_token_id (int, optional, defaults to None):
Token ID for beginning of sequence.
eos_token_id (int, optional, defaults to None):
Token ID for end of sequence.
"""
model_type = "genemamba"
attribute_map = {
"hidden_size": "hidden_size",
"num_hidden_layers": "num_hidden_layers",
}
def __init__(
self,
vocab_size: int = 25426,
hidden_size: int = 512,
num_hidden_layers: int = 24,
intermediate_size: int = 2048,
max_position_embeddings: int = 2048,
hidden_dropout_prob: float = 0.1,
initializer_range: float = 0.02,
mamba_mode: str = "gate",
embedding_pooling: str = "mean",
num_labels: int = 2,
pad_token_id: int = 1,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.hidden_dropout_prob = hidden_dropout_prob
self.initializer_range = initializer_range
self.mamba_mode = mamba_mode
self.embedding_pooling = embedding_pooling
self.num_labels = num_labels
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id