| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch GPTJiang model.""" |
|
|
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| import torch.nn.functional as F |
|
|
| from transformers.activations import ACT2FN |
| from transformers.file_utils import ( |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| replace_return_docstrings, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
| from .configuration_gpt_jiang import GPTJiangConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "GPTJiangConfig" |
| GPT_JIANG_PRETRAINED_MODEL_ARCHIVE_LIST = [] |
|
|
|
|
| class RMSNorm(torch.nn.Module): |
| def __init__(self, dim: int, eps: float=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
|
|
| class GPTJiangPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
| config_class = GPTJiangConfig |
| base_model_prefix = "gpt_jiang" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["GPTJiangLayer"] |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, GatedLinear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.fill_(1.0) |
| elif isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, RMSNorm): |
| |
| module.weight.data.fill_(1.0) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, GPTJiangModel): |
| module.gradient_checkpointing = value |
|
|
|
|
| class GPTJiangAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.max_position_embeddings = config.max_position_embeddings |
| self.num_attention_heads = config.num_attention_heads |
| self.hidden_size = config.hidden_size |
| self.head_size = self.hidden_size // self.num_attention_heads |
| self.rotary_ndims = int(self.head_size * config.rotary_pct) |
| self.rotary_emb = RotaryEmbedding( |
| self.rotary_ndims, |
| config.max_position_embeddings, |
| base=config.rotary_emb_base |
| ) |
| self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.causal_mask_cached = None |
|
|
| def causal_mask(self, x, seq_len): |
| if self.causal_mask_cached is None or seq_len > self.causal_mask_cached.shape[2]: |
| cache_size = max(self.max_position_embeddings, seq_len) |
| self.causal_mask_cached = torch.ones( |
| cache_size, |
| cache_size, |
| dtype=torch.bool |
| ).tril().view(1, 1, cache_size, cache_size) |
| return self.causal_mask_cached[:, :, :seq_len, :seq_len].to(x.device) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask, |
| head_mask=None, |
| layer_past=None, |
| use_cache=False, |
| output_attentions=False |
| ): |
| has_layer_past = layer_past is not None |
|
|
| |
| |
| |
| qkv = self.query_key_value(hidden_states) |
|
|
| |
| |
| new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) |
| qkv = qkv.view(*new_qkv_shape) |
|
|
| |
| query = qkv[..., : self.head_size].permute(0, 2, 1, 3) |
| key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) |
| value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| seq_len = key.shape[-2] |
| offset = 0 |
| if has_layer_past: |
| offset = layer_past[0].shape[-2] |
| seq_len += offset |
| cos, sin = self.rotary_emb(value, seq_len=seq_len) |
|
|
| query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=offset) |
| |
| |
| |
|
|
| |
| if has_layer_past: |
| past_key = layer_past[0] |
| past_value = layer_past[1] |
| key = torch.cat((past_key, key), dim=-2) |
| value = torch.cat((past_value, value), dim=-2) |
| present = (key, value,) if use_cache else None |
|
|
| query = query.type_as(hidden_states) |
| key = key.type_as(hidden_states) |
| value = value.type_as(hidden_states) |
|
|
| if output_attentions: |
| |
| attn_output, attn_weights = self._attn( |
| query, key, value, |
| attention_mask=attention_mask, |
| head_mask=head_mask |
| ) |
| else: |
| if layer_past is not None and attention_mask is None: |
| |
| batch_size = query.size(0) |
| attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool)[:, None, None, :] |
|
|
| if attention_mask is not None: |
| attn_mask = attention_mask.transpose(2, 3) * attention_mask |
| query_length = query.size(-2) |
| key_length = key.size(-2) |
| if query_length > 1: |
| causal_mask = self.causal_mask(query, seq_len) |
| causal_mask = causal_mask[:, :, -query_length:, :] |
| attn_mask = (attn_mask[:, :, -query_length:, :] * causal_mask).to(torch.bool) |
| else: |
| attn_mask = attn_mask[:, :, -query_length:, :].to(torch.bool) |
|
|
| attn_output = F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| attn_mask=attn_mask, |
| is_causal=False |
| ) |
| else: |
| attn_output = F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| attn_mask=None, |
| is_causal=True |
| ) |
| attn_weights = None |
|
|
| |
| |
| attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) |
| |
| attn_output = self.dense(attn_output) |
|
|
| outputs = (attn_output, present) |
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| return outputs |
|
|
| @classmethod |
| def _calculate_attn_output_loss(self, attn_output): |
| bs, num_attention_heads, seq_len, attn_head_size = attn_output.size() |
| attn_output_out = attn_output.view(bs, num_attention_heads, -1) |
| attn_output_out_norm = attn_output_out / torch.max( |
| attn_output_out.norm(dim=2, keepdim=True), |
| 1e-8 * torch.ones_like(attn_output_out) |
| ) |
| sim = torch.bmm(attn_output_out_norm, attn_output_out_norm.permute(0, 2, 1)) |
| attn_output_loss = sim.sum() / sim.numel() |
| return attn_output_loss |
|
|
| @classmethod |
| def _split_heads(cls, tensor, num_attention_heads, attn_head_size): |
| """ |
| Splits hidden dim into attn_head_size and num_attention_heads |
| """ |
| |
| new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) |
| |
| tensor = tensor.view(new_shape) |
| |
| tensor = tensor.permute(0, 2, 1, 3) |
| return tensor |
|
|
| @classmethod |
| def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): |
| """ |
| Merges attn_head_size dim and num_attn_heads dim into hidden dim |
| """ |
| |
| tensor = tensor.permute(0, 2, 1, 3).contiguous() |
| |
| tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) |
| |
| return tensor |
|
|
| def create_upper_triangular_matrix(self, q, k): |
| size = max(q, k) |
| |
| identity = torch.eye(size) |
| |
| row_indices = torch.arange(size).view(-1, 1).expand(size, size) |
| |
| col_indices = torch.arange(size).view(1, -1).expand(size, size) |
| |
| upper_triangular_matrix = torch.where(row_indices < col_indices, 0, 1) |
| return upper_triangular_matrix[-q:, -k:].to(torch.bool) |
| |
| def _attn(self, query, key, value, attention_mask=None, head_mask=None): |
| |
| |
| batch_size, num_attention_heads, query_length, attn_head_size = query.size() |
| key_length = key.size(-2) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| causal_mask = self.create_upper_triangular_matrix( |
| query_length, key_length |
| ).view(1, 1, query_length, key_length).to(query.device) |
|
|
| query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) |
| key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) |
| attn_scores = torch.zeros( |
| batch_size * num_attention_heads, |
| query_length, |
| key_length, |
| dtype=query.dtype, |
| device=key.device, |
| ) |
| norm_factor = self.head_size ** 0.5 |
| attn_scores = torch.baddbmm( |
| attn_scores, |
| query, |
| key.transpose(1, 2), |
| beta=1.0, |
| alpha=(torch.tensor(1.0, dtype=query.dtype, device=query.device) / norm_factor), |
| ) |
| attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) |
|
|
| mask_value = torch.finfo(attn_scores.dtype).min |
| |
| |
| mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) |
| attn_scores = torch.where(causal_mask, attn_scores, mask_value) |
|
|
| if attention_mask is not None: |
| |
| attn_scores = attn_scores + attention_mask |
|
|
| attn_weights = nn.functional.softmax(attn_scores.float(), dim=-1).type_as(value) |
|
|
| |
| if head_mask is not None: |
| attn_weights = attn_weights * head_mask |
|
|
| attn_output = torch.matmul(attn_weights, value) |
| return attn_output, attn_weights |
|
|
|
|
| class RotaryEmbedding(torch.nn.Module): |
| def __init__(self, dim, max_position_embeddings, base=10000, device=None): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| |
| self.max_seq_len_cached = max_position_embeddings |
| t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.cos_cached = emb.cos()[None, None, :, :] |
| self.sin_cached = emb.sin()[None, None, :, :] |
|
|
| def forward(self, x, seq_len=None): |
| |
| |
| if seq_len > self.max_seq_len_cached: |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| self.cos_cached = emb.cos()[None, None, :, :] |
| self.sin_cached = emb.sin()[None, None, :, :] |
| return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): |
| cos = cos[..., offset : q.shape[-2] + offset, :] |
| sin = sin[..., offset : q.shape[-2] + offset, :] |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class GatedLinear(nn.Linear): |
| pass |
|
|
|
|
| class GPTJiangMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias) |
| self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) |
| self.gated = config.gated |
| if config.gated: |
| self.dense_h_to_4h_gate = GatedLinear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias) |
| self.act = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_states): |
| |
| if self.gated: |
| |
| |
| |
| hidden_states = self.act(self.dense_h_to_4h(hidden_states)) * self.dense_h_to_4h_gate(hidden_states) |
| else: |
| |
| |
| hidden_states = self.act(self.dense_h_to_4h(hidden_states)) |
| hidden_states = self.dense_4h_to_h(hidden_states) |
| return hidden_states |
|
|
|
|
| class GPTJiangLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.use_parallel_residual = config.use_parallel_residual |
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.attention = GPTJiangAttention(config) |
| self.mlp = GPTJiangMLP(config) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| use_cache=False, |
| layer_past=None, |
| output_attentions=False, |
| ): |
| attention_layer_outputs = self.attention( |
| self.input_layernorm(hidden_states), |
| attention_mask=attention_mask, |
| layer_past=layer_past, |
| head_mask=head_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| attn_output = attention_layer_outputs[0] |
| outputs = attention_layer_outputs[1:] |
|
|
| |
| if self.use_parallel_residual: |
| |
| |
| mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) |
| hidden_states = mlp_output + attn_output + hidden_states |
| else: |
| |
| |
| |
| attn_output = attn_output + hidden_states |
| mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) |
| hidden_states = mlp_output + attn_output |
|
|
| if use_cache: |
| outputs = (hidden_states,) + outputs |
| else: |
| outputs = (hidden_states,) + outputs[1:] |
|
|
| return outputs |
|
|
|
|
| GPT_JIANG_START_DOCSTRING = r""" |
| This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use |
| it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and |
| behavior. |
| |
| Parameters: |
| config ([`~GPTJiangConfig`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
| GPT_JIANG_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `({0})`): |
| Indices of input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
| Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert *input_ids* indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare GPTJiang Model transformer outputting raw hidden-states without any specific head on top.", |
| GPT_JIANG_START_DOCSTRING, |
| ) |
| class GPTJiangModel(GPTJiangPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([GPTJiangLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.final_layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_in |
|
|
| def set_input_embeddings(self, value): |
| self.embed_in = value |
|
|
| @add_start_docstrings_to_model_forward(GPT_JIANG_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| @add_code_sample_docstrings( |
| output_type=BaseModelOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| r""" |
| past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
| Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
| don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
| `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| batch_size, seq_length = input_shape |
|
|
| if past_key_values is None: |
| past_key_values = tuple([None] * self.config.num_hidden_layers) |
|
|
| |
| if attention_mask is not None: |
| assert batch_size > 0, "batch_size has to be defined and > 0" |
| attention_mask = attention_mask.view(batch_size, -1) |
| |
| |
| |
| |
| |
| attention_mask = attention_mask[:, None, None, :] |
|
|
| |
| |
| |
| |
| |
| attention_mask = attention_mask.to(dtype=self.dtype) |
| |
|
|
| |
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_in(input_ids) |
|
|
| hidden_states = inputs_embeds |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| presents = () if use_cache else None |
| all_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
| for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| |
| return module(*inputs, use_cache, None, output_attentions) |
|
|
| return custom_forward |
|
|
| outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer), |
| hidden_states, |
| attention_mask, |
| head_mask[i], |
| ) |
| else: |
| outputs = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| head_mask=head_mask[i], |
| layer_past=layer_past, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = outputs[0] |
| if use_cache is True: |
| presents = presents + (outputs[1],) |
| if output_attentions: |
| all_attentions = all_attentions + (outputs[2 if use_cache else 1],) |
|
|
| hidden_states = self.final_layer_norm(hidden_states) |
| |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) |
|
|
| ret = BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=presents, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
| return ret |
|
|
|
|
| @add_start_docstrings( |
| """GPTJiang Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_JIANG_START_DOCSTRING |
| ) |
| class GPTJiangForCausalLM(GPTJiangPreTrainedModel): |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.gpt_kdf = GPTJiangModel(config) |
| self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_output_embeddings(self): |
| return self.embed_out |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.embed_out = new_embeddings |
|
|
| @add_start_docstrings_to_model_forward(GPT_JIANG_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
| `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are |
| only required when the model is used as a decoder in a Sequence to Sequence model. |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
| don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
| `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in |
| `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are |
| ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, GPTJiangForCausalLM, GPTJiangConfig |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
| >>> config = GPTJiangConfig.from_pretrained("EleutherAI/gpt-neox-20b") |
| >>> config.is_decoder = True |
| >>> model = GPTJiangForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) |
| |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
| >>> outputs = model(**inputs) |
| |
| >>> prediction_logits = outputs.logits |
| ```""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.gpt_kdf( |
| input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| lm_logits = self.embed_out(hidden_states) |
|
|
| lm_loss = None |
| attn_output_loss = None |
| if labels is not None: |
| |
| shift_logits = lm_logits[:, :-1, :].contiguous() |
| labels = labels[:, 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (lm_logits,) + outputs[1:] |
| return ((lm_loss,) + output) if lm_loss is not None else output |
|
|
| ret = CausalLMOutputWithPast( |
| loss=lm_loss, |
| logits=lm_logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
| return ret |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): |
| input_shape = input_ids.shape |
|
|
| |
| if attention_mask is None: |
| attention_mask = input_ids.new_ones(input_shape) |
|
|
| |
| if past_key_values and past_key_values[0] is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| } |
|
|
| def _reorder_cache(self, past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], |
| ) |
| return reordered_past |
|
|