| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PreTrainedEncoder, PreTrainedDecoder |
| from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class CSUMLMEncoder(PreTrainedEncoder): |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| |
| |
| return encoder_outputs |
|
|
| class CSUMLMDecoder(PreTrainedDecoder): |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| head_mask=None, |
| cross_attn_head_mask=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| |
| |
| return decoder_outputs |
|
|
| class CSUMLMModel(PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.encoder = CSUMLMEncoder(config) |
| self.decoder = CSUMLMDecoder(config) |
| self.multimodal_fusion = MultimodalFusion(config) |
| |
| |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| decoder_input_ids=None, |
| decoder_attention_mask=None, |
| head_mask=None, |
| decoder_head_mask=None, |
| cross_attn_head_mask=None, |
| encoder_outputs=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| decoder_inputs_embeds=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| |
| |
| return output |
|
|
| |
| CSUMLMModel.register_for_auto_class() |