| |
| |
| |
| |
|
|
| from copy import deepcopy |
| import math |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from configuration_tsp import TSPConfig |
|
|
|
|
| class TSPPreTrainedModel(PreTrainedModel): |
| config_class = TSPConfig |
| base_model_prefix = "backbone" |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TSPModelForPreTraining(TSPPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.backbone = TSPModel(config) |
| if config.use_electra: |
| mlm_config = deepcopy(config) |
| mlm_config.hidden_size //= config.electra_generator_size_divisor |
| mlm_config.intermediate_size //= config.electra_generator_size_divisor |
| mlm_config.num_attention_heads //= config.electra_generator_size_divisor |
| self.mlm_backbone = TSPModel(mlm_config) |
| self.mlm_head = MaskedLMHead( |
| mlm_config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings |
| ) |
| self.rtd_backbone = self.backbone |
| self.rtd_backbone.embeddings = self.mlm_backbone.embeddings |
| self.rtd_head = ReplacedTokenDiscriminationHead(config) |
| else: |
| self.mlm_backbone = self.backbone |
| self.mlm_head = MaskedLMHead( |
| config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings |
| ) |
| self.tsp_head = TextStructurePredictionHead(config) |
| self.apply(self._init_weights) |
|
|
| def forward(self, *args, **kwargs): |
| raise NotImplementedError( |
| "Refer to the implementation of text structrue prediction task for how to use the model." |
| ) |
|
|
|
|
| class MaskedLMHead(nn.Module): |
| def __init__(self, config, word_embeddings=None): |
| super().__init__() |
| self.linear = nn.Linear(config.hidden_size, config.embedding_size) |
| self.norm = nn.LayerNorm(config.embedding_size) |
| self.predictor = nn.Linear(config.embedding_size, config.vocab_size) |
| if word_embeddings is not None: |
| self.predictor.weight = word_embeddings.weight |
|
|
| def forward( |
| self, |
| x, |
| is_selected=None, |
| ): |
| if is_selected is not None: |
| |
| |
| x = x[is_selected] |
| x = self.linear(x) |
| x = F.gelu(x) |
| x = self.norm(x) |
| return self.predictor(x) |
|
|
|
|
| class ReplacedTokenDiscriminationHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.linear = nn.Linear(config.hidden_size, config.hidden_size) |
| self.predictor = nn.Linear(config.hidden_size, 1) |
|
|
| def forward(self, x): |
| x = self.linear(x) |
| x = F.gelu(x) |
| x = self.predictor(x) |
| return x.squeeze(-1) |
|
|
|
|
| class TextStructurePredictionHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) |
| self.norm = nn.LayerNorm(config.hidden_size * 2) |
| self.linear2 = nn.Linear(config.hidden_size * 2, 6) |
|
|
| def forward( |
| self, x, |
| ): |
| x = self.linear1(x) |
| x = F.gelu(x) |
| x = self.norm(x) |
| return self.linear2(x) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TSPModelForTokenClassification(TSPPreTrainedModel): |
| def __init__(self, config, num_classes): |
| super().__init__(config) |
| self.backbone = TSPModel(config) |
| self.head = TokenClassificationHead(config, num_classes) |
| self.apply(self._init_weights) |
|
|
| def forward( |
| self, |
| input_ids, |
| attention_mask, |
| token_type_ids, |
| ): |
| hidden_states = self.backbone( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| return self.head(hidden_states) |
|
|
|
|
| class TokenClassificationHead(nn.Module): |
| def __init__(self, config, num_classes): |
| super().__init__() |
| self.dropout = nn.Dropout(config.dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, num_classes) |
|
|
| def forward(self, x): |
| x = self.dropout(x) |
| x = self.classifier(x) |
| return x |
|
|
|
|
| class TSPModelForSequenceClassification(TSPPreTrainedModel): |
| def __init__(self, config, num_classes): |
| super().__init__(config) |
| self.backbone = TSPModel(config) |
| self.head = SequenceClassififcationHead(config, num_classes) |
| self.apply(self._init_weights) |
|
|
| def forward( |
| self, |
| input_ids, |
| attention_mask, |
| token_type_ids, |
| ): |
| hidden_states = self.backbone( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| return self.head(hidden_states) |
|
|
|
|
| class SequenceClassififcationHead(nn.Module): |
| def __init__(self, config, num_classes): |
| super().__init__() |
| self.dropout = nn.Dropout(config.dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, num_classes) |
|
|
| def forward( |
| self, x, |
| ): |
| x = x[:, 0, :] |
| x = self.dropout(x) |
| return self.classifier(x) |
|
|
|
|
| class TSPModelForQuestionAnswering(TSPPreTrainedModel): |
| def __init__(self, config, num_classes): |
| super().__init__() |
| self.backbone = TSPModel(config) |
| self.head = SequenceClassififcationHead(config, num_classes) |
| self.apply(self._init_weights) |
|
|
| def forward( |
| self, |
| input_ids, |
| attention_mask, |
| token_type_ids, |
| ): |
| hidden_states = self.backbone( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| return self.head(hidden_states) |
|
|
|
|
| class SquadHead(nn.Module): |
| def __init__( |
| self, config, beam_size, predict_answerability, |
| ): |
| super().__init__() |
| self.beam_size = beam_size |
| self.predict_answerability = predict_answerability |
|
|
| |
| self.start_predictor = nn.Linear(config.hidden_size, 1) |
|
|
| |
| self.end_predictor = nn.Sequential( |
| nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1), |
| ) |
|
|
| |
| if predict_answerability: |
| self.answerability_predictor = nn.Sequential( |
| nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1), |
| ) |
| else: |
| self.answerability_predictor = None |
|
|
| def forward( |
| self, |
| hidden_states, |
| token_type_ids, |
| answer_start_position=None, |
| ): |
|
|
| |
| answer_mask = token_type_ids |
| last_sep = answer_mask.cumsum(dim=1) == answer_mask.sum( |
| dim=1, keepdim=True |
| ) |
| answer_mask = answer_mask * ~last_sep |
| answer_mask[:, 0] = 1 |
| answer_mask = answer_mask.bool() |
|
|
| |
| start_logits, start_top_hidden_states = self._calculate_start( |
| hidden_states, answer_mask, answer_start_position |
| ) |
|
|
| |
| end_logits = self._calculate_end_logits( |
| hidden_states, start_top_hidden_states, answer_mask, |
| ) |
|
|
| |
| answerability_logits = None |
| if self.answerability_predictor is not None: |
| answerability_logits = self._calculate_answerability_logits( |
| hidden_states, start_logits |
| ) |
|
|
| return start_logits, end_logits, answerability_logits |
|
|
| def _calculate_start(self, hidden_states, answer_mask, start_positions): |
| start_logits = self.start_predictor(hidden_states).squeeze(-1) |
| start_logits = start_logits.masked_fill(~answer_mask, -float("inf")) |
| start_top_indices, start_top_hidden_states = None, None |
| if self.training: |
| start_top_indices = start_positions |
| else: |
| k = self.beam_size |
| _, start_top_indices = start_logits.topk(k=k, dim=-1) |
| start_top_hidden_states = torch.stack( |
| [ |
| hiddens.index_select(dim=0, index=index) |
| for hiddens, index in zip(hidden_states, start_top_indices) |
| ] |
| ) |
| return start_logits, start_top_hidden_states |
|
|
| def _calculate_end_logits( |
| self, hidden_states, start_top_hidden_states, answer_mask |
| ): |
| B, L, D = hidden_states.shape |
| start_tophiddens = start_top_hidden_states.view(B, -1, 1, D).expand( |
| -1, -1, L, -1 |
| ) |
| end_hidden_states = torch.cat( |
| [ |
| start_tophiddens, |
| hidden_states.view(B, 1, L, D).expand_as(start_tophiddens), |
| ], |
| dim=-1, |
| ) |
| end_logits = self.end_predictor(end_hidden_states).squeeze(-1) |
| end_logits = end_logits.masked_fill( |
| ~answer_mask.view(B, 1, L), -float("inf") |
| ) |
| end_logits = end_logits.squeeze(1) |
|
|
| return end_logits |
|
|
| def _calculate_answerability_logits(self, hidden_states, start_logits): |
| answerability_hidden_states = hidden_states[:, 0, :] |
| start_probs = start_logits.softmax(dim=-1).unsqueeze(-1) |
| start_featrues = (start_probs * hidden_states).sum(dim=1) |
| answerability_hidden_states = torch.cat( |
| [answerability_hidden_states, start_featrues], dim=-1 |
| ) |
| answerability_logits = self.answerability_predictor( |
| answerability_hidden_states |
| ) |
| return answerability_logits.squeeze(-1) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TSPModel(TSPPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.embeddings = Embeddings(config) |
| if config.embedding_size != config.hidden_size: |
| self.embeddings_project = nn.Linear( |
| config.embedding_size, config.hidden_size |
| ) |
| self.layers = nn.ModuleList( |
| EncoderLayer(config) for _ in range(config.num_hidden_layers) |
| ) |
| self.apply(self._init_weights) |
|
|
| def forward( |
| self, |
| input_ids, |
| attention_mask, |
| token_type_ids, |
| ): |
| x = self.embeddings( |
| input_ids=input_ids, token_type_ids=token_type_ids |
| ) |
| if hasattr(self, "embeddings_project"): |
| x = self.embeddings_project(x) |
|
|
| extended_attention_mask = self.get_extended_attention_mask( |
| attention_mask=attention_mask, |
| input_shape=input_ids.shape, |
| device=input_ids.device, |
| ) |
|
|
| for layer_idx, layer in enumerate(self.layers): |
| x = layer(x, attention_mask=extended_attention_mask) |
|
|
| return x |
|
|
|
|
| class Embeddings(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.word_embeddings = nn.Embedding( |
| config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id |
| ) |
| if config.position_embedding_type == "absolute": |
| self.position_embeddings = nn.Embedding( |
| config.max_sequence_length, config.embedding_size |
| ) |
| self.token_type_embeddings = nn.Embedding(2, config.embedding_size) |
| self.norm = nn.LayerNorm(config.embedding_size) |
| self.dropout = nn.Dropout(config.dropout_prob) |
|
|
| def forward( |
| self, |
| input_ids, |
| token_type_ids, |
| ): |
| B, L = input_ids.shape |
| embeddings = self.word_embeddings(input_ids) |
| embeddings += self.token_type_embeddings(token_type_ids) |
| if hasattr(self, "position_embeddings"): |
| embeddings += self.position_embeddings.weight[None, :L, :] |
| embeddings = self.norm(embeddings) |
| embeddings = self.dropout(embeddings) |
| return embeddings |
|
|
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self_attn_block = BlockWrapper(config, MultiHeadSelfAttention) |
| self.transition_block = BlockWrapper(config, FeedForwardNetwork) |
|
|
| def forward( |
| self, |
| x, |
| attention_mask, |
| ): |
| x = self.self_attn_block(x, attention_mask=attention_mask) |
| x = self.transition_block(x) |
| return x |
|
|
|
|
| class BlockWrapper(nn.Module): |
| def __init__(self, config, sublayer_cls): |
| super().__init__() |
| self.sublayer = sublayer_cls(config) |
| self.dropout = nn.Dropout(config.dropout_prob) |
| self.norm = nn.LayerNorm(config.hidden_size) |
|
|
| def forward(self, x, **kwargs): |
| original_x = x |
| x = self.sublayer(x, **kwargs) |
| x = self.dropout(x) |
| x = original_x + x |
| x = self.norm(x) |
| return x |
|
|
|
|
| class MultiHeadSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.mix_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size) |
| self.attention = Attention(config) |
| self.o_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.H = config.num_attention_heads |
| self.d = config.hidden_size // self.H |
| if config.position_embedding_type == "rotary": |
| self.rotray_position_embeds = RotaryEmbedding(self.d) |
|
|
| def forward( |
| self, |
| x, |
| attention_mask, |
| ): |
| B, L, D, H, d = *x.shape, self.H, self.d |
| query, key, value = ( |
| self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1) |
| ) |
| if hasattr(self, "rotray_position_embeds"): |
| query, key = self.rotray_position_embeds(query, key) |
| output = self.attention(query, key, value, attention_mask) |
| output = self.o_proj(output.transpose(1, 2).reshape(B, L, D)) |
| return output |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dropout = nn.Dropout(config.dropout_prob) |
|
|
| def forward( |
| self, |
| query, |
| key, |
| value, |
| attention_mask, |
| ): |
| B, H, L, d = key.shape |
| attention_score = query.matmul(key.transpose(-2, -1)) |
| attention_score = attention_score / math.sqrt(d) |
| attention_score += attention_mask |
| attention_probs = attention_score.softmax(dim=-1) |
| attention_probs = self.dropout(attention_probs) |
| output = attention_probs.matmul(value) |
| return output |
|
|
|
|
| class FeedForwardNetwork(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) |
| self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
| def forward(self, x): |
| x = self.linear1(x) |
| x = F.gelu(x) |
| x = self.linear2(x) |
| return x |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| seq_len_cached = 0 |
| cos_cached = None |
| sin_cached = None |
|
|
| def __init__(self, dim): |
| super().__init__() |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| def _forward(self, x): |
| |
| |
| seq_len = x.shape[2] |
| if seq_len > RotaryEmbedding.seq_len_cached: |
| RotaryEmbedding.seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
| freqs = t.view(-1, 1) @ self.inv_freq.view(1, -1) |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| RotaryEmbedding.cos_cached = emb.cos()[None, None, :, :] |
| RotaryEmbedding.sin_cached = emb.sin()[None, None, :, :] |
| |
| if seq_len == RotaryEmbedding.seq_len_cached: |
| cos, sin = RotaryEmbedding.cos_cached, RotaryEmbedding.sin_cached |
| else: |
| cos, sin = ( |
| RotaryEmbedding.cos_cached[:, :, :seq_len, :], |
| RotaryEmbedding.sin_cached[:, :, :seq_len, :], |
| ) |
|
|
| |
| sections = [x.shape[-1] // 2, x.shape[-1] - x.shape[-1] // 2] |
| x1, x2 = x.split(sections, dim=-1) |
| half_rotated_x = torch.cat((-x2, x1), dim=-1) |
| return (x * cos) + (half_rotated_x * sin) |
|
|
| def forward( |
| self, query, key, |
| ): |
| return self._forward(query), self._forward(key) |
|
|