""" Configuration for GeneMamba model. Defines all hyperparameters and settings for the GeneMamba architecture. """ from transformers import PretrainedConfig from typing import Optional class GeneMambaConfig(PretrainedConfig): """ Configuration class for GeneMamba model. This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig. It can be used to instantiate models from pretrained checkpoints or customize model initialization. Args: vocab_size (int, optional, defaults to 25426): Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs). hidden_size (int, optional, defaults to 512): Dimensionality of the hidden/embedding layers (d_model in Mamba). num_hidden_layers (int, optional, defaults to 24): Number of Mamba layers (mamba_layer). intermediate_size (int, optional, defaults to 2048): Dimensionality of intermediate representations in MLP. max_position_embeddings (int, optional, defaults to 2048): Maximum sequence length (seq_len). hidden_dropout_prob (float, optional, defaults to 0.1): Dropout probability for hidden states. initializer_range (float, optional, defaults to 0.02): Standard deviation of truncated normal initializer. mamba_mode (str, optional, defaults to "gate"): Aggregation mode for bidirectional Mamba layers. Options: "mean", "sum", "concat", "gate". embedding_pooling (str, optional, defaults to "mean"): Method for pooling to get cell embedding. Options: "CLS", "mean", "weighted". num_labels (int, optional, defaults to 2): Number of labels for sequence classification tasks. pad_token_id (int, optional, defaults to 1): Token ID for padding. bos_token_id (int, optional, defaults to None): Token ID for beginning of sequence. eos_token_id (int, optional, defaults to None): Token ID for end of sequence. """ model_type = "genemamba" attribute_map = { "hidden_size": "hidden_size", "num_hidden_layers": "num_hidden_layers", } def __init__( self, vocab_size: int = 25426, hidden_size: int = 512, num_hidden_layers: int = 24, intermediate_size: int = 2048, max_position_embeddings: int = 2048, hidden_dropout_prob: float = 0.1, initializer_range: float = 0.02, mamba_mode: str = "gate", embedding_pooling: str = "mean", num_labels: int = 2, pad_token_id: int = 1, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings self.hidden_dropout_prob = hidden_dropout_prob self.initializer_range = initializer_range self.mamba_mode = mamba_mode self.embedding_pooling = embedding_pooling self.num_labels = num_labels self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id