| import torch |
| import torch.nn as nn |
| from sentence_transformers import models |
|
|
| class CustTrans(models.Transformer): |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.curr_task_type = None |
| self._rebuild_taskembedding(['sts', 'quora']) |
|
|
| def forward(self, inputs, task_type=None): |
|
|
| enc = self.auto_model(**inputs).last_hidden_state |
|
|
| if task_type == None: |
| task_type = self.curr_task_type |
|
|
| if task_type in self.task_types: |
| idx = torch.tensor(self.task_types.index(task_type), device=self.TaskEmbedding.weight.device) |
| hyp = self.TaskEmbedding(idx) |
| inputs['token_embeddings'] = self._project(enc, hyp) |
|
|
| else: |
| inputs['token_embeddings'] = enc |
|
|
| return inputs |
|
|
| def _set_curr_task_type(self, task_type): |
| self.curr_task_type = task_type |
|
|
| def _set_taskembedding_grad(self, value): |
| self.TaskEmbedding.weight.requires_grad = value |
|
|
| def _set_transformer_grad(self, value): |
| for param in self.auto_model.parameters(): |
| param.requires_grad = value |
|
|
| def _rebuild_taskembedding(self, task_types): |
| self.task_types = task_types |
| self.task_emb = 1 - torch.eye(len(self.task_types),768) |
| self.TaskEmbedding = nn.Embedding(len(self.task_types), 768).from_pretrained(self.task_emb) |
|
|
| def _project(self, v, normal_hyper): |
| |
| return v*normal_hyper |
|
|