Helix-mRNA-Wrapper / configuration_helix_mrna.py
Taykhoom's picture
Upload folder using huggingface_hub
a8c6b7e verified
from typing import Literal
from transformers import PretrainedConfig
class HelixmRNAConfig(PretrainedConfig):
"""HelixmRNAConfig class to store the configuration of the Helix-mRNA model.
Parameters
----------
batch_size : int, optional, default=10
The batch size
device : Literal["cpu", "cuda"], optional, default="cpu"
The device to use. Either use "cuda" or "cpu".
max_length : int, optional, default=12288
The maximum length of the input sequence.
nproc: int, optional, default=1
Number of processes to use for data processing.
"""
model_type = "mamba2" # helical's model type is "mamba2" so this needs to be set
model_name: Literal["helical-ai/Helix-mRNA"] = "helical-ai/Helix-mRNA"
def __init__(
self,
max_length: int = 12288,
**kwargs,
):
self.config = {
"model_name": self.model_name,
"max_length": max_length,
}
super().__init__(**kwargs)
@property
def layers_block_type(self):
layers = []
if self.num_hidden_layers != len(self.layers_block_type_string):
raise ValueError(
f"num_hidden_layers should be equal to the number of layers in layers_block_type_string, but got {self.num_hidden_layers} and {len(self.layers_block_type_string)}"
)
for layer in self.layers_block_type_string:
if layer == "M":
layers.append("mamba")
elif layer == "*":
layers.append("attention")
elif layer == "+":
layers.append("mlp")
return layers