| from transformers import PretrainedConfig |
| from transformers.utils import logging |
| from transformers.models.esm import EsmConfig |
| from transformers.models.bert import BertConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class ProtSTConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of a [`ProtSTModel`]. |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| |
| Args: |
| protein_config (`dict`, *optional*): |
| Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`]. |
| text_config (`dict`, *optional*): |
| Dictionary of configuration options used to initialize [`BertForPubMed`]. |
| ```""" |
|
|
| model_type = "protst" |
|
|
| def __init__( |
| self, |
| protein_config=None, |
| text_config=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| if protein_config is None: |
| protein_config = {} |
| logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.") |
|
|
| if text_config is None: |
| text_config = {} |
| logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.") |
|
|
| self.protein_config = EsmConfig(**protein_config) |
| self.text_config = BertConfig(**text_config) |
|
|
| @classmethod |
| def from_protein_text_configs( |
| cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs |
| ): |
| r""" |
| Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns: |
| [`ProtSTConfig`]: An instance of a configuration object |
| """ |
|
|
| return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs) |