| import torch |
|
|
|
|
| class DecoderWithLMhead(torch.nn.Module): |
| """ Creation of a class to combine the decoder and the lm head """ |
|
|
| def __init__(self, decoder, lm_head, config): |
| super().__init__() |
| self.decoder = decoder |
| self.lm_head = lm_head |
| self.config = config |
|
|
| def forward(self, *inputs): |
|
|
| input_ids, attention_mask, encoder_hidden_states = inputs[:3] |
|
|
| list_pkv = inputs[3:] |
| past_key_values = tuple(list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)) |
|
|
| decoder_output = self.decoder( |
| input_ids=input_ids, |
| encoder_attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
| lm_head_out = self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)) |
|
|
| return lm_head_out, decoder_output[1] |
|
|
|
|
| class T5Encoder(torch.nn.Module): |
| """ Creation of a class to output only the last hidden state from the encoder """ |
|
|
| def __init__(self, encoder): |
| super().__init__() |
| self.encoder = encoder |
|
|
| def forward(self, *input, **kwargs): |
| return self.encoder(*input, **kwargs)[0] |
|
|
|
|
| class DecoderWithLMheadInitial(torch.nn.Module): |
| """ Creation of a class to combine the decoder and the lm head """ |
|
|
| def __init__(self, decoder, lm_head, config): |
| super().__init__() |
| self.decoder = decoder |
| self.lm_head = lm_head |
| self.config = config |
|
|
| def forward(self, input_ids, attention_mask, encoder_hidden_states): |
| decoder_output = self.decoder( |
| input_ids=input_ids, |
| encoder_attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| ) |
|
|
| return ( |
| self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)), |
| decoder_output[1], |
| ) |
|
|