| """ |
| Configuration for GeneMamba model. |
| Defines all hyperparameters and settings for the GeneMamba architecture. |
| """ |
|
|
| from transformers import PretrainedConfig |
| from typing import Optional |
|
|
|
|
| class GeneMambaConfig(PretrainedConfig): |
| """ |
| Configuration class for GeneMamba model. |
| |
| This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig. |
| It can be used to instantiate models from pretrained checkpoints or customize model initialization. |
| |
| Args: |
| vocab_size (int, optional, defaults to 25426): |
| Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs). |
| |
| hidden_size (int, optional, defaults to 512): |
| Dimensionality of the hidden/embedding layers (d_model in Mamba). |
| |
| num_hidden_layers (int, optional, defaults to 24): |
| Number of Mamba layers (mamba_layer). |
| |
| intermediate_size (int, optional, defaults to 2048): |
| Dimensionality of intermediate representations in MLP. |
| |
| max_position_embeddings (int, optional, defaults to 2048): |
| Maximum sequence length (seq_len). |
| |
| hidden_dropout_prob (float, optional, defaults to 0.1): |
| Dropout probability for hidden states. |
| |
| initializer_range (float, optional, defaults to 0.02): |
| Standard deviation of truncated normal initializer. |
| |
| mamba_mode (str, optional, defaults to "gate"): |
| Aggregation mode for bidirectional Mamba layers. |
| Options: "mean", "sum", "concat", "gate". |
| |
| embedding_pooling (str, optional, defaults to "mean"): |
| Method for pooling to get cell embedding. |
| Options: "CLS", "mean", "weighted". |
| |
| num_labels (int, optional, defaults to 2): |
| Number of labels for sequence classification tasks. |
| |
| pad_token_id (int, optional, defaults to 1): |
| Token ID for padding. |
| |
| bos_token_id (int, optional, defaults to None): |
| Token ID for beginning of sequence. |
| |
| eos_token_id (int, optional, defaults to None): |
| Token ID for end of sequence. |
| """ |
| |
| model_type = "genemamba" |
| attribute_map = { |
| "hidden_size": "hidden_size", |
| "num_hidden_layers": "num_hidden_layers", |
| } |
| |
| def __init__( |
| self, |
| vocab_size: int = 25426, |
| hidden_size: int = 512, |
| num_hidden_layers: int = 24, |
| intermediate_size: int = 2048, |
| max_position_embeddings: int = 2048, |
| hidden_dropout_prob: float = 0.1, |
| initializer_range: float = 0.02, |
| mamba_mode: str = "gate", |
| embedding_pooling: str = "mean", |
| num_labels: int = 2, |
| pad_token_id: int = 1, |
| bos_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| **kwargs |
| ): |
| super().__init__(pad_token_id=pad_token_id, **kwargs) |
| |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.intermediate_size = intermediate_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.initializer_range = initializer_range |
| self.mamba_mode = mamba_mode |
| self.embedding_pooling = embedding_pooling |
| self.num_labels = num_labels |
| self.pad_token_id = pad_token_id |
| self.bos_token_id = bos_token_id |
| self.eos_token_id = eos_token_id |
|
|