| from typing import List, Union |
| import torch |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, BertTokenizer |
| from transformers.utils import is_remote_url, download_url |
| from pathlib import Path |
| from .configuration_vgcn import VGCNConfig |
| import pickle as pkl |
| import numpy as np |
| import scipy.sparse as sp |
|
|
|
|
|
|
|
|
| def get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj,gcn_config:VGCNConfig): |
|
|
| def sparse_scipy2torch(coo_sparse): |
| |
| i = torch.LongTensor(np.vstack((coo_sparse.row, coo_sparse.col))) |
| v = torch.from_numpy(coo_sparse.data) |
| return torch.sparse.FloatTensor(i, v, torch.Size(coo_sparse.shape)) |
|
|
| def normalize_adj(adj): |
| """ |
| Symmetrically normalize adjacency matrix. |
| """ |
|
|
| D_matrix = np.array(adj.sum(axis=1)) |
| D_inv_sqrt = np.power(D_matrix, -0.5).flatten() |
| D_inv_sqrt[np.isinf(D_inv_sqrt)] = 0. |
| d_mat_inv_sqrt = sp.diags(D_inv_sqrt) |
| return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) |
|
|
| gcn_vocab_adj_tf.data *= (gcn_vocab_adj_tf.data > gcn_config.tf_threshold) |
| gcn_vocab_adj_tf.eliminate_zeros() |
|
|
| gcn_vocab_adj.data *= (gcn_vocab_adj.data > gcn_config.npmi_threshold) |
| gcn_vocab_adj.eliminate_zeros() |
|
|
| if gcn_config.vocab_type == 'pmi': |
| gcn_vocab_adj_list = [gcn_vocab_adj] |
| elif gcn_config.vocab_type == 'tf': |
| gcn_vocab_adj_list = [gcn_vocab_adj_tf] |
| elif gcn_config.vocab_type == 'all': |
| gcn_vocab_adj_list = [gcn_vocab_adj_tf, gcn_vocab_adj] |
| else: |
| raise ValueError(f"vocab_type must be 'pmi', 'tf' or 'all', got {gcn_config.vocab_type}") |
|
|
| norm_gcn_vocab_adj_list = [] |
| for i in range(len(gcn_vocab_adj_list)): |
| adj = gcn_vocab_adj_list[i] |
| adj = normalize_adj(adj) |
| norm_gcn_vocab_adj_list.append(sparse_scipy2torch(adj.tocoo())) |
| |
| for t in norm_gcn_vocab_adj_list: |
| t.requires_grad = False |
|
|
| del gcn_vocab_adj_list |
|
|
| return norm_gcn_vocab_adj_list |
|
|
|
|
|
|
| class VCGNModelForTextClassification(PreTrainedModel): |
| config_class = VGCNConfig |
|
|
| def __init__(self, config, load_adjacency_matrix=True,): |
| super().__init__(config) |
| |
| self.tokenizer = BertTokenizer.from_pretrained(config.bert_model) |
|
|
| if load_adjacency_matrix: |
| norm_gcn_vocab_adj_list = self.load_adj_matrix(config.gcn_adj_matrix) |
| else: |
| norm_gcn_vocab_adj_list = [] |
| for _ in range(2 if config.vocab_type=='all' else 1): |
| norm_gcn_vocab_adj_list.append(torch.sparse.FloatTensor(torch.LongTensor([[0],[0]]), torch.Tensor([0]), (config.vocab_size, config.vocab_size))) |
|
|
| self.model = VGCN_Bert( |
| config, |
| gcn_adj_matrix=norm_gcn_vocab_adj_list, |
| gcn_adj_dim=config.vocab_size, |
| gcn_adj_num=len(norm_gcn_vocab_adj_list), |
| gcn_embedding_dim=config.gcn_embedding_dim, |
|
|
| ) |
|
|
| @classmethod |
| def from_pretrained(cls, *model_args, reload_adjacency_matrix=False, **kwargs): |
| model = super().from_pretrained( *model_args, **kwargs, load_adjacency_matrix=False) |
|
|
| if reload_adjacency_matrix: |
| norm_gcn_vocab_adj_list = model.load_adj_matrix(model.config.gcn_adj_matrix) |
| model.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in norm_gcn_vocab_adj_list]) |
| for p in model.model.embeddings.vocab_gcn.adj_matrix: |
| p.requires_grad=False |
|
|
| return model |
| |
| def set_adjacency_matrix(self, adj_matrix:Union[List, np.ndarray, sp.csr_matrix, torch.Tensor] ): |
| |
| if isinstance(adj_matrix, np.ndarray): |
| adj_matrix = [torch.from_numpy(adj_matrix)] |
| else: |
| raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}") |
|
|
| self.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix]) |
| for p in self.model.embeddings.vocab_gcn.adj_matrix: |
| p.requires_grad=False |
|
|
|
|
| def load_adj_matrix(self, adj_matrix): |
| filename = None |
| if Path(adj_matrix).is_file(): |
| filename = Path(adj_matrix) |
| |
| elif (Path(__file__).parent / Path(adj_matrix)).is_file(): |
| filename = Path(__file__).parent / Path(adj_matrix) |
| elif is_remote_url(adj_matrix): |
| filename = download_url(adj_matrix) |
|
|
|
|
| gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb')) |
|
|
| self.tokenizer = BertTokenizer.from_pretrained(adj_config['bert_model']) |
| return get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config) |
|
|
| def _prep_batch(self, batch: torch.Tensor): |
|
|
| vocab_size = self.tokenizer.vocab_size |
|
|
| batch_gcn_swop_eye = F.one_hot(batch, vocab_size).float().to(self.device) |
| batch_gcn_swop_eye = batch_gcn_swop_eye.transpose(1,2) |
| |
| batch_gcn_swop_eye[:, self.tokenizer.pad_token_id, :] = 0 |
| batch_gcn_swop_eye[:, self.tokenizer.cls_token_id, :] = 0 |
| batch_gcn_swop_eye[:, self.tokenizer.sep_token_id, :] = 0 |
|
|
| batch_gcn_swop_eye = F.pad(batch_gcn_swop_eye,(0,self.config.gcn_embedding_dim,0,0,0,0),value=0) |
|
|
| batch = F.pad(batch, (0, self.config.gcn_embedding_dim), 'constant', 0) |
| |
| |
| mask = torch.zeros(batch.shape[0], batch.shape[1] + 1, dtype=batch.dtype, device=self.device) |
| mask2 = torch.zeros(batch.shape[0], batch.shape[1] + 1, dtype=batch.dtype, device=self.device) |
|
|
| pos_start = (batch==self.tokenizer.pad_token_id).int().argmax(1) |
|
|
| mask[(torch.arange(batch.shape[0]), pos_start)] = 1 |
| mask2[(torch.arange(batch.shape[0]), pos_start+self.config.gcn_embedding_dim)] = 1 |
|
|
| mask = mask.cumsum(1)[:, :-1].bool() |
| mask2 = mask2.cumsum(1)[:, :-1].bool() |
|
|
| mask = mask & ~mask2 |
|
|
| batch.masked_fill_(mask, self.tokenizer.sep_token_id) |
|
|
| return batch, batch_gcn_swop_eye |
| |
| def text_to_batch(self, text: Union[List[str], str]): |
| if isinstance(text, str): |
| text = [text] |
| encoded = self.tokenizer.batch_encode_plus(text, padding=True, truncation=True, return_tensors='pt', max_length=self.config.max_seq_len-self.config.gcn_embedding_dim) |
| return encoded['input_ids'].to(self.device) |
|
|
| def forward(self, input:Union[torch.Tensor, List[str], str], labels=None): |
|
|
| if not isinstance(input, torch.Tensor): |
| input = self.text_to_batch(input) |
| |
| input, batch_gcn_swop_eye = self._prep_batch(input) |
|
|
| segment_ids = torch.zeros_like(input).int().to(self.device) |
| input_mask = (input>0).int().to(self.device) |
|
|
|
|
| logits = self.model(batch_gcn_swop_eye, input, segment_ids, input_mask ) |
| if labels is not None: |
| loss = torch.nn.cross_entropy(logits, labels) |
| return {"loss": loss, "logits": logits} |
| return {"logits": logits} |
|
|
| def predict(self, text: Union[List[str], str], as_dict=True): |
| with torch.no_grad(): |
| logits = self.forward(text)['logits'] |
| if as_dict: |
| label_id = torch.argmax(logits, dim=1).cpu().numpy() |
| label = [self.config.id2label[l] for l in label_id] |
| return { |
| "logits": logits, |
| "label_id": label_id, |
| "label": label, |
| } |
| else: |
| return torch.argmax(logits, dim=1).cpu().numpy() |
| |
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.init as init |
| import math |
|
|
| from transformers import BertModel |
| from transformers.models.bert.modeling_bert import BertEmbeddings, BertPooler,BertEncoder |
|
|
| class VocabGraphConvolution(nn.Module): |
| """Vocabulary GCN module. |
| |
| Params: |
| `voc_dim`: The size of vocabulary graph |
| `num_adj`: The number of the adjacency matrix of Vocabulary graph |
| `hid_dim`: The hidden dimension after XAW |
| `out_dim`: The output dimension after Relu(XAW)W |
| `dropout_rate`: The dropout probabilitiy for all fully connected |
| layers in the embeddings, encoder, and pooler. |
| |
| Inputs: |
| `vocab_adj_list`: The list of the adjacency matrix |
| `X_dv`: the feature of mini batch document, can be TF-IDF (batch, vocab), or word embedding (batch, word_embedding_dim, vocab) |
| |
| Outputs: |
| The graph embedding representation, dimension (batch, `out_dim`) or (batch, word_embedding_dim, `out_dim`) |
| |
| """ |
| def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2): |
| super(VocabGraphConvolution, self).__init__() |
| if isinstance(adj_matrix, nn.Parameter) or isinstance(adj_matrix, nn.ParameterList): |
| self.adj_matrix=adj_matrix |
| elif isinstance(adj_matrix, list): |
| self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix]) |
| for p in self.adj_matrix: |
| p.requires_grad=False |
| else: |
| raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}") |
|
|
| self.voc_dim=voc_dim |
| self.num_adj=num_adj |
| self.hid_dim=hid_dim |
| self.out_dim=out_dim |
|
|
| for i in range(self.num_adj): |
| setattr(self, 'W%d_vh'%i, nn.Parameter(torch.randn(voc_dim, hid_dim))) |
|
|
| self.fc_hc=nn.Linear(hid_dim,out_dim) |
| self.act_func = nn.ReLU() |
| self.dropout = nn.Dropout(dropout_rate) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| for n,p in self.named_parameters(): |
| if n.startswith('W') : |
| init.kaiming_uniform_(p, a=math.sqrt(5)) |
|
|
| def forward(self, X_dv, add_linear_mapping_term=False): |
| for i in range(self.num_adj): |
| H_vh=self.adj_matrix[i].mm(getattr(self, 'W%d_vh'%i)) |
| |
| H_vh=self.dropout(H_vh) |
| H_dh=X_dv.matmul(H_vh) |
|
|
| if add_linear_mapping_term: |
| H_linear=X_dv.matmul(getattr(self, 'W%d_vh'%i)) |
| H_linear=self.dropout(H_linear) |
| H_dh+=H_linear |
|
|
| if i == 0: |
| fused_H = H_dh |
| else: |
| fused_H += H_dh |
|
|
| out=self.fc_hc(fused_H) |
| return out |
|
|
|
|
| class VGCNBertEmbeddings(BertEmbeddings): |
| """Construct the embeddings from word, VGCN graph, position and token_type embeddings. |
| |
| Params: |
| `config`: a BertConfig class instance with the configuration to build a new model |
| `gcn_adj_dim`: The size of vocabulary graph |
| `gcn_adj_num`: The number of the adjacency matrix of Vocabulary graph |
| `gcn_embedding_dim`: The output dimension after VGCN |
| |
| Inputs: |
| `vocab_adj_list`: The list of the adjacency matrix |
| `gcn_swop_eye`: The transform matrix for transform the token sequence (sentence) to the Vocabulary order (BoW order) |
| `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
| with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts |
| `extract_features.py`, `run_classifier.py` and `run_squad.py`) |
| `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
| types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
| a `sentence B` token (see BERT paper for more details). |
| `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
| selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
| input sequence length in the current batch. It's the mask that we typically use for attention when |
| a batch has varying length sentences. |
| |
| Outputs: |
| the word embeddings fused by VGCN embedding, position embedding and token_type embeddings. |
| |
| """ |
| def __init__(self, config, gcn_adj_matrix, gcn_adj_dim, gcn_adj_num, gcn_embedding_dim): |
| super(VGCNBertEmbeddings, self).__init__(config) |
| assert gcn_embedding_dim>=0 |
| self.gcn_adj_matrix=gcn_adj_matrix |
| self.gcn_embedding_dim=gcn_embedding_dim |
| self.vocab_gcn=VocabGraphConvolution(gcn_adj_matrix,gcn_adj_dim, gcn_adj_num, 128, gcn_embedding_dim) |
|
|
| def forward(self, gcn_swop_eye, input_ids, token_type_ids=None, attention_mask=None): |
| words_embeddings = self.word_embeddings(input_ids) |
| vocab_input=gcn_swop_eye.matmul(words_embeddings).transpose(1,2) |
| |
| if self.gcn_embedding_dim>0: |
| gcn_vocab_out = self.vocab_gcn(vocab_input) |
| |
| gcn_words_embeddings=words_embeddings.clone() |
| for i in range(self.gcn_embedding_dim): |
| tmp_pos=(attention_mask.sum(-1)-2-self.gcn_embedding_dim+1+i)+torch.arange(0,input_ids.shape[0]).to(input_ids.device)*input_ids.shape[1] |
| gcn_words_embeddings.flatten(start_dim=0, end_dim=1)[tmp_pos,:]=gcn_vocab_out[:,:,i] |
|
|
| seq_length = input_ids.size(1) |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| if token_type_ids is None: |
| token_type_ids = torch.zeros_like(input_ids) |
|
|
| position_embeddings = self.position_embeddings(position_ids) |
| token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
| if self.gcn_embedding_dim>0: |
| embeddings = gcn_words_embeddings + position_embeddings + token_type_embeddings |
| else: |
| embeddings = words_embeddings + position_embeddings + token_type_embeddings |
|
|
| embeddings = self.LayerNorm(embeddings) |
| embeddings = self.dropout(embeddings) |
| return embeddings |
|
|
|
|
| class VGCN_Bert(BertModel): |
| """VGCN-BERT model for text classification. It inherits from Huggingface's BertModel. |
| |
| Params: |
| `config`: a BertConfig class instance with the configuration to build a new model |
| `gcn_adj_dim`: The size of vocabulary graph |
| `gcn_adj_num`: The number of the adjacency matrix of Vocabulary graph |
| `gcn_embedding_dim`: The output dimension after VGCN |
| `num_labels`: the number of classes for the classifier. Default = 2. |
| `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False |
| `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. |
| This can be used to compute head importance metrics. Default: False |
| |
| Inputs: |
| `vocab_adj_list`: The list of the adjacency matrix |
| `gcn_swop_eye`: The transform matrix for transform the token sequence (sentence) to the Vocabulary order (BoW order) |
| `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
| with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts |
| `extract_features.py`, `run_classifier.py` and `run_squad.py`) |
| `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
| types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
| a `sentence B` token (see BERT paper for more details). |
| `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
| selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
| input sequence length in the current batch. It's the mask that we typically use for attention when |
| a batch has varying length sentences. |
| `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] |
| with indices selected in [0, ..., num_labels]. |
| `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. |
| It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. |
| |
| Outputs: |
| Outputs the classification logits of shape [batch_size, num_labels]. |
| |
| """ |
| def __init__(self, config, gcn_adj_matrix, gcn_adj_dim, gcn_adj_num, gcn_embedding_dim): |
| super(VGCN_Bert, self).__init__(config) |
| self.embeddings = VGCNBertEmbeddings(config,gcn_adj_matrix,gcn_adj_dim,gcn_adj_num, gcn_embedding_dim) |
| self.encoder = BertEncoder(config) |
| self.pooler = BertPooler(config) |
| self.gcn_adj_matrix=gcn_adj_matrix |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| self.will_collect_cls_states=False |
| self.all_cls_states=[] |
| self.output_attentions=config.output_attentions |
|
|
| |
|
|
| def forward(self, gcn_swop_eye, input_ids, token_type_ids=None, attention_mask=None, output_hidden_states=False, head_mask=None): |
| if token_type_ids is None: |
| token_type_ids = torch.zeros_like(input_ids) |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
| embedding_output = self.embeddings(gcn_swop_eye, input_ids, token_type_ids,attention_mask) |
|
|
| |
| |
| |
| |
| |
| extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
| |
| |
| |
| |
| |
| extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
| |
| |
| |
| |
| |
| if head_mask is not None: |
| if head_mask.dim() == 1: |
| head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1) |
| elif head_mask.dim() == 2: |
| head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
| head_mask = head_mask.to(dtype=next(self.parameters()).dtype) |
| else: |
| head_mask = [None] * self.config.num_hidden_layers |
|
|
| if self.output_attentions: |
| output_all_encoded_layers=True |
| encoded_layers = self.encoder(embedding_output, |
| extended_attention_mask, |
| output_hidden_states=output_hidden_states, |
| head_mask=head_mask) |
| if self.output_attentions: |
| all_attentions, encoded_layers = encoded_layers |
|
|
| pooled_output = self.pooler(encoded_layers[-1]) |
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| if self.output_attentions: |
| return all_attentions, logits |
|
|
| return logits |
|
|