| |
| |
|
|
| |
|
|
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from torch import Tensor |
|
|
| from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids |
|
|
|
|
| class XLMRobertaEmbeddings(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| vocab_size, |
| max_position_embeddings, |
| type_vocab_size, |
| padding_idx=None, |
| device=None, |
| dtype=None, |
| ): |
| """ |
| If max_position_embeddings <= 0, there's no position embeddings |
| If type_vocab_size <= 0, there's no token type embeddings |
| """ |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.word_embeddings = nn.Embedding( |
| vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs |
| ) |
| self.max_position_embeddings = max_position_embeddings |
| self.type_vocab_size = type_vocab_size |
| if self.max_position_embeddings > 0: |
| self.position_embeddings = nn.Embedding( |
| max_position_embeddings, embed_dim, **factory_kwargs |
| ) |
| if self.type_vocab_size > 0: |
| self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs) |
|
|
| def forward(self, input_ids, position_ids=None, token_type_ids=None): |
| """ |
| input_ids: (batch, seqlen) |
| position_ids: (batch, seqlen) |
| token_type_ids: (batch, seqlen) |
| """ |
| batch_size, seqlen = input_ids.shape |
| embeddings = self.word_embeddings(input_ids) |
| if self.max_position_embeddings > 0: |
| if position_ids is None: |
| position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device) |
| |
| position_embeddings = self.position_embeddings(position_ids) |
| embeddings = embeddings + position_embeddings |
| if self.type_vocab_size > 0: |
| if token_type_ids is None: |
| token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) |
| token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| embeddings = embeddings + token_type_embeddings |
| return embeddings |
|
|