| import torch |
| import tqdm |
| from torch import nn |
| from transformers import MT5EncoderModel, MT5PreTrainedModel |
|
|
| class MT5EncoderWithProjection(MT5PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.mt5_encoder = MT5EncoderModel(config) |
| self.projection = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.post_init() |
|
|
| def forward(self, **input_args): |
| hidden_states = self.mt5_encoder(**input_args).last_hidden_state |
| mask = input_args['attention_mask'] |
| batch_embeddings = torch.sum(hidden_states * mask[:, :, None], dim=1) / torch.sum(mask, dim=1)[:, None] |
| batch_embeddings = self.projection(batch_embeddings) |
| return batch_embeddings |
|
|