File size: 3,471 Bytes
0a55f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import torch.nn as nn

from typing import List
from transformers import AutoModel

def mask_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class LanguageModel(nn.Module):
    def __init__(self, 
                 modelname: str, 
                 device: str, 
                 readout: str
        ):
        super(LanguageModel, self).__init__()
        self.device = device
        self.modelname = modelname
        self.readout_fn = readout

        self.model = AutoModel.from_pretrained(modelname)
        self.hidden_size = self.model.config.hidden_size

    def readout(self, model_inputs, model_outputs, readout_masks=None):
        if self.readout_fn == 'cls':
            if 'bert' in self.modelname or 'deberta' in self.modelname:
                text_representations = model_outputs.last_hidden_state[:, 0]
            elif 'xlnet' in self.modelname:
                text_representations = model_outputs.last_hidden_state[:, -1]
            else:
                raise ValueError('Invalid model name {} for the cls readout.'.format(self.modelname))
        elif self.readout_fn == 'mean':
            text_representations = mask_pooling(model_outputs, model_inputs['attention_mask'])
        elif self.readout_fn == 'ch' and readout_masks is not None:
            text_representations = mask_pooling(model_outputs, readout_masks)
        else:
            raise ValueError('Invalid readout function.')
        return text_representations

    def _lm_forward(self, tokens):
        tokens = tokens.to(self.device)
        if 'readout_mask' in tokens:
            readout_mask = tokens.pop('readout_mask')
        else:
            readout_mask = None
        outputs = self.model(**tokens)
        return self.readout(tokens, outputs, readout_mask)

    def forward(self):
        raise NotImplementedError

    def save_pretrained(self, modeldir):
        model_filename = os.path.join(modeldir, 'checkpoint.pt')
        torch.save(self.state_dict(), model_filename)

    def load_pretrained(self, modeldir):
        model_filename = os.path.join(modeldir, 'checkpoint.pt')
        self.load_state_dict(torch.load(model_filename))

class MultiHeadLanguageModel(LanguageModel):
    def __init__(self, 
                 modelname: str, 
                 device: str, 
                 readout: str, 
                 num_classes: List
        ):
        super().__init__(
            modelname,
            device, 
            readout
        )

        self.num_classes = num_classes
        self.lns = nn.ModuleList([nn.Linear(self.hidden_size, num_class) for num_class in num_classes])

    def forward(self, input_tokens, input_head_indices, class_tokens, class_head_indices):
        head_indices = torch.unique(input_head_indices)
        text_representations = self._lm_forward(input_tokens)

        final_preds = {}
        for i in head_indices:
            if torch.any(input_head_indices == i):
                final_preds[i.item()] = self.lns[i.item()](text_representations[input_head_indices == i])
            else:
                final_preds[i.item()] = torch.tensor([]).to(self.device)
        return final_preds