| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch XGLM model.""" |
|
|
|
|
| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
|
|
| from transformers.activations import ACT2FN |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask |
| from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions ,SequenceClassifierOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging |
| from .configuration_xglm import XGLMConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "facebook/xglm-564M" |
| _CONFIG_FOR_DOC = "XGLMConfig" |
|
|
|
|
| XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "facebook/xglm-564M", |
| |
| ] |
|
|
| XGLM_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`XGLMConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
| XGLM_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.max_position_embeddings - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of |
| the decoder. |
| encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): |
| Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values |
| selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): |
| Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): |
| Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
| `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
| don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
| `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape |
| `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you |
| can choose to directly pass an embedded representation. This is useful if you want more control over how to |
| convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| class XGLMSinusoidalPositionalEmbedding(nn.Module): |
| """This module produces sinusoidal positional embeddings of any length.""" |
|
|
| def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): |
| super().__init__() |
| self.offset = 2 |
| self.embedding_dim = embedding_dim |
| self.padding_idx = padding_idx |
| self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) |
|
|
| def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): |
| emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) |
| if hasattr(self, "weights"): |
| |
| emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) |
|
|
| self.register_buffer("weights", emb_weights, persistent=False) |
|
|
| @staticmethod |
| def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): |
| """ |
| Build sinusoidal embeddings. |
| |
| This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of |
| "Attention Is All You Need". |
| """ |
| half_dim = embedding_dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) |
| if embedding_dim % 2 == 1: |
| |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
| if padding_idx is not None: |
| emb[padding_idx, :] = 0 |
|
|
| return emb.to(torch.get_default_dtype()) |
|
|
| @torch.no_grad() |
| def forward(self, position_ids: torch.Tensor = None, past_key_values_length: int = 0): |
| bsz, seq_len = position_ids.size() |
| position_ids += self.offset |
|
|
| |
| max_pos = 2 + seq_len + past_key_values_length |
| if max_pos > self.weights.size(0): |
| self.make_weights(max_pos, self.embedding_dim, self.padding_idx) |
|
|
| return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() |
|
|
|
|
| class XGLMAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| dropout: float = 0.0, |
| is_decoder: bool = False, |
| bias: bool = True, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.head_dim = embed_dim // num_heads |
|
|
| if (self.head_dim * num_heads) != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
| f" and `num_heads`: {num_heads})." |
| ) |
| self.scaling = self.head_dim**-0.5 |
| self.is_decoder = is_decoder |
|
|
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| key_value_states: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| |
| |
| is_cross_attention = key_value_states is not None |
|
|
| bsz, tgt_len, _ = hidden_states.size() |
|
|
| |
| query_states = self.q_proj(hidden_states) * self.scaling |
| |
| if is_cross_attention and past_key_value is not None: |
| |
| key_states = past_key_value[0] |
| value_states = past_key_value[1] |
| elif is_cross_attention: |
| |
| key_states = self._shape(self.k_proj(key_value_states), -1, bsz) |
| value_states = self._shape(self.v_proj(key_value_states), -1, bsz) |
| elif past_key_value is not None: |
| |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| else: |
| |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
| if self.is_decoder: |
| |
| |
| |
| |
| |
| |
| |
| past_key_value = (key_states, value_states) |
|
|
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
| key_states = key_states.view(*proj_shape) |
| value_states = value_states.view(*proj_shape) |
|
|
| src_len = key_states.size(1) |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
|
|
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
| ) |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
| attn_weights = torch.max( |
| attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) |
| ) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| |
| if attn_weights.dtype == torch.float16: |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) |
| else: |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| if layer_head_mask is not None: |
| if layer_head_mask.size() != (self.num_heads,): |
| raise ValueError( |
| f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" |
| f" {layer_head_mask.size()}" |
| ) |
| attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| if output_attentions: |
| |
| |
| |
| |
| attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
| else: |
| attn_weights_reshaped = None |
|
|
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
| attn_output = torch.bmm(attn_probs, value_states) |
|
|
| if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
| attn_output = attn_output.transpose(1, 2) |
|
|
| |
| |
| attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights_reshaped, past_key_value |
|
|
|
|
| class XGLMDecoderLayer(nn.Module): |
| def __init__(self, config: XGLMConfig): |
| super().__init__() |
| self.embed_dim = config.d_model |
|
|
| self.self_attn = XGLMAttention( |
| embed_dim=self.embed_dim, |
| num_heads=config.attention_heads, |
| dropout=config.attention_dropout, |
| is_decoder=True, |
| ) |
| self.dropout = config.dropout |
| self.activation_fn = ACT2FN[config.activation_function] |
| self.activation_dropout = config.activation_dropout |
|
|
| if config.add_cross_attention: |
| self.encoder_attn = XGLMAttention( |
| embed_dim=self.embed_dim, |
| num_heads=config.attention_heads, |
| dropout=config.attention_dropout, |
| is_decoder=True, |
| ) |
| self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
| self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim) |
| self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| cross_attn_layer_head_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = True, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`): attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| encoder_hidden_states (`torch.FloatTensor`): |
| cross attention input to the layer of shape `(batch, seq_len, embed_dim)` |
| encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size |
| `(encoder_attention_heads,)`. |
| cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of |
| size `(decoder_attention_heads,)`. |
| past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| """ |
| residual = hidden_states |
| hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
| |
| |
| self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| past_key_value=self_attn_past_key_value, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
|
|
| |
| cross_attn_present_key_value = None |
| cross_attn_weights = None |
| if encoder_hidden_states is not None: |
| residual = hidden_states |
| hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
| |
| cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
| hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( |
| hidden_states=hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| layer_head_mask=cross_attn_layer_head_mask, |
| past_key_value=cross_attn_past_key_value, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
|
|
| |
| present_key_value = present_key_value + cross_attn_present_key_value |
|
|
| |
| residual = hidden_states |
| hidden_states = self.final_layer_norm(hidden_states) |
| hidden_states = self.activation_fn(self.fc1(hidden_states)) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
| hidden_states = self.fc2(hidden_states) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights, cross_attn_weights) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| class XGLMPreTrainedModel(PreTrainedModel): |
| config_class = XGLMConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["XGLMDecoderLayer"] |
|
|
| def _init_weights(self, module): |
| std = self.config.init_std |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| @add_start_docstrings( |
| "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", |
| XGLM_START_DOCSTRING, |
| ) |
| class XGLMModel(XGLMPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`] |
| |
| Args: |
| config: XGLMConfig |
| embed_tokens (nn.Embedding): output embedding |
| """ |
|
|
| def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None): |
| super().__init__(config) |
| self.dropout = config.dropout |
| self.layerdrop = config.layerdrop |
| self.padding_idx = config.pad_token_id |
| self.max_target_positions = config.max_position_embeddings |
| self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 |
|
|
| if embed_tokens is not None: |
| self.embed_tokens = embed_tokens |
| else: |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) |
|
|
| self.embed_positions = XGLMSinusoidalPositionalEmbedding( |
| config.max_position_embeddings, |
| config.d_model, |
| config.pad_token_id, |
| ) |
| self.layers = nn.ModuleList([XGLMDecoderLayer(config) for _ in range(config.num_layers)]) |
| self.layer_norm = nn.LayerNorm(config.d_model) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutputWithPastAndCrossAttentions, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
| if position_ids is None: |
| position_ids = torch.arange( |
| past_key_values_length, |
| input_shape[-1] + past_key_values_length, |
| dtype=torch.long, |
| device=input_ids.device if input_ids is not None else inputs_embeds.device, |
| ) |
| position_ids = position_ids.unsqueeze(0) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, input_shape, inputs_embeds, past_key_values_length |
| ) |
|
|
| |
| if encoder_hidden_states is not None and encoder_attention_mask is not None: |
| |
| encoder_attention_mask = _prepare_4d_attention_mask( |
| encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
| ) |
|
|
| hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length) |
| hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache =" |
| " False`..." |
| ) |
| use_cache = False |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
| next_decoder_cache = () if use_cache else None |
|
|
| |
| for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): |
| if attn_mask is not None: |
| if attn_mask.size()[0] != len(self.layers): |
| raise ValueError( |
| f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
| f" {head_mask.size()[0]}." |
| ) |
| for idx, decoder_layer in enumerate(self.layers): |
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| if self.training: |
| dropout_probability = torch.rand([]) |
| if dropout_probability < self.layerdrop: |
| continue |
|
|
| past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| head_mask[idx] if head_mask is not None else None, |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, |
| None, |
| output_attentions, |
| use_cache, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| cross_attn_layer_head_mask=( |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
| ), |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| if encoder_hidden_states is not None: |
| all_cross_attentions += (layer_outputs[2],) |
|
|
| hidden_states = self.layer_norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input |
| embeddings). |
| """, |
| XGLM_START_DOCSTRING, |
| ) |
| class XGLMForCausalLM(XGLMPreTrainedModel): |
| base_model_prefix = "model" |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = XGLMModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=CausalLMOutputWithCrossAttentions, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| """ |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| head_mask=head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| logits = self.lm_head(outputs[0]) |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_labels = labels.new_zeros(labels.shape) |
| shift_labels[:, :-1] = labels[:, 1:].clone() |
| shift_labels[:, -1] = self.config.pad_token_id |
|
|
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithCrossAttentions( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| cross_attentions=outputs.cross_attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs |
| ): |
| if past_key_values is not None: |
| past_length = past_key_values[0][0].shape[2] |
|
|
| |
| if input_ids.shape[1] > past_length: |
| remove_prefix_length = past_length |
| else: |
| |
| remove_prefix_length = input_ids.shape[1] - 1 |
|
|
| input_ids = input_ids[:, remove_prefix_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
| else: |
| position_ids = None |
| |
| if attention_mask is None: |
| attention_mask = input_ids.new_ones(input_ids.shape) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| } |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
| ) |
| return reordered_past |
|
|
| |
| class XGLMForSequenceClassification(XGLMPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.transformer = XGLMModel(config) |
| self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) |
| self.model_parallel = False |
| self.device_map = None |
|
|
|
|
| self.post_init() |
| @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=SequenceClassifierOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.transformer( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size, sequence_length = input_ids.shape[:2] |
| else: |
| batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
| assert ( |
| self.config.pad_token_id is not None or batch_size == 1 |
| ), "Cannot handle batch sizes > 1 if no padding token is defined." |
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( |
| logits.device |
| ) |
| else: |
| sequence_lengths = -1 |
| logger.warning( |
| f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
| "unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
| ) |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |