| ''' |
| Author: Qiguang Chen |
| Date: 2023-01-11 10:39:26 |
| LastEditors: Qiguang Chen |
| LastEditTime: 2023-02-18 19:33:34 |
| Description: |
| |
| ''' |
| from common.utils import InputData |
| from model.encoder.base_encoder import BaseEncoder, BiEncoder |
| from model.encoder.pretrained_encoder import PretrainedEncoder |
| from model.encoder.non_pretrained_encoder import NonPretrainedEncoder |
|
|
| class AutoEncoder(BaseEncoder): |
| |
| def __init__(self, **config): |
| """automatedly load encoder by 'encoder_name' |
| Args: |
| config (dict): |
| encoder_name (str): support ["lstm", "self-attention-lstm", "bi-encoder"] and other pretrained model in hugging face |
| **args (Any): other configuration items corresponding to each module. |
| """ |
| super().__init__() |
| self.config = config |
| if config.get("encoder_name"): |
| encoder_name = config.get("encoder_name").lower() |
| if encoder_name in ["lstm", "self-attention-lstm"]: |
| self.__encoder = NonPretrainedEncoder(**config) |
| elif encoder_name == "bi-encoder": |
| self.__encoder= BiEncoder(self.__init__(**config["intent_encoder"]), self.__init__(**config["intent_encoder"])) |
| else: |
| self.__encoder = PretrainedEncoder(**config) |
| else: |
| raise ValueError("There is no Encoder Name in config.") |
|
|
| def forward(self, inputs: InputData): |
| return self.__encoder(inputs) |