from typing import Literal from transformers import PretrainedConfig class HelixmRNAConfig(PretrainedConfig): """HelixmRNAConfig class to store the configuration of the Helix-mRNA model. Parameters ---------- batch_size : int, optional, default=10 The batch size device : Literal["cpu", "cuda"], optional, default="cpu" The device to use. Either use "cuda" or "cpu". max_length : int, optional, default=12288 The maximum length of the input sequence. nproc: int, optional, default=1 Number of processes to use for data processing. """ model_type = "mamba2" # helical's model type is "mamba2" so this needs to be set model_name: Literal["helical-ai/Helix-mRNA"] = "helical-ai/Helix-mRNA" def __init__( self, max_length: int = 12288, **kwargs, ): self.config = { "model_name": self.model_name, "max_length": max_length, } super().__init__(**kwargs) @property def layers_block_type(self): layers = [] if self.num_hidden_layers != len(self.layers_block_type_string): raise ValueError( f"num_hidden_layers should be equal to the number of layers in layers_block_type_string, but got {self.num_hidden_layers} and {len(self.layers_block_type_string)}" ) for layer in self.layers_block_type_string: if layer == "M": layers.append("mamba") elif layer == "*": layers.append("attention") elif layer == "+": layers.append("mlp") return layers