| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch T5 model. """ |
|
|
| import copy |
| import math |
| import os |
| import warnings |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import logging |
| from transformers.modeling_utils import ( |
| PreTrainedModel, |
| find_pruneable_heads_and_indices, |
| prune_linear_layer, |
| ) |
| from transformers.file_utils import ( |
| DUMMY_INPUTS, |
| DUMMY_MASK, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| replace_return_docstrings, |
| ) |
|
|
| from .glen_t5_config import T5Config |
| from .glen_t5_outputs import BaseModelOutputWithPast, Seq2SeqModelOutput |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "T5Config" |
|
|
| |
| |
| |
| |
| T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "t5-small", |
| "t5-base", |
| "t5-large", |
| "t5-3b", |
| "t5-11b", |
| |
| ] |
|
|
|
|
| |
| |
| |
| |
| def load_tf_weights_in_t5(model, config, tf_checkpoint_path): |
| """Load tf checkpoints in a pytorch model.""" |
| try: |
| import re |
|
|
| import numpy as np |
| import tensorflow as tf |
| except ImportError: |
| logger.error( |
| "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " |
| "https://www.tensorflow.org/install/ for installation instructions." |
| ) |
| raise |
| tf_path = os.path.abspath(tf_checkpoint_path) |
| logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) |
| |
| init_vars = tf.train.list_variables(tf_path) |
| names = [] |
| tf_weights = {} |
| for name, shape in init_vars: |
| logger.info("Loading TF weight {} with shape {}".format(name, shape)) |
| array = tf.train.load_variable(tf_path, name) |
| names.append(name) |
| tf_weights[name] = array |
|
|
| for txt_name in names: |
| name = txt_name.split("/") |
| |
| |
| if any( |
| n |
| in [ |
| "adam_v", |
| "adam_m", |
| "AdamWeightDecayOptimizer", |
| "AdamWeightDecayOptimizer_1", |
| "global_step", |
| ] |
| for n in name |
| ): |
| logger.info("Skipping {}".format("/".join(name))) |
| tf_weights.pop(txt_name, None) |
| continue |
| if "_slot_" in name[-1]: |
| logger.info("Skipping {}".format("/".join(name))) |
| tf_weights.pop(txt_name, None) |
| continue |
| pointer = model |
| array = tf_weights[txt_name] |
| for m_name in name: |
| if re.fullmatch(r"[A-Za-z]+_\d+", m_name): |
| scope_names = re.split(r"_(\d+)", m_name) |
| else: |
| scope_names = [m_name] |
| if scope_names[0] in ["kernel", "scale", "embedding"]: |
| pointer = getattr(pointer, "weight") |
| |
| |
| |
| |
| |
| |
| else: |
| try: |
| pointer = getattr(pointer, scope_names[0]) |
| except AttributeError: |
| logger.info("Skipping {}".format("/".join(name))) |
| continue |
| if len(scope_names) >= 2: |
| num = int(scope_names[1]) |
| pointer = pointer[num] |
| if scope_names[0] not in ["kernel", "scale", "embedding"]: |
| pointer = getattr(pointer, "weight") |
| if scope_names[0] != "embedding": |
| logger.info( |
| "Transposing numpy weight of shape {} for {}".format(array.shape, name) |
| ) |
| array = np.transpose(array) |
| try: |
| assert ( |
| pointer.shape == array.shape |
| ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" |
| except AssertionError as e: |
| e.args += (pointer.shape, array.shape) |
| raise |
| logger.info("Initialize PyTorch weight {}".format(name)) |
| pointer.data = torch.from_numpy(array.astype(np.float32)) |
| tf_weights.pop(txt_name, None) |
|
|
| logger.info( |
| "Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())) |
| ) |
| |
| return model |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class T5LayerNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """Construct a layernorm module in the T5 style |
| No bias and no substraction of mean. |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, x): |
| |
| variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| x = x / torch.sqrt(variance + self.variance_epsilon) |
|
|
| if self.weight.dtype == torch.float16: |
| x = x.to(torch.float16) |
| return self.weight * x |
|
|
|
|
| class T5DenseReluDense(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward(self, hidden_states): |
| h = self.wi(hidden_states) |
| h = F.relu(h) |
| h = self.dropout(h) |
| h = self.wo(h) |
| return h |
|
|
|
|
| class T5LayerFF(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.DenseReluDense = T5DenseReluDense(config) |
| self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward(self, hidden_states): |
| norm_x = self.layer_norm(hidden_states) |
| y = self.DenseReluDense(norm_x) |
| layer_output = hidden_states + self.dropout(y) |
| return layer_output |
|
|
|
|
| class T5Attention(nn.Module): |
| def __init__( |
| self, |
| config: T5Config, |
| has_relative_attention_bias=False, |
| is_bidirectional=False, |
| ): |
| super().__init__() |
| self.is_bidirectional = is_bidirectional |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
|
|
| self.relative_attention_num_buckets = config.relative_attention_num_buckets |
| self.d_model = config.d_model |
| self.d_kv = config.d_kv |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.d_kv |
|
|
| |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) |
|
|
| if self.has_relative_attention_bias: |
| self.relative_attention_bias = nn.Embedding( |
| self.relative_attention_num_buckets, self.n_heads |
| ) |
| self.pruned_heads = set() |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| heads, index = find_pruneable_heads_and_indices( |
| heads, self.n_heads, self.d_kv, self.pruned_heads |
| ) |
| |
| self.q = prune_linear_layer(self.q, index) |
| self.k = prune_linear_layer(self.k, index) |
| self.v = prune_linear_layer(self.v, index) |
| self.o = prune_linear_layer(self.o, index, dim=1) |
| |
| self.n_heads = self.n_heads - len(heads) |
| self.inner_dim = self.d_kv * self.n_heads |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| @staticmethod |
| def _relative_position_bucket( |
| relative_position, bidirectional=True, num_buckets=32, max_distance=128 |
| ): |
| """ |
| Adapted from Mesh Tensorflow: |
| https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 |
| |
| Translate relative position to a bucket number for relative attention. |
| The relative position is defined as memory_position - query_position, i.e. |
| the distance in tokens from the attending position to the attended-to |
| position. If bidirectional=False, then positive relative positions are |
| invalid. |
| We use smaller buckets for small absolute relative_position and larger buckets |
| for larger absolute relative_positions. All relative positions >=max_distance |
| map to the same bucket. All relative positions <=-max_distance map to the |
| same bucket. This should allow for more graceful generalization to longer |
| sequences than the model has been trained on. |
| Args: |
| relative_position: an int32 Tensor |
| bidirectional: a boolean - whether the attention is bidirectional |
| num_buckets: an integer |
| max_distance: an integer |
| Returns: |
| a Tensor with the same shape as relative_position, containing int32 |
| values in the range [0, num_buckets) |
| """ |
| ret = 0 |
| n = -relative_position |
| if bidirectional: |
| num_buckets //= 2 |
| ret += (n < 0).to( |
| torch.long |
| ) * num_buckets |
| n = torch.abs(n) |
| else: |
| n = torch.max(n, torch.zeros_like(n)) |
| |
|
|
| |
| max_exact = num_buckets // 2 |
| is_small = n < max_exact |
|
|
| |
| val_if_large = max_exact + ( |
| torch.log(n.float() / max_exact) |
| / math.log(max_distance / max_exact) |
| * (num_buckets - max_exact) |
| ).to(torch.long) |
| val_if_large = torch.min( |
| val_if_large, torch.full_like(val_if_large, num_buckets - 1) |
| ) |
|
|
| ret += torch.where(is_small, n, val_if_large) |
| return ret |
|
|
| def compute_bias(self, qlen, klen): |
| """Compute binned relative position bias""" |
| context_position = torch.arange(qlen, dtype=torch.long)[:, None] |
| memory_position = torch.arange(klen, dtype=torch.long)[None, :] |
| relative_position = memory_position - context_position |
| rp_bucket = self._relative_position_bucket( |
| relative_position, |
| bidirectional=self.is_bidirectional, |
| num_buckets=self.relative_attention_num_buckets, |
| ) |
| rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) |
| values = self.relative_attention_bias( |
| rp_bucket |
| ) |
| values = values.permute([2, 0, 1]).unsqueeze( |
| 0 |
| ) |
| return values |
|
|
| def forward( |
| self, |
| input, |
| mask=None, |
| kv=None, |
| position_bias=None, |
| past_key_value=None, |
| head_mask=None, |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| ): |
| """ |
| Self-attention (if kv is None) or attention over source sentence (provided by kv). |
| """ |
| |
| |
| |
| bs, qlen, dim = input.size() |
|
|
| if past_key_value is not None: |
| assert self.is_decoder is True, "Encoder cannot cache past key value states" |
| assert ( |
| len(past_key_value) == 2 |
| ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( |
| len(past_key_value) |
| ) |
| real_qlen = ( |
| qlen + past_key_value[0].shape[2] |
| if query_length is None |
| else query_length |
| ) |
| else: |
| real_qlen = qlen |
|
|
| if kv is None: |
| klen = real_qlen |
| else: |
| klen = kv.size(1) |
|
|
| def shape(x): |
| """projection""" |
| return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2) |
|
|
| def unshape(x): |
| """compute context""" |
| return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim) |
|
|
| q = shape(self.q(input)) |
|
|
| if kv is None: |
| k = shape(self.k(input)) |
| v = shape(self.v(input)) |
| elif past_key_value is None: |
| k = v = kv |
| k = shape(self.k(k)) |
| v = shape(self.v(v)) |
|
|
| if past_key_value is not None: |
| if kv is None: |
| k_, v_ = past_key_value |
| k = torch.cat([k_, k], dim=2) |
| v = torch.cat([v_, v], dim=2) |
| else: |
| k, v = past_key_value |
|
|
| if self.is_decoder and use_cache is True: |
| present_key_value_state = ((k, v),) |
| else: |
| present_key_value_state = (None,) |
|
|
| |
| scores = torch.matmul( |
| q, k.transpose(3, 2) |
| ) |
|
|
| if position_bias is None: |
| if not self.has_relative_attention_bias: |
| raise ValueError( |
| "No position_bias provided and no weights to compute position_bias" |
| ) |
| position_bias = self.compute_bias(real_qlen, klen) |
|
|
| |
| |
| if past_key_value is not None: |
| position_bias = position_bias[:, :, -qlen:, :] |
|
|
| if mask is not None: |
| position_bias = position_bias + mask |
|
|
| scores += position_bias |
| weights = F.softmax(scores.float(), dim=-1).type_as( |
| scores |
| ) |
| weights = F.dropout( |
| weights, p=self.dropout, training=self.training |
| ) |
|
|
| |
| if head_mask is not None: |
| weights = weights * head_mask |
|
|
| context = torch.matmul(weights, v) |
| context = unshape(context) |
|
|
| context = self.o(context) |
|
|
| outputs = (context,) + present_key_value_state |
|
|
| if output_attentions: |
| outputs = outputs + (weights,) |
| if self.has_relative_attention_bias: |
| outputs = outputs + (position_bias,) |
| return outputs |
|
|
|
|
| class T5LayerSelfAttention(nn.Module): |
| def __init__(self, config, has_relative_attention_bias=False): |
| super().__init__() |
| self.SelfAttention = T5Attention( |
| config, |
| has_relative_attention_bias=has_relative_attention_bias, |
| is_bidirectional=not config.is_decoder, |
| ) |
| self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_bias=None, |
| head_mask=None, |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| ): |
| norm_x = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| norm_x, |
| mask=attention_mask, |
| position_bias=position_bias, |
| head_mask=head_mask, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| y = attention_output[0] |
| layer_output = hidden_states + self.dropout(y) |
| outputs = (layer_output,) + attention_output[ |
| 1: |
| ] |
| return outputs |
|
|
|
|
| class T5LayerCrossAttention(nn.Module): |
| def __init__(self, config, has_relative_attention_bias=False): |
| super().__init__() |
| self.EncDecAttention = T5Attention( |
| config, |
| has_relative_attention_bias=has_relative_attention_bias, |
| is_bidirectional=True, |
| ) |
| self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| hidden_states, |
| kv, |
| attention_mask=None, |
| position_bias=None, |
| head_mask=None, |
| past_key_value=None, |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| ): |
| norm_x = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| norm_x, |
| mask=attention_mask, |
| kv=kv, |
| position_bias=position_bias, |
| head_mask=head_mask, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| ) |
| y = attention_output[0] |
| layer_output = hidden_states + self.dropout(y) |
| outputs = (layer_output,) + attention_output[ |
| 1: |
| ] |
| return outputs |
|
|
|
|
| class T5Block(nn.Module): |
| def __init__(self, config, has_relative_attention_bias=False): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.layer = nn.ModuleList() |
| self.layer.append( |
| T5LayerSelfAttention( |
| config, has_relative_attention_bias=has_relative_attention_bias |
| ) |
| ) |
| if self.is_decoder: |
| self.layer.append( |
| T5LayerCrossAttention( |
| config, has_relative_attention_bias=has_relative_attention_bias |
| ) |
| ) |
|
|
| self.layer.append(T5LayerFF(config)) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_bias=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| encoder_decoder_position_bias=None, |
| head_mask=None, |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| ): |
| if past_key_value is not None: |
| assert self.is_decoder, "Only decoder can use `past_key_values`" |
| expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
|
|
| error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( |
| expected_num_past_key_values, |
| "2 (past / key) for cross attention" |
| if expected_num_past_key_values == 4 |
| else "", |
| len(past_key_value), |
| ) |
| assert len(past_key_value) == expected_num_past_key_values, error_message |
|
|
| self_attn_past_key_value = past_key_value[:2] |
| cross_attn_past_key_value = past_key_value[2:] |
| else: |
| self_attn_past_key_value, cross_attn_past_key_value = None, None |
|
|
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| head_mask=head_mask, |
| past_key_value=self_attn_past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| hidden_states, present_key_value_state = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[ |
| 2: |
| ] |
|
|
| if self.is_decoder and encoder_hidden_states is not None: |
| |
| |
| if present_key_value_state is not None: |
| query_length = present_key_value_state[0].shape[2] |
| else: |
| query_length = None |
|
|
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| kv=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| head_mask=head_mask, |
| past_key_value=cross_attn_past_key_value, |
| query_length=query_length, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = cross_attention_outputs[0] |
| |
| if present_key_value_state is not None: |
| present_key_value_state = ( |
| present_key_value_state + cross_attention_outputs[1] |
| ) |
|
|
| |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
|
|
| |
| hidden_states = self.layer[-1](hidden_states) |
| outputs = (hidden_states,) |
|
|
| |
| outputs = outputs + (present_key_value_state,) + attention_outputs |
| return outputs |
|
|
|
|
| class T5PreTrainedModel(PreTrainedModel): |
| """An abstract class to handle weights initialization and |
| a simple interface for downloading and loading pretrained models. |
| """ |
|
|
| config_class = T5Config |
| load_tf_weights = load_tf_weights_in_t5 |
| base_model_prefix = "transformer" |
|
|
| @property |
| def dummy_inputs(self): |
| input_ids = torch.tensor(DUMMY_INPUTS) |
| input_mask = torch.tensor(DUMMY_MASK) |
| dummy_inputs = { |
| "decoder_input_ids": input_ids, |
| "input_ids": input_ids, |
| "decoder_attention_mask": input_mask, |
| } |
| return dummy_inputs |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| from tevatron.modeling import ( |
| T5ForConditionalGeneration_GLEN as T5ForConditionalGeneration, |
| ) |
|
|
| factor = ( |
| self.config.initializer_factor |
| ) |
| if isinstance(module, T5LayerNorm): |
| module.weight.data.fill_(factor * 1.0) |
| elif isinstance(module, (T5Model, T5ForConditionalGeneration)): |
| |
| |
| module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) |
| elif isinstance(module, T5DenseReluDense): |
| |
| |
| |
| module.wi.weight.data.normal_( |
| mean=0.0, std=factor * ((self.config.d_model) ** -0.5) |
| ) |
| if hasattr(module.wi, "bias") and module.wi.bias is not None: |
| module.wi.bias.data.zero_() |
| module.wo.weight.data.normal_( |
| mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) |
| ) |
| if hasattr(module.wo, "bias") and module.wo.bias is not None: |
| module.wo.bias.data.zero_() |
| elif isinstance(module, T5Attention): |
| |
| |
| d_model = self.config.d_model |
| d_kv = self.config.d_kv |
| n_heads = self.config.num_heads |
| module.q.weight.data.normal_( |
| mean=0.0, std=factor * ((d_model * d_kv) ** -0.5) |
| ) |
| module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
| module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
| module.o.weight.data.normal_( |
| mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5) |
| ) |
| if module.has_relative_attention_bias: |
| module.relative_attention_bias.weight.data.normal_( |
| mean=0.0, std=factor * ((d_model) ** -0.5) |
| ) |
|
|
| def _shift_right(self, input_ids): |
| decoder_start_token_id = self.config.decoder_start_token_id |
| pad_token_id = self.config.pad_token_id |
|
|
| assert ( |
| decoder_start_token_id is not None |
| ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" |
|
|
| |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
| shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() |
| shifted_input_ids[..., 0] = decoder_start_token_id |
|
|
| assert ( |
| pad_token_id is not None |
| ), "self.model.config.pad_token_id has to be defined." |
| |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
| assert torch.all( |
| shifted_input_ids >= 0 |
| ).item(), "Verify that `shifted_input_ids` has only positive values" |
|
|
| return shifted_input_ids |
|
|
|
|
| class T5Stack(T5PreTrainedModel): |
| def __init__(self, config, embed_tokens=None): |
| super().__init__(config) |
|
|
| self.embed_tokens = embed_tokens |
| self.is_decoder = config.is_decoder |
|
|
| self.block = nn.ModuleList( |
| [ |
| T5Block(config, has_relative_attention_bias=bool(i == 0)) |
| for i in range(config.num_layers) |
| ] |
| ) |
| self.final_layer_norm = T5LayerNorm( |
| config.d_model, eps=config.layer_norm_epsilon |
| ) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| self.init_weights() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def get_output_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.embed_tokens = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| inputs_embeds=None, |
| head_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError( |
| f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time" |
| ) |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError( |
| f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds" |
| ) |
|
|
| if inputs_embeds is None: |
| assert ( |
| self.embed_tokens is not None |
| ), "You have to intialize the model with valid token embeddings" |
| if self.training and self.is_decoder and len(input_ids) == 2: |
| inputs_embeds = self.embed_tokens(input_ids[0]) |
| else: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| batch_size, seq_length = input_shape |
|
|
| |
| mask_seq_length = ( |
| past_key_values[0][0].shape[2] + seq_length |
| if past_key_values is not None |
| else seq_length |
| ) |
|
|
| if use_cache is True: |
| assert ( |
| self.is_decoder |
| ), ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format( |
| self |
| ) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(batch_size, mask_seq_length).to( |
| inputs_embeds.device |
| ) |
| if ( |
| self.is_decoder |
| and encoder_attention_mask is None |
| and encoder_hidden_states is not None |
| ): |
| encoder_seq_length = encoder_hidden_states.shape[1] |
| encoder_attention_mask = torch.ones( |
| batch_size, |
| encoder_seq_length, |
| device=inputs_embeds.device, |
| dtype=torch.long, |
| ) |
|
|
| |
| if past_key_values is None: |
| past_key_values = [None] * len(self.block) |
|
|
| |
| extended_attention_mask = self.get_extended_attention_mask( |
| attention_mask, input_shape, inputs_embeds.device |
| ) |
| |
|
|
| if self.is_decoder and encoder_attention_mask is not None: |
| encoder_extended_attention_mask = self.invert_attention_mask( |
| encoder_attention_mask |
| ) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| position_bias = None |
| encoder_decoder_position_bias = None |
|
|
| hidden_states = self.dropout(inputs_embeds) |
|
|
| for i, (layer_module, past_key_value) in enumerate( |
| zip(self.block, past_key_values) |
| ): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_outputs = layer_module( |
| hidden_states, |
| attention_mask=extended_attention_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| head_mask=head_mask[i], |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| |
| |
| hidden_states, present_key_value_state = layer_outputs[:2] |
|
|
| if i == 0: |
| |
| |
| position_bias = layer_outputs[3 if output_attentions else 2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[ |
| 5 if output_attentions else 3 |
| ] |
| |
| if use_cache: |
| present_key_value_states = present_key_value_states + ( |
| present_key_value_state, |
| ) |
|
|
| if output_attentions: |
| all_attentions = all_attentions + ( |
| layer_outputs[2], |
| ) |
|
|
| hidden_states = self.final_layer_norm(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| present_key_value_states, |
| all_hidden_states, |
| all_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=present_key_value_states, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| T5_START_DOCSTRING = r""" |
| |
| The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer |
| <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, |
| Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. |
| It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting. |
| |
| This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic |
| methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, |
| pruning heads etc.) |
| |
| This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general |
| usage and behavior. |
| |
| Parameters: |
| config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the configuration. |
| Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. |
| """ |
|
|
| T5_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. |
| T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right |
| and the left. |
| |
| Indices can be obtained using :class:`~transformers.T5Tokenizer`. |
| See :meth:`transformers.PreTrainedTokenizer.encode` and |
| :meth:`transformers.PreTrainedTokenizer.__call__` for detail. |
| |
| To know more on how to prepare :obj:`input_ids` for pretraining take a look a |
| `T5 Training <./t5.html#training>`__. |
| attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| Mask to avoid performing attention on padding token indices. |
| Mask values selected in ``[0, 1]``: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| `What are attention masks? <../glossary.html#attention-mask>`__ |
| decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): |
| Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for |
| :obj:`decoder_input_ids` generation. |
| If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see |
| :obj:`past_key_values`). |
| |
| To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at |
| `T5 Training <./t5.html#training>`__. If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both |
| unset, :obj:`decoder_input_ids` takes the value of :obj:`input_ids`. |
| decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`): |
| Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will |
| also be used by default. |
| encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): |
| Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`) |
| :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of |
| hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
| past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
| Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
| |
| If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
| (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
| instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
| head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): |
| Mask to nullify selected heads of the self-attention modules. |
| Mask values selected in ``[0, 1]``: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
| Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert :obj:`input_ids` indices into associated |
| vectors than the model's internal embedding lookup matrix. |
| decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): |
| Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded |
| representation. |
| If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` have to be input |
| (see :obj:`past_key_values`). |
| This is useful if you want more control over how to convert :obj:`decoder_input_ids` indices into |
| associated vectors than the model's internal embedding lookup matrix. |
| |
| If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both |
| unset, :obj:`decoder_inputs_embeds` takes the value of :obj:`inputs_embeds`. |
| |
| use_cache (:obj:`bool`, `optional`): |
| If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
| decoding (see :obj:`past_key_values`). |
| |
| output_attentions (:obj:`bool`, `optional`): |
| Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned |
| tensors for more detail. |
| output_hidden_states (:obj:`bool`, `optional`): |
| Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for |
| more detail. |
| return_dict (:obj:`bool`, `optional`): |
| Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare T5 Model transformer outputting raw hidden-states" |
| "without any specific head on top.", |
| T5_START_DOCSTRING, |
| ) |
| class T5Model(T5PreTrainedModel): |
| def __init__(self, config: T5Config): |
| super().__init__(config) |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| encoder_config = copy.deepcopy(config) |
| encoder_config.use_cache = False |
| encoder_config.is_encoder_decoder = False |
| self.encoder = T5Stack(encoder_config, self.shared) |
|
|
| decoder_config = copy.deepcopy(config) |
| decoder_config.is_decoder = True |
| decoder_config.is_encoder_decoder = False |
| decoder_config.num_layers = config.num_decoder_layers |
| if self.multiple_decoder: |
| self.decoder_list = [] |
| for i in range(self.decoder_num): |
| self.decoder_list.append(T5Stack(decoder_config, self.shared)) |
| else: |
| self.decoder = T5Stack(decoder_config, self.shared) |
|
|
| self.init_weights() |
|
|
| def get_input_embeddings(self): |
| return self.shared |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.shared = new_embeddings |
| self.encoder.set_input_embeddings(new_embeddings) |
| self.decoder.set_input_embeddings(new_embeddings) |
|
|
| def get_encoder(self): |
| return self.encoder |
|
|
| def get_decoder(self): |
| return self.decoder |
|
|
| def _prune_heads(self, heads_to_prune): |
| """Prunes heads of the model. |
| heads_to_prune: dict of {layer_num: list of heads to prune in this layer} |
| See base class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
| @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) |
| @replace_return_docstrings( |
| output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| input_mask=None, |
| attention_mask=None, |
| decoder_input_ids=None, |
| decoder_attention_mask=None, |
| encoder_outputs=None, |
| only_encoder=False, |
| past_key_values=None, |
| head_mask=None, |
| inputs_embeds=None, |
| decoder_inputs_embeds=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| r""" |
| Returns: |
| |
| Example:: |
| |
| >>> from transformers import T5Tokenizer, T5Model |
| |
| >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') |
| >>> model = T5Model.from_pretrained('t5-small') |
| |
| >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 |
| >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 |
| >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True) |
| |
| >>> last_hidden_states = outputs.last_hidden_state |
| """ |
| if "decoder_past_key_value_states" in kwargs: |
| warnings.warn( |
| "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", |
| FutureWarning, |
| ) |
| past_key_values = kwargs.pop("decoder_past_key_value_states") |
| if "decoder_past_key_values" in kwargs: |
| warnings.warn( |
| "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", |
| FutureWarning, |
| ) |
| past_key_values = kwargs.pop("decoder_past_key_values") |
| assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." |
|
|
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| if encoder_outputs is None: |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| head_mask=head_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| |
| |
| |
| |
| |
| |
| if only_encoder: |
| return encoder_outputs |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| inputs_embeds=decoder_inputs_embeds, |
| past_key_values=past_key_values, |
| encoder_hidden_states=hidden_states, |
| encoder_attention_mask=attention_mask, |
| head_mask=head_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| if not return_dict: |
| return decoder_outputs + encoder_outputs |
|
|
| return Seq2SeqModelOutput( |
| last_hidden_state=decoder_outputs.last_hidden_state, |
| past_key_values=decoder_outputs.past_key_values, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_attentions=decoder_outputs.attentions, |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
| encoder_hidden_states=encoder_outputs.hidden_states, |
| encoder_attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| class HierarchicT5Stack(T5PreTrainedModel): |
| def __init__(self, config, embed_tokens=None, depth=1): |
| super().__init__(config) |
|
|
| self.embed_tokens = embed_tokens |
| self.depth = depth |
|
|
| self.stacks = nn.ModuleList( |
| [T5Stack(config, embed_tokens) for _ in range(depth)] |
| ) |
|
|
| for stack in self.stacks: |
| stack.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| inputs_embeds=None, |
| head_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| is_train = input_ids.shape[1] > 1 and past_key_values is None |
| if is_train: |
| assert input_ids.shape[1] <= self.depth |
| max_depth = input_ids.shape[1] |
| outputs = [ |
| self.stacks[i]( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| inputs_embeds=inputs_embeds, |
| head_mask=head_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| for i in range(max_depth) |
| ] |
| final_output = outputs[max_depth - 1] |
| for i in range(max_depth - 1): |
| final_output[0][:, i, :] = outputs[i][0][:, i, :] |
| else: |
| cur_depth = 0 |
| if past_key_values is not None: |
| cur_depth = past_key_values[0][0].shape[2] |
| |
| final_output = self.stacks[cur_depth]( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| inputs_embeds=inputs_embeds, |
| head_mask=head_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| return final_output |
|
|