import os import torch import torch.nn as nn from typing import List from transformers import AutoModel def mask_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) class LanguageModel(nn.Module): def __init__(self, modelname: str, device: str, readout: str ): super(LanguageModel, self).__init__() self.device = device self.modelname = modelname self.readout_fn = readout self.model = AutoModel.from_pretrained(modelname) self.hidden_size = self.model.config.hidden_size def readout(self, model_inputs, model_outputs, readout_masks=None): if self.readout_fn == 'cls': if 'bert' in self.modelname or 'deberta' in self.modelname: text_representations = model_outputs.last_hidden_state[:, 0] elif 'xlnet' in self.modelname: text_representations = model_outputs.last_hidden_state[:, -1] else: raise ValueError('Invalid model name {} for the cls readout.'.format(self.modelname)) elif self.readout_fn == 'mean': text_representations = mask_pooling(model_outputs, model_inputs['attention_mask']) elif self.readout_fn == 'ch' and readout_masks is not None: text_representations = mask_pooling(model_outputs, readout_masks) else: raise ValueError('Invalid readout function.') return text_representations def _lm_forward(self, tokens): tokens = tokens.to(self.device) if 'readout_mask' in tokens: readout_mask = tokens.pop('readout_mask') else: readout_mask = None outputs = self.model(**tokens) return self.readout(tokens, outputs, readout_mask) def forward(self): raise NotImplementedError def save_pretrained(self, modeldir): model_filename = os.path.join(modeldir, 'checkpoint.pt') torch.save(self.state_dict(), model_filename) def load_pretrained(self, modeldir): model_filename = os.path.join(modeldir, 'checkpoint.pt') self.load_state_dict(torch.load(model_filename)) class MultiHeadLanguageModel(LanguageModel): def __init__(self, modelname: str, device: str, readout: str, num_classes: List ): super().__init__( modelname, device, readout ) self.num_classes = num_classes self.lns = nn.ModuleList([nn.Linear(self.hidden_size, num_class) for num_class in num_classes]) def forward(self, input_tokens, input_head_indices, class_tokens, class_head_indices): head_indices = torch.unique(input_head_indices) text_representations = self._lm_forward(input_tokens) final_preds = {} for i in head_indices: if torch.any(input_head_indices == i): final_preds[i.item()] = self.lns[i.item()](text_representations[input_head_indices == i]) else: final_preds[i.item()] = torch.tensor([]).to(self.device) return final_preds