| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch MixtureOfTokens model.""" |
|
|
| import math |
| import warnings |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.cuda.amp import autocast |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| from torch.nn.init import trunc_normal_ |
|
|
| from transformers.activations import ACT2FN |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPastAndCrossAttentions, |
| CausalLMOutputWithCrossAttentions, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutputWithPast, |
| TokenClassifierOutput, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer |
| from transformers.utils import ( |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| ) |
| from transformers.utils.model_parallel_utils import assert_device_map, get_device_map |
| from .configuration_mot import MoTConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "jaszczur/mixture_of_tokens" |
| _CONFIG_FOR_DOC = "MoTConfig" |
|
|
|
|
| def with_batch_size_alignment(forward_fn): |
| def _forward(self, x): |
| """assumed ordering (batch, seq_len, dmodel)""" |
| size = x.size(self.sparsity_dim) |
| if size % self.group_size != 0: |
| if self.sparsity_dim == 1: |
| x = x.transpose(0, 1) |
|
|
| x = self.pad(x) |
|
|
| if self.sparsity_dim == 1: |
| x = forward_fn(self, x.transpose(0, 1)) |
| return x[:, :size, :] |
| else: |
| x = forward_fn(self, x) |
| return x[:size, :, :] |
| else: |
| return forward_fn(self, x) |
|
|
| return _forward |
|
|
|
|
| |
| class MoTAttention(nn.Module): |
| def __init__(self, config, is_cross_attention=False, layer_idx=None): |
| super().__init__() |
| self.config = config |
| max_positions = config.max_position_embeddings |
| self.register_buffer( |
| "bias", |
| torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( |
| 1, 1, max_positions, max_positions |
| ), |
| persistent=False, |
| ) |
| self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) |
|
|
| self.embed_dim = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.embed_dim // self.num_heads |
| self.split_size = self.embed_dim |
| if self.head_dim * self.num_heads != self.embed_dim: |
| raise ValueError( |
| f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
| f" {self.num_heads})." |
| ) |
|
|
| self.scale_attn_weights = config.scale_attn_weights |
| self.is_cross_attention = is_cross_attention |
|
|
| |
| self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx |
| self.layer_idx = layer_idx |
| self.reorder_and_upcast_attn = config.reorder_and_upcast_attn |
|
|
| if self.is_cross_attention: |
| self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) |
| self.q_attn = Conv1D(self.embed_dim, self.embed_dim) |
| else: |
| self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) |
| self.c_proj = Conv1D(self.embed_dim, self.embed_dim) |
|
|
| self.attn_dropout = nn.Dropout(config.attn_pdrop) |
| self.resid_dropout = nn.Dropout(config.resid_pdrop) |
| self.is_causal = True |
|
|
| self.pruned_heads = set() |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) |
| index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) |
|
|
| |
| self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) |
| self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) |
|
|
| |
| self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) |
| self.num_heads = self.num_heads - len(heads) |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| def _attn(self, query, key, value, attention_mask=None, head_mask=None): |
| attn_weights = torch.matmul(query, key.transpose(-1, -2)) |
|
|
| if self.scale_attn_weights: |
| attn_weights = attn_weights / torch.full( |
| [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device |
| ) |
|
|
| |
| if self.scale_attn_by_inverse_layer_idx: |
| attn_weights = attn_weights / float(self.layer_idx + 1) |
|
|
| if not self.is_cross_attention: |
| |
| query_length, key_length = query.size(-2), key.size(-2) |
| causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] |
| mask_value = torch.finfo(attn_weights.dtype).min |
| |
| |
| mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) |
| attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) |
|
|
| if attention_mask is not None: |
| |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| |
| attn_weights = attn_weights.type(value.dtype) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| |
| if head_mask is not None: |
| attn_weights = attn_weights * head_mask |
|
|
| attn_output = torch.matmul(attn_weights, value) |
|
|
| return attn_output, attn_weights |
|
|
| def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): |
| |
| bsz, num_heads, q_seq_len, dk = query.size() |
| _, _, k_seq_len, _ = key.size() |
|
|
| |
| attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) |
|
|
| |
| scale_factor = 1.0 |
| if self.scale_attn_weights: |
| scale_factor /= float(value.size(-1)) ** 0.5 |
|
|
| if self.scale_attn_by_inverse_layer_idx: |
| scale_factor /= float(self.layer_idx + 1) |
|
|
| |
| with autocast(enabled=False): |
| q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) |
| attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) |
| attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) |
|
|
| if not self.is_cross_attention: |
| |
| query_length, key_length = query.size(-2), key.size(-2) |
| causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] |
| mask_value = torch.finfo(attn_weights.dtype).min |
| |
| |
| mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) |
| attn_weights = torch.where(causal_mask, attn_weights, mask_value) |
|
|
| if attention_mask is not None: |
| |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| |
| if attn_weights.dtype != torch.float32: |
| raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") |
| attn_weights = attn_weights.type(value.dtype) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| |
| if head_mask is not None: |
| attn_weights = attn_weights * head_mask |
|
|
| attn_output = torch.matmul(attn_weights, value) |
|
|
| return attn_output, attn_weights |
|
|
| def _split_heads(self, tensor, num_heads, attn_head_size): |
| """ |
| Splits hidden_size dim into attn_head_size and num_heads |
| """ |
| new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) |
| tensor = tensor.view(new_shape) |
| return tensor.permute(0, 2, 1, 3) |
|
|
| def _merge_heads(self, tensor, num_heads, attn_head_size): |
| """ |
| Merges attn_head_size dim and num_attn_heads dim into hidden_size |
| """ |
| tensor = tensor.permute(0, 2, 1, 3).contiguous() |
| new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) |
| return tensor.view(new_shape) |
|
|
| def forward( |
| self, |
| hidden_states: Optional[Tuple[torch.FloatTensor]], |
| layer_past: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = False, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: |
| if encoder_hidden_states is not None: |
| if not hasattr(self, "q_attn"): |
| raise ValueError( |
| "If class is used as cross attention, the weights `q_attn` have to be defined. " |
| "Please make sure to instantiate class with `MoTAttention(..., is_cross_attention=True)`." |
| ) |
|
|
| query = self.q_attn(hidden_states) |
| key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) |
| attention_mask = encoder_attention_mask |
| else: |
| query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) |
|
|
| query = self._split_heads(query, self.num_heads, self.head_dim) |
| key = self._split_heads(key, self.num_heads, self.head_dim) |
| value = self._split_heads(value, self.num_heads, self.head_dim) |
|
|
| if layer_past is not None: |
| past_key, past_value = layer_past |
| key = torch.cat((past_key, key), dim=-2) |
| value = torch.cat((past_value, value), dim=-2) |
|
|
| if use_cache is True: |
| present = (key, value) |
| else: |
| present = None |
|
|
| if self.reorder_and_upcast_attn: |
| attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) |
| else: |
| attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) |
|
|
| attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) |
| attn_output = self.c_proj(attn_output) |
| attn_output = self.resid_dropout(attn_output) |
|
|
| outputs = (attn_output, present) |
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class MoTMLP(nn.Module): |
| r""" |
| Implementation of the Mixture of Tokens Sparse MLP module. |
| """ |
|
|
| def __init__(self, inner_dim: int, config: MoTConfig, sparsity_dim: int = 0, init_type: str = "kaiming_uniform"): |
| super().__init__() |
|
|
| self.d_model: int = config.n_embd |
| self.d_ff: int = config.n_inner if inner_dim is None else inner_dim |
| self.n_expert: int = config.n_expert |
| self.group_size: int = config.group_size |
| self.sparsity_dim: int = sparsity_dim |
| self.expert_size: int = config.expert_size |
| self.temperature: float = config.temperature |
| self.act = ACT2FN[config.activation_function] |
|
|
| self.init_type: str = init_type |
| self.init_scale: float = config.init_scale |
|
|
| self.emit_softmax_over_experts: bool = config.emit_softmax_over_experts |
| self.use_discrete_routing: bool = config.use_discrete_routing |
|
|
| if self.n_expert is not None: |
| if self.d_ff % self.n_expert: |
| self.d_ff += self.n_expert - (self.d_ff % self.n_expert) |
| warnings.warn("d_ff should be divisible by n_expert, padding d_ff to be divisible by n_expert") |
| self.expert_size = self.d_ff // self.n_expert |
| elif self.expert_size is not None: |
| if self.d_ff % self.expert_size: |
| self.d_ff += self.expert_size - (self.d_ff % self.expert_size) |
| warnings.warn("d_ff should be divisible by expert_size, padding d_ff to be divisible by expert_size") |
| self.n_expert = self.d_ff // self.expert_size |
| else: |
| raise ValueError("Either expert_size or n_expert should be provided") |
|
|
| self.lin1 = nn.Parameter( |
| self.get_init_weight( |
| (self.n_expert, self.d_model, self.expert_size), |
| fan_in=self.d_model, |
| init_type=self.init_type, |
| scale=self.init_scale, |
| ) |
| ) |
|
|
| self.lin2 = nn.Parameter( |
| self.get_init_weight( |
| (self.n_expert, self.expert_size, self.d_model), |
| fan_in=self.expert_size, |
| init_type=self.init_type, |
| scale=self.init_scale, |
| ) |
| ) |
|
|
| self.controller = nn.Parameter( |
| self.get_init_weight( |
| (self.d_model, self.n_expert), |
| fan_in=self.d_model, |
| init_type=self.init_type, |
| scale=self.init_scale, |
| ) |
| ) |
| self.dropout = nn.Dropout(config.resid_pdrop) |
|
|
| @staticmethod |
| def argmax_one_hot(x: torch.Tensor, dim: int): |
| max_values, _ = x.max(dim=dim, keepdim=True) |
| return torch.where( |
| condition=x == max_values, |
| input=torch.Tensor([1.0]).to(dtype=x.dtype, device=x.device), |
| other=torch.Tensor([0.0]).to(dtype=x.dtype, device=x.device), |
| ) |
|
|
| def get_init_weight(self, shape, fan_in, init_type, scale, dtype=torch.float32): |
| if init_type == "kaiming_uniform": |
| return self.init_kaiming_uniform(shape=shape, fan_in=fan_in, scale=scale, dtype=dtype) |
| elif init_type == "truncated_normal": |
| return self.init_truncated_normal(shape=shape, fan_in=fan_in, scale=scale, dtype=dtype) |
| else: |
| raise ValueError(f"Unknown init_type: {init_type}") |
|
|
| @staticmethod |
| def init_kaiming_uniform(shape, fan_in, scale, dtype=torch.float32): |
| range_ = scale * (3 / fan_in) ** 0.5 |
| return torch.zeros(shape, dtype=dtype).uniform_(-range_, range_) |
|
|
| @staticmethod |
| def init_truncated_normal(shape, fan_in, scale, dtype=torch.float32): |
| std = (scale / fan_in) ** 0.5 |
| low = -2 * scale |
| high = 2 * scale |
| t = torch.zeros(shape, dtype=dtype) |
| return trunc_normal_(t, mean=0.0, std=std, a=low, b=high) |
|
|
| @staticmethod |
| def stable_softmax_temperature(x: torch.Tensor, temperature: float, dim: int = -1) -> torch.Tensor: |
| return F.softmax(x / temperature, dim=dim) |
|
|
| def pad(self, x): |
| size = x.size(0) |
| ceiling = torch.ceil(torch.tensor(size / self.group_size).float()) |
| new_batch_size = self.group_size * ceiling.int().item() |
| padding_size = new_batch_size - size |
| logger.debug("Padding batch size from %d to %d", size, new_batch_size) |
|
|
| |
| zero_sequence = torch.zeros_like(x[0:1]) |
| padding_sequences = zero_sequence.repeat(padding_size, 1, 1) |
|
|
| return torch.cat([x, padding_sequences], dim=0) |
|
|
| @with_batch_size_alignment |
| def forward(self, x): |
| x = self.group_tokens(x) |
| merge_weights, emit_weights = self.calculate_mixed_tokens_with_weights(x) |
| x = self.merge_map_emit(x, merge_weights, emit_weights) |
| x = self.redistribute_tokens(x) |
| x = self.dropout(x) |
| return x |
|
|
| def group_tokens(self, x): |
| """ |
| Reshape code so the axis to split into groups is on position 1, and then group over said axis. |
| e.g.: |
| - if we group tokens from different sequences in a batch (sparsity = 0), we need to put the batch dimension to position 1. |
| - if we group tokens within one sequence, the dimension to split into groups is already on position 1, hence we leave it as is. |
| |
| free_dimension is the dimension on position 0 after reshape |
| split_dimension is the dimension on position 1 - the one to split into groups |
| |
| :param x: normal input tensor of shape (batch, seq_len, dmodel) |
| :return: x of shape (free_dimension, split_dimension // group_size, group_size , dmodel) |
| """ |
| assert len(x.shape) == 3, "incorrect shape of a tensor, expected a 3D tensor" |
| assert ( |
| x.size(-1) == self.d_model |
| ), f"expected the last dimension of input tensor to be d_model = {self.d_model}" |
|
|
| if self.sparsity_dim == 0: |
| x = x.transpose(0, 1) |
| elif self.sparsity_dim != 1: |
| raise NotImplementedError |
|
|
| free_dimension = x.size(1) |
| assert ( |
| free_dimension % self.group_size == 0 |
| ), f"free dimension = {free_dimension} should be divisible by group size = {self.group_size}" |
|
|
| x = x.view(x.size(0), -1, self.group_size, self.d_model) |
| return x |
|
|
| def redistribute_tokens(self, x): |
| """ |
| An inverse operation to group_tokens. |
| """ |
| assert len(x.shape) == 4, "incorrect shape of a tensor, expected a 4D tensor" |
|
|
| x = x.view(x.size(0), -1, self.d_model) |
| if self.sparsity_dim == 0: |
| x = x.transpose(0, 1) |
| elif self.sparsity_dim != 1: |
| raise NotImplementedError |
|
|
| return x |
|
|
| def calculate_mixed_tokens_with_weights(self, x): |
| """ |
| This function calculates merge and emit weights based on the input tensor, using a controller matrix. |
| The merge weights determine the aggregation of tokens within a group, and emit weights govern the redistribution |
| of the aggregated token back to the original tokens. Temperature scaling is applied to the logits, and optional |
| discrete routing can be used to obtain one-hot representations of the weights. |
| """ |
| |
| merge_logits = torch.matmul(x, self.controller) |
| |
|
|
| |
| temp_merge = self.temperature |
| temp_emit = self.temperature |
|
|
| merge_softmax_dim = -2 |
| emit_softmax_dim = -1 if self.emit_softmax_over_experts else -2 |
|
|
| merge_weights = self.stable_softmax_temperature(merge_logits, temp_merge, dim=merge_softmax_dim) |
|
|
| |
| if isinstance(temp_merge, torch.nn.Parameter) or self.emit_softmax_over_experts: |
| emit_weights = self.stable_softmax_temperature(merge_logits, temp_emit, dim=emit_softmax_dim) |
| else: |
| emit_weights = merge_weights |
|
|
| if self.use_discrete_routing: |
| merge_weights = self.argmax_one_hot(merge_weights, dim=merge_softmax_dim) |
| emit_weights = self.argmax_one_hot(emit_weights, dim=emit_softmax_dim) |
| return merge_weights, emit_weights |
|
|
| def merge_map_emit(self, x, merge_weights, emit_weights): |
| """ |
| :param x: input reshaped to (free_dimension, split_dimension // group_size, group_size, dmodel) |
| :param merge_weights: weights for merging tokens within a group, shape (free_dimension, split_dimension // group_size, group_size, n_expert) |
| :param emit_weights: weights for emitting tokens within a group, shape (free_dimension, split_dimension // group_size, group_size, n_expert) |
| :return: tensor of token updates of shape (free_dimension, split_dimension // group_size, group_size, dmodel) |
| """ |
| x = torch.matmul( |
| merge_weights.transpose(-1, -2), |
| x, |
| ) |
| |
| x = torch.bmm(x.view(-1, self.n_expert, self.d_model).transpose(0, 1), self.lin1) |
| x = self.act(x) |
| |
| x = torch.bmm(x, self.lin2) |
| |
|
|
| |
| |
| |
| x = torch.matmul( |
| emit_weights, |
| x.view(x.size(0), emit_weights.size(0), -1, self.d_model).permute(1, 2, 0, 3), |
| ) |
|
|
| return x |
|
|
|
|
| MoT_ATTENTION_CLASSES = { |
| "eager": MoTAttention, |
| } |
|
|
|
|
| |
| class MoTBlock(nn.Module): |
| def __init__(self, config, layer_idx=None): |
| super().__init__() |
| hidden_size = config.hidden_size |
| inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size |
| attention_class = MoT_ATTENTION_CLASSES[config._attn_implementation] |
|
|
| self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| self.attn = attention_class(config=config, layer_idx=layer_idx) |
| self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
|
| if config.add_cross_attention: |
| self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) |
| self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
|
| self.mlp = MoTMLP(inner_dim, config) |
|
|
| def forward( |
| self, |
| hidden_states: Optional[Tuple[torch.FloatTensor]], |
| layer_past: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = False, |
| output_attentions: Optional[bool] = False, |
| ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: |
| residual = hidden_states |
| hidden_states = self.ln_1(hidden_states) |
| attn_outputs = self.attn( |
| hidden_states, |
| layer_past=layer_past, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| attn_output = attn_outputs[0] |
| outputs = attn_outputs[1:] |
| |
| hidden_states = attn_output + residual |
|
|
| if encoder_hidden_states is not None: |
| |
| if not hasattr(self, "crossattention"): |
| raise ValueError( |
| f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " |
| "cross-attention layers by setting `config.add_cross_attention=True`" |
| ) |
| residual = hidden_states |
| hidden_states = self.ln_cross_attn(hidden_states) |
| cross_attn_outputs = self.crossattention( |
| hidden_states, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| output_attentions=output_attentions, |
| ) |
| attn_output = cross_attn_outputs[0] |
| |
| hidden_states = residual + attn_output |
| outputs = outputs + cross_attn_outputs[2:] |
|
|
| residual = hidden_states |
| hidden_states = self.ln_2(hidden_states) |
| feed_forward_hidden_states = self.mlp(hidden_states) |
| |
| hidden_states = residual + feed_forward_hidden_states |
|
|
| if use_cache: |
| outputs = (hidden_states,) + outputs |
| else: |
| outputs = (hidden_states,) + outputs[1:] |
|
|
| return outputs |
|
|
|
|
| |
| class MoTPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = MoTConfig |
| base_model_prefix = "transformer" |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MoTBlock"] |
| _skip_keys_device_placement = "past_key_values" |
|
|
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
|
|
| |
| def _init_weights(self, module): |
| """Initialize the weights.""" |
| if isinstance(module, (nn.Linear, Conv1D)): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| |
| |
| |
| |
| |
| |
| for name, p in module.named_parameters(): |
| if name == "c_proj.weight": |
| |
| p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) |
|
|
|
|
| MOT_START_DOCSTRING = r""" |
| |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`MoTConfig`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
| MOT_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): |
| `input_ids_length` = `sequence_length` if `past_key_values` is `None` else |
| `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input |
| sequence tokens in the vocabulary. |
| |
| If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as |
| `input_ids`. |
| |
| Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): |
| Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see |
| `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have |
| their past given to this model should not be passed as `input_ids` as they have already been computed. |
| attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for |
| `past_key_values`. In other words, the `attention_mask` always has to have the length: |
| `len(past_key_values) + len(input_ids)` |
| |
| [What are attention masks?](../glossary#attention-mask) |
| token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): |
| Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
| 1]`: |
| |
| - 0 corresponds to a *sentence A* token, |
| - 1 corresponds to a *sentence B* token. |
| |
| [What are token type IDs?](../glossary#token-type-ids) |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.max_position_embeddings - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
| Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| |
| If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see |
| `past_key_values`). |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| PARALLELIZE_DOCSTRING = r""" |
| This is an experimental feature and is a subject to change at a moment's notice. |
| |
| Uses a device map to distribute attention modules of the model across several devices. If no device map is given, |
| it will evenly distribute blocks across all devices. |
| |
| Args: |
| device_map (`Dict[int, list]`, optional, defaults to None): |
| A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always |
| automatically mapped to the first device (for esoteric reasons). That means that the first device should |
| have fewer attention modules mapped to it than other devices. For reference, the mot models have the |
| following number of attention modules: |
| |
| - mot: 12 |
| - mot-medium: 24 |
| - mot-large: 36 |
| - mot-xl: 48 |
| |
| Example: |
| |
| ```python |
| # Here is an example of a device map on a machine with 4 GPUs using mot-xl, which has a total of 48 attention modules: |
| model = MoTLMHeadModel.from_pretrained("jaszczur/mixture_of_tokens") |
| device_map = { |
| 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], |
| 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], |
| 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], |
| 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], |
| } |
| model.parallelize(device_map) |
| ``` |
| """ |
| DEPARALLELIZE_DOCSTRING = r""" |
| Moves the model to cpu from a model parallel state. |
| |
| Example: |
| |
| ```python |
| # On a 4 GPU machine with mot-large: |
| model = MoTLMHeadModel.from_pretrained("jaszczur/mixture_of_tokens") |
| device_map = { |
| 0: [0, 1, 2, 3, 4, 5, 6, 7], |
| 1: [8, 9, 10, 11, 12, 13, 14, 15], |
| 2: [16, 17, 18, 19, 20, 21, 22, 23], |
| 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], |
| } |
| model.parallelize(device_map) # Splits the model across several devices |
| model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() |
| ``` |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare MOT Model transformer outputting raw hidden-states without any specific head on top.", |
| MOT_START_DOCSTRING, |
| ) |
| class MoTModel(MoTPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.embed_dim = config.hidden_size |
|
|
| self.wte = nn.Embedding(config.vocab_size, self.embed_dim) |
| self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) |
|
|
| self.drop = nn.Dropout(config.embd_pdrop) |
| self.h = nn.ModuleList([MoTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
| self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings(PARALLELIZE_DOCSTRING) |
| def parallelize(self, device_map=None): |
| |
| warnings.warn( |
| "`MoTModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" |
| " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" |
| " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," |
| " ...}", |
| FutureWarning, |
| ) |
| self.device_map = ( |
| get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map |
| ) |
| assert_device_map(self.device_map, len(self.h)) |
| self.model_parallel = True |
| self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) |
| self.last_device = "cuda:" + str(max(self.device_map.keys())) |
| self.wte = self.wte.to(self.first_device) |
| self.wpe = self.wpe.to(self.first_device) |
| |
| for k, v in self.device_map.items(): |
| for block in v: |
| cuda_device = "cuda:" + str(k) |
| self.h[block] = self.h[block].to(cuda_device) |
| |
| self.ln_f = self.ln_f.to(self.last_device) |
|
|
| @add_start_docstrings(DEPARALLELIZE_DOCSTRING) |
| def deparallelize(self): |
| warnings.warn( |
| "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", |
| FutureWarning, |
| ) |
| self.model_parallel = False |
| self.device_map = None |
| self.first_device = "cpu" |
| self.last_device = "cpu" |
| self.wte = self.wte.to("cpu") |
| self.wpe = self.wpe.to("cpu") |
| for index in range(len(self.h)): |
| self.h[index] = self.h[index].to("cpu") |
| self.ln_f = self.ln_f.to("cpu") |
| torch.cuda.empty_cache() |
|
|
| def get_input_embeddings(self): |
| return self.wte |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.wte = new_embeddings |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.h[layer].attn.prune_heads(heads) |
|
|
| @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutputWithPastAndCrossAttentions, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| batch_size = input_ids.shape[0] |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| batch_size = inputs_embeds.shape[0] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if token_type_ids is not None: |
| token_type_ids = token_type_ids.view(-1, input_shape[-1]) |
|
|
| if past_key_values is None: |
| past_length = 0 |
| past_key_values = tuple([None] * len(self.h)) |
| else: |
| past_length = past_key_values[0][0].size(-2) |
| if position_ids is None: |
| position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) |
| position_ids = position_ids.unsqueeze(0) |
|
|
| |
| if attention_mask is not None: |
| if batch_size <= 0: |
| raise ValueError("batch_size has to be defined and > 0") |
| attention_mask = attention_mask.view(batch_size, -1) |
| |
| |
| |
| |
| |
| attention_mask = attention_mask[:, None, None, :] |
|
|
| |
| |
| |
| |
| |
| attention_mask = attention_mask.to(dtype=self.dtype) |
| attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min |
|
|
| |
| |
| if self.config.add_cross_attention and encoder_hidden_states is not None: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| if encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| else: |
| encoder_attention_mask = None |
|
|
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.n_layer) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.wte(input_ids) |
| position_embeds = self.wpe(position_ids) |
| hidden_states = inputs_embeds + position_embeds |
|
|
| if token_type_ids is not None: |
| token_type_embeds = self.wte(token_type_ids) |
| hidden_states = hidden_states + token_type_embeds |
|
|
| hidden_states = self.drop(hidden_states) |
|
|
| output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| presents = () if use_cache else None |
| all_self_attentions = () if output_attentions else None |
| all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
| all_hidden_states = () if output_hidden_states else None |
| for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
| |
| if self.model_parallel: |
| torch.cuda.set_device(hidden_states.device) |
| |
| if layer_past is not None: |
| layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) |
| |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(hidden_states.device) |
| if isinstance(head_mask, torch.Tensor): |
| head_mask = head_mask.to(hidden_states.device) |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| outputs = self._gradient_checkpointing_func( |
| block.__call__, |
| hidden_states, |
| None, |
| attention_mask, |
| head_mask[i], |
| encoder_hidden_states, |
| encoder_attention_mask, |
| use_cache, |
| output_attentions, |
| ) |
| else: |
| outputs = block( |
| hidden_states, |
| layer_past=layer_past, |
| attention_mask=attention_mask, |
| head_mask=head_mask[i], |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = outputs[0] |
| if use_cache is True: |
| presents = presents + (outputs[1],) |
|
|
| if output_attentions: |
| all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) |
| if self.config.add_cross_attention: |
| all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) |
|
|
| |
| if self.model_parallel: |
| for k, v in self.device_map.items(): |
| if i == v[-1] and "cuda:" + str(k) != self.last_device: |
| hidden_states = hidden_states.to("cuda:" + str(k + 1)) |
|
|
| hidden_states = self.ln_f(hidden_states) |
|
|
| hidden_states = hidden_states.view(output_shape) |
| |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] |
| if v is not None |
| ) |
|
|
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=presents, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The MOT Model transformer with a language modeling head on top (linear layer with weights tied to the input |
| embeddings). |
| """, |
| MOT_START_DOCSTRING, |
| ) |
| class MoTLMHeadModel(MoTPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer = MoTModel(config) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings(PARALLELIZE_DOCSTRING) |
| def parallelize(self, device_map=None): |
| warnings.warn( |
| "`MoTLMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" |
| " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" |
| " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" |
| " 0, 'transformer.h.1': 1, ...}", |
| FutureWarning, |
| ) |
| self.device_map = ( |
| get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) |
| if device_map is None |
| else device_map |
| ) |
| assert_device_map(self.device_map, len(self.transformer.h)) |
| self.transformer.parallelize(self.device_map) |
| self.lm_head = self.lm_head.to(self.transformer.first_device) |
| self.model_parallel = True |
|
|
| @add_start_docstrings(DEPARALLELIZE_DOCSTRING) |
| def deparallelize(self): |
| warnings.warn( |
| "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", |
| FutureWarning, |
| ) |
| self.transformer.deparallelize() |
| self.transformer = self.transformer.to("cpu") |
| self.lm_head = self.lm_head.to("cpu") |
| self.model_parallel = False |
| torch.cuda.empty_cache() |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| token_type_ids = kwargs.get("token_type_ids", None) |
| |
| if past_key_values: |
| past_length = past_key_values[0][0].shape[2] |
|
|
| |
| if input_ids.shape[1] > past_length: |
| remove_prefix_length = past_length |
| else: |
| |
| remove_prefix_length = input_ids.shape[1] - 1 |
|
|
| input_ids = input_ids[:, remove_prefix_length:] |
| if token_type_ids is not None: |
| token_type_ids = token_type_ids[:, -input_ids.shape[1] :] |
|
|
| attention_mask = kwargs.get("attention_mask", None) |
| position_ids = kwargs.get("position_ids", None) |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
| else: |
| position_ids = None |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| } |
| ) |
|
|
| return model_inputs |
|
|
| @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=CausalLMOutputWithCrossAttentions, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
| `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
| are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.transformer( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
|
|
| |
| if self.model_parallel: |
| torch.cuda.set_device(self.transformer.first_device) |
| hidden_states = hidden_states.to(self.lm_head.weight.device) |
|
|
| lm_logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| |
| labels = labels.to(lm_logits.device) |
| |
| shift_logits = lm_logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
| if not return_dict: |
| output = (lm_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return CausalLMOutputWithCrossAttentions( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| cross_attentions=transformer_outputs.cross_attentions, |
| ) |
|
|
| @staticmethod |
| def _reorder_cache( |
| past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
| ) -> Tuple[Tuple[torch.Tensor]]: |
| """ |
| This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
| [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
| beam_idx at every generation step. |
| """ |
| return tuple( |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) |
| for layer_past in past_key_values |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The MoT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for |
| RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the |
| input embeddings, the classification head takes as input the input of a specified classification token index in the |
| input sequence). |
| """, |
| MOT_START_DOCSTRING, |
| ) |
| @add_start_docstrings( |
| """ |
| The MOT Model transformer with a sequence classification head on top (linear layer). |
| |
| [`MoTForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
| (e.g. GPT-1) do. |
| |
| Since it does classification on the last token, it requires to know the position of the last token. If a |
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
| each row of the batch). |
| """, |
| MOT_START_DOCSTRING, |
| ) |
| class MoTForSequenceClassification(MoTPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.transformer = MoTModel(config) |
| self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=SequenceClassifierOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.transformer( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size, sequence_length = input_ids.shape[:2] |
| else: |
| batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
| assert ( |
| self.config.pad_token_id is not None or batch_size == 1 |
| ), "Cannot handle batch sizes > 1 if no padding token is defined." |
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
| logger.warning( |
| f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
| "unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
| ) |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| MOT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for |
| Named-Entity-Recognition (NER) tasks. |
| """, |
| MOT_START_DOCSTRING, |
| ) |
| class MoTForTokenClassification(MoTPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.transformer = MoTModel(config) |
| if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: |
| classifier_dropout = config.classifier_dropout |
| elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: |
| classifier_dropout = config.hidden_dropout |
| else: |
| classifier_dropout = 0.1 |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING) |
| |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=TokenClassifierOutput, |
| config_class=_CONFIG_FOR_DOC, |
| expected_loss=0.25, |
| expected_output=[ |
| "Lead", |
| "Lead", |
| "Lead", |
| "Position", |
| "Lead", |
| "Lead", |
| "Lead", |
| "Lead", |
| "Lead", |
| "Lead", |
| "Lead", |
| "Lead", |
| ], |
| ) |
| |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, TokenClassifierOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.transformer( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = transformer_outputs[0] |
| hidden_states = self.dropout(hidden_states) |
| logits = self.classifier(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + transformer_outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The MixtureOfTokens transformer with a span classification head on top for extractive question-answering tasks like |
| SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). |
| """, |
| MOT_START_DOCSTRING, |
| ) |
| class MoTForQuestionAnswering(MoTPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.transformer = MoTModel(config) |
| self.qa_outputs = nn.Linear(config.hidden_size, 2) |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=QuestionAnsweringModelOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| start_positions: Optional[torch.LongTensor] = None, |
| end_positions: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, QuestionAnsweringModelOutput]: |
| r""" |
| start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| are not taken into account for computing the loss. |
| end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| are not taken into account for computing the loss. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.transformer( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| logits = self.qa_outputs(sequence_output) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1).to(start_logits.device) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1).to(end_logits.device) |
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| if not return_dict: |
| output = (start_logits, end_logits) + outputs[2:] |
| return ((total_loss,) + output) if total_loss is not None else output |
|
|
| return QuestionAnsweringModelOutput( |
| loss=total_loss, |
| start_logits=start_logits, |
| end_logits=end_logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|