| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| class LexicalConfig(PretrainedConfig): |
| model_type = "lexical_embedding" |
|
|
| def __init__( |
| self, |
| vocab_size=30522, |
| embed_dim=2048, |
| padding_idx=0, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.embed_dim = embed_dim |
| self.padding_idx = padding_idx |
|
|
| class LexicalHFModel(PreTrainedModel): |
| config_class = LexicalConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| self.embedding = nn.Embedding( |
| config.vocab_size, |
| config.embed_dim, |
| padding_idx=config.padding_idx |
| ) |
| |
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| embeds = self.embedding(input_ids) |
| |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
|
|
| mask_expanded = attention_mask.unsqueeze(-1).expand(embeds.size()).float() |
| sum_embeddings = torch.sum(embeds * mask_expanded, 1) |
| sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) |
| mean_pooled = sum_embeddings / sum_mask |
| |
| return torch.nn.functional.normalize(mean_pooled, p=2, dim=1) |