File size: 5,638 Bytes
a7754d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | from typing import Any, Dict, List, Optional, Tuple, Union
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel, BertPreTrainedModel
from transformers.modeling_outputs import ModelOutput
import torch
def get_range_vector(size: int, device: int) -> torch.Tensor:
"""
Returns a range vector with the desired size, starting at 0. The CUDA implementation
is meant to avoid copy data from CPU to GPU.
"""
return torch.arange(0, size, dtype=torch.long, device=device)
class Seq2LabelsOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
detect_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
max_error_probability: Optional[torch.FloatTensor] = None
class Seq2LabelsModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_detect_classes = config.num_detect_classes
self.label_smoothing = config.label_smoothing
if config.load_pretrained:
self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path)
bert_config = self.bert.config
else:
# Uu tien dung config hien tai (da co cac truong BERT architecture)
# thay vi goi AutoConfig.from_pretrained() online
if hasattr(config, 'hidden_size') and config.hidden_size:
from copy import deepcopy
bert_config = deepcopy(config)
# vocab_size trong config la so labels (15), khong phai BERT tokenizer vocab
bert_config.vocab_size = getattr(config, 'bert_vocab_size', 38168)
else:
bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path)
self.bert = AutoModel.from_config(bert_config)
if config.special_tokens_fix:
try:
vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
except AttributeError:
# reserve more space
vocab_size = self.bert.word_embedding.num_embeddings + 5
self.bert.resize_token_embeddings(vocab_size + 1)
predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
self.dropout = nn.Dropout(predictor_dropout)
self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
input_offsets: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
d_tags: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
if input_offsets is not None:
# offsets is (batch_size, d1, ..., dn, orig_sequence_length)
range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
sequence_output = sequence_output[range_vector, input_offsets]
logits = self.classifier(self.dropout(sequence_output))
logits_d = self.detector(sequence_output)
loss = None
if labels is not None and d_tags is not None:
loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
loss_d_fct = CrossEntropyLoss()
loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
loss = loss_labels + loss_d
if not return_dict:
output = (logits, logits_d) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return Seq2LabelsOutput(
loss=loss,
logits=logits,
detect_logits=logits_d,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
max_error_probability=torch.ones(logits.size(0), device=logits.device),
)
|