| """ |
| @file mm.py |
| @brief This file contains the code for the multimodal model. It is a modified version of the CLIP model from the huggingface transformers library. |
| @author yutangli |
| """ |
| import torch |
| from torch.nn import CrossEntropyLoss |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.models.clip.configuration_clip import CLIPConfig |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.activations import ACT2FN |
| from transformers.utils import logging, ModelOutput |
| from typing import Optional, Union, Tuple, Dict |
| import math |
| from dataclasses import dataclass |
|
|
| from torch import Tensor, device, dtype, nn |
|
|
| from transformers.modeling_utils import ( |
| PreTrainedModel, |
| apply_chunking_to_forward, |
| find_pruneable_heads_and_indices, |
| prune_linear_layer, |
| ) |
|
|
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPastAndCrossAttentions, |
| BaseModelOutputWithPoolingAndCrossAttentions, |
| CausalLMOutputWithCrossAttentions, |
| MaskedLMOutput, |
| MultipleChoiceModelOutput, |
| NextSentencePredictorOutput, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutput, |
| TokenClassifierOutput, |
| ) |
|
|
| from transformers.models.bert.configuration_bert import BertConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| def _make_causal_mask( |
| input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
| ): |
| """ |
| Make causal mask used for bi-directional self-attention. |
| """ |
| bsz, tgt_len = input_ids_shape |
| mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) |
| mask_cond = torch.arange(mask.size(-1), device=device) |
| mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) |
| mask = mask.to(dtype) |
|
|
| if past_key_values_length > 0: |
| mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) |
| return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
|
|
|
|
| |
| def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
| """ |
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
| """ |
| bsz, src_len = mask.size() |
| tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
|
|
| inverted_mask = 1.0 - expanded_mask |
|
|
| return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) |
|
|
|
|
| @dataclass |
| class BaseModelOutput(ModelOutput): |
| """ |
| Base class for model's outputs, with potential hidden states and attentions. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| last_hidden_state: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| intermediate_hidden_state: Optional[Dict[str, torch.FloatTensor]] = None |
|
|
| @dataclass |
| class BaseModelOutputWithPooling(ModelOutput): |
| """ |
| Base class for model's outputs that also contains a pooling of the last hidden states. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): |
| Last layer hidden-state of the first token of the sequence (classification token) after further processing |
| through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns |
| the classification token after processing through a linear layer and a tanh activation function. The linear |
| layer weights are trained from the next sentence prediction (classification) objective during pretraining. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| last_hidden_state: torch.FloatTensor = None |
| pooler_output: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| intermediate_hidden_state: Optional[Dict[str, torch.FloatTensor]] = None |
|
|
|
|
| class BertConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to |
| instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a |
| configuration with the defaults will yield a similar configuration to that of the BERT |
| [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| |
| |
| Args: |
| vocab_size (`int`, *optional*, defaults to 30522): |
| Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the |
| `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. |
| hidden_size (`int`, *optional*, defaults to 768): |
| Dimensionality of the encoder layers and the pooler layer. |
| num_hidden_layers (`int`, *optional*, defaults to 12): |
| Number of hidden layers in the Transformer encoder. |
| num_attention_heads (`int`, *optional*, defaults to 12): |
| Number of attention heads for each attention layer in the Transformer encoder. |
| intermediate_size (`int`, *optional*, defaults to 3072): |
| Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. |
| hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): |
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, |
| `"relu"`, `"silu"` and `"gelu_new"` are supported. |
| hidden_dropout_prob (`float`, *optional*, defaults to 0.1): |
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. |
| attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): |
| The dropout ratio for the attention probabilities. |
| max_position_embeddings (`int`, *optional*, defaults to 512): |
| The maximum sequence length that this model might ever be used with. Typically set this to something large |
| just in case (e.g., 512 or 1024 or 2048). |
| type_vocab_size (`int`, *optional*, defaults to 2): |
| The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. |
| initializer_range (`float`, *optional*, defaults to 0.02): |
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
| layer_norm_eps (`float`, *optional*, defaults to 1e-12): |
| The epsilon used by the layer normalization layers. |
| position_embedding_type (`str`, *optional*, defaults to `"absolute"`): |
| Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For |
| positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to |
| [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). |
| For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models |
| with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). |
| is_decoder (`bool`, *optional*, defaults to `False`): |
| Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. |
| use_cache (`bool`, *optional*, defaults to `True`): |
| Whether or not the model should return the last key/values attentions (not used by all models). Only |
| relevant if `config.is_decoder=True`. |
| classifier_dropout (`float`, *optional*): |
| The dropout ratio for the classification head. |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import BertConfig, BertModel |
| |
| >>> # Initializing a BERT bert-base-uncased style configuration |
| >>> configuration = BertConfig() |
| |
| >>> # Initializing a model (with random weights) from the bert-base-uncased style configuration |
| >>> model = BertModel(configuration) |
| |
| >>> # Accessing the model configuration |
| >>> configuration = model.config |
| ```""" |
| model_type = "bert" |
|
|
| def __init__( |
| self, |
| vocab_size=30522, |
| hidden_size=768, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=3072, |
| hidden_act="gelu", |
| hidden_dropout_prob=0.1, |
| attention_probs_dropout_prob=0.1, |
| max_position_embeddings=512, |
| type_vocab_size=2, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| pad_token_id=0, |
| position_embedding_type="absolute", |
| use_cache=True, |
| classifier_dropout=None, |
| **kwargs, |
| ): |
| super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
|
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.hidden_act = hidden_act |
| self.intermediate_size = intermediate_size |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| self.max_position_embeddings = max_position_embeddings |
| self.type_vocab_size = type_vocab_size |
| self.initializer_range = initializer_range |
| self.layer_norm_eps = layer_norm_eps |
| self.position_embedding_type = position_embedding_type |
| self.use_cache = use_cache |
| self.classifier_dropout = classifier_dropout |
|
|
|
|
| class VisionConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a |
| CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a |
| configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP |
| [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| |
| Args: |
| hidden_size (`int`, *optional*, defaults to 768): |
| Dimensionality of the encoder layers and the pooler layer. |
| intermediate_size (`int`, *optional*, defaults to 3072): |
| Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. |
| num_hidden_layers (`int`, *optional*, defaults to 12): |
| Number of hidden layers in the Transformer encoder. |
| num_attention_heads (`int`, *optional*, defaults to 12): |
| Number of attention heads for each attention layer in the Transformer encoder. |
| image_size (`int`, *optional*, defaults to 224): |
| The size (resolution) of each image. |
| patch_size (`int`, *optional*, defaults to 32): |
| The size (resolution) of each patch. |
| hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): |
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, |
| `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. |
| layer_norm_eps (`float`, *optional*, defaults to 1e-5): |
| The epsilon used by the layer normalization layers. |
| attention_dropout (`float`, *optional*, defaults to 0.0): |
| The dropout ratio for the attention probabilities. |
| initializer_range (`float`, *optional*, defaults to 0.02): |
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
| initializer_factor (`float`, *optional*, defaults to 1): |
| A factor for initializing all weight matrices (should be kept to 1, used internally for initialization |
| testing). |
| |
| Example: |
| |
| ```python |
| >>> from transformers import CLIPVisionConfig, CLIPVisionModel |
| |
| >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration |
| >>> configuration = CLIPVisionConfig() |
| |
| >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration |
| >>> model = CLIPVisionModel(configuration) |
| |
| >>> # Accessing the model configuration |
| >>> configuration = model.config |
| ```""" |
|
|
| model_type = "clip_vision_model" |
|
|
| def __init__( |
| self, |
| hidden_size=768, |
| intermediate_size=3072, |
| projection_dim=512, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| num_channels=3, |
| image_size=224, |
| patch_size=32, |
| hidden_act="quick_gelu", |
| layer_norm_eps=1e-5, |
| attention_dropout=0.0, |
| initializer_range=0.02, |
| initializer_factor=1.0, |
| intermediate_transformer_output = [4, 6, 8], |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.projection_dim = projection_dim |
| self.intermediate_transformer_output = intermediate_transformer_output |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.num_channels = num_channels |
| self.patch_size = patch_size |
| self.image_size = image_size |
| self.initializer_range = initializer_range |
| self.initializer_factor = initializer_factor |
| self.attention_dropout = attention_dropout |
| self.layer_norm_eps = layer_norm_eps |
| self.hidden_act = hidden_act |
|
|
|
|
| class BertEmbeddings(nn.Module): |
| """Construct the embeddings from word and position embeddings.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
| |
| |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| |
| self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
| self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
| |
| self.config = config |
|
|
| def forward( |
| self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 |
| ): |
| if input_ids is not None: |
| input_shape = input_ids.size() |
| else: |
| input_shape = inputs_embeds.size()[:-1] |
|
|
| seq_length = input_shape[1] |
|
|
| if position_ids is None: |
| position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
|
|
| embeddings = inputs_embeds |
|
|
| if self.position_embedding_type == "absolute": |
| position_embeddings = self.position_embeddings(position_ids) |
| embeddings += position_embeddings |
| embeddings = self.LayerNorm(embeddings) |
| embeddings = self.dropout(embeddings) |
| return embeddings |
|
|
|
|
| class VisionEmbeddings(nn.Module): |
| def __init__(self, config: VisionConfig): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.image_size = config.image_size |
| self.patch_size = config.patch_size |
|
|
| self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
|
|
| self.patch_embedding = nn.Conv2d( |
| in_channels=config.num_channels, |
| out_channels=self.embed_dim, |
| kernel_size=self.patch_size, |
| stride=self.patch_size, |
| bias=False, |
| ) |
|
|
| self.num_patches = (self.image_size // self.patch_size) ** 2 |
| self.num_positions = self.num_patches + 1 |
| self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
| self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) |
|
|
| def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| batch_size = pixel_values.shape[0] |
| patch_embeds = self.patch_embedding(pixel_values) |
| patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
|
|
| class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
| embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
| embeddings = embeddings + self.position_embedding(self.position_ids) |
| return embeddings |
|
|
|
|
| class BertSelfAttention(nn.Module): |
| def __init__(self, config, is_cross_attention): |
| super().__init__() |
| self.config = config |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| "The hidden size (%d) is not a multiple of the number of attention " |
| "heads (%d)" % (config.hidden_size, config.num_attention_heads) |
| ) |
| |
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| if is_cross_attention: |
| self.key = nn.Linear(config.encoder_width, self.all_head_size) |
| self.value = nn.Linear(config.encoder_width, self.all_head_size) |
| else: |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| self.max_position_embeddings = config.max_position_embeddings |
| self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
| self.save_attention = False |
| |
| def save_attn_gradients(self, attn_gradients): |
| self.attn_gradients = attn_gradients |
| |
| def get_attn_gradients(self): |
| return self.attn_gradients |
| |
| def save_attention_map(self, attention_map): |
| self.attention_map = attention_map |
| |
| def get_attention_map(self): |
| return self.attention_map |
| |
| def transpose_for_scores(self, x): |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| ): |
| mixed_query_layer = self.query(hidden_states) |
|
|
| |
| |
| |
| is_cross_attention = encoder_hidden_states is not None |
|
|
| if is_cross_attention: |
| key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| attention_mask = encoder_attention_mask |
| elif past_key_value is not None: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| else: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
| past_key_value = (key_layer, value_layer) |
|
|
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| seq_length = hidden_states.size()[1] |
| position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
| position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
| distance = position_ids_l - position_ids_r |
| positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
| positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
| if self.position_embedding_type == "relative_key": |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores |
| elif self.position_embedding_type == "relative_key_query": |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| if attention_mask is not None: |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| |
| if is_cross_attention and self.save_attention: |
| self.save_attention_map(attention_probs) |
| attention_probs.register_hook(self.save_attn_gradients) |
|
|
| |
| |
| attention_probs_dropped = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| attention_probs_dropped = attention_probs_dropped * head_mask |
|
|
| context_layer = torch.matmul(attention_probs_dropped, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| context_layer = context_layer.view(*new_context_layer_shape) |
|
|
| outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
| outputs = outputs + (past_key_value,) |
| return outputs |
|
|
|
|
| class BertSelfOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states, input_tensor): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| class BertAttention(nn.Module): |
| def __init__(self, config, is_cross_attention=False): |
| super().__init__() |
| self.self = BertSelfAttention(config, is_cross_attention) |
| self.output = BertSelfOutput(config) |
| self.pruned_heads = set() |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| heads, index = find_pruneable_heads_and_indices( |
| heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads |
| ) |
|
|
| |
| self.self.query = prune_linear_layer(self.self.query, index) |
| self.self.key = prune_linear_layer(self.self.key, index) |
| self.self.value = prune_linear_layer(self.self.value, index) |
| self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
| |
| self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
| self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| ): |
| self_outputs = self.self( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| past_key_value, |
| output_attentions, |
| ) |
| attention_output = self.output(self_outputs[0], hidden_states) |
| outputs = (attention_output,) + self_outputs[1:] |
| return outputs |
|
|
|
|
| class BertIntermediate(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
| if isinstance(config.hidden_act, str): |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.intermediate_act_fn = config.hidden_act |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| return hidden_states |
|
|
|
|
| class BertOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states, input_tensor): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| class Attention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.embed_dim // self.num_heads |
| if self.head_dim * self.num_heads != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
| f" {self.num_heads})." |
| ) |
| self.scale = self.head_dim**-0.5 |
| self.dropout = config.attention_dropout |
|
|
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| causal_attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| bsz, tgt_len, embed_dim = hidden_states.size() |
|
|
| |
| query_states = self.q_proj(hidden_states) * self.scale |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
| key_states = key_states.view(*proj_shape) |
| value_states = value_states.view(*proj_shape) |
|
|
| src_len = key_states.size(1) |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
|
|
| |
| if causal_attention_mask is not None: |
| if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" |
| f" {causal_attention_mask.size()}" |
| ) |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
| ) |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| if output_attentions: |
| |
| |
| |
| |
| attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
| else: |
| attn_weights_reshaped = None |
|
|
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
| attn_output = torch.bmm(attn_probs, value_states) |
|
|
| if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
| attn_output = attn_output.transpose(1, 2) |
| attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) |
|
|
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights_reshaped |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.activation_fn = ACT2FN[config.hidden_act] |
| self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
| self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.fc1(hidden_states) |
| hidden_states = self.activation_fn(hidden_states) |
| hidden_states = self.fc2(hidden_states) |
| return hidden_states |
|
|
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, config: CLIPConfig): |
| super().__init__() |
| self.embed_dim = config.hidden_size |
| self.self_attn = Attention(config) |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
| self.mlp = MLP(config) |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor, |
| causal_attention_mask: torch.Tensor, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.FloatTensor]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`): attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| `(config.encoder_attention_heads,)`. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| """ |
| residual = hidden_states |
|
|
| hidden_states = self.layer_norm1(hidden_states) |
| hidden_states, attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| causal_attention_mask=causal_attention_mask, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.layer_norm2(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class BertLayer(nn.Module): |
| def __init__(self, config, layer_num): |
| super().__init__() |
| self.config = config |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
| self.attention = BertAttention(config) |
| self.layer_num = layer_num |
| if self.config.add_cross_attention: |
| self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) |
| self.intermediate = BertIntermediate(config) |
| self.output = BertOutput(config) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| mode=None, |
| ): |
| |
| self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
| self_attention_outputs = self.attention( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| output_attentions=output_attentions, |
| past_key_value=self_attn_past_key_value, |
| ) |
| attention_output = self_attention_outputs[0] |
|
|
| outputs = self_attention_outputs[1:-1] |
| present_key_value = self_attention_outputs[-1] |
|
|
| if mode=='multimodal': |
| assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" |
|
|
| cross_attention_outputs = self.crossattention( |
| attention_output, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| output_attentions=output_attentions, |
| ) |
| attention_output = cross_attention_outputs[0] |
| outputs = outputs + cross_attention_outputs[1:-1] |
| layer_output = apply_chunking_to_forward( |
| self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output |
| ) |
| outputs = (layer_output,) + outputs |
|
|
| outputs = outputs + (present_key_value,) |
|
|
| return outputs |
|
|
| def feed_forward_chunk(self, attention_output): |
| intermediate_output = self.intermediate(attention_output) |
| layer_output = self.output(intermediate_output, attention_output) |
| return layer_output |
| |
|
|
| class VisionEncoder(nn.Module): |
| """ |
| Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
| [`CLIPEncoderLayer`]. |
| |
| Args: |
| config: CLIPConfig |
| """ |
|
|
| def __init__(self, config: VisionConfig): |
| super().__init__() |
| self.config = config |
| self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| inputs_embeds, |
| attention_mask: Optional[torch.Tensor] = None, |
| causal_attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| intermediate_hidden_state: Optional[bool] = None |
| ) -> Union[Tuple, BaseModelOutput]: |
| r""" |
| Args: |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Causal mask for the text model. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| encoder_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| intermediate_hidden_state = {} if intermediate_hidden_state else None |
|
|
| hidden_states = inputs_embeds |
| for idx, encoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(encoder_layer), |
| hidden_states, |
| attention_mask, |
| causal_attention_mask, |
| ) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| attention_mask, |
| causal_attention_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
| |
| if intermediate_hidden_state is not None and (idx+1) in self.config.intermediate_transformer_output: |
| key = 'layer_'+str(idx) |
| intermediate_hidden_state[key] = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, intermediate_hidden_state=intermediate_hidden_state, hidden_states=encoder_states, attentions=all_attentions |
| ) |
|
|
|
|
| class BertEncoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| mode='multimodal', |
| ): |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attentions = () if output_attentions else None |
| all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
|
|
| next_decoder_cache = () if use_cache else None |
| |
| for i in range(self.config.num_hidden_layers): |
| layer_module = self.layer[i] |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_head_mask = head_mask[i] if head_mask is not None else None |
| past_key_value = past_key_values[i] if past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| if use_cache: |
| logger.warn( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, past_key_value, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer_module), |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| mode=mode, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| past_key_value, |
| output_attentions, |
| mode=mode, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
| if use_cache: |
| next_decoder_cache += (layer_outputs[-1],) |
| if output_attentions: |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| next_decoder_cache, |
| all_hidden_states, |
| all_self_attentions, |
| all_cross_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=next_decoder_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__(self, config: VisionConfig): |
| super().__init__() |
| self.config = config |
| embed_dim = config.hidden_size |
|
|
| self.embeddings = VisionEmbeddings(config) |
| self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
| self.encoder = VisionEncoder(config) |
| self.post_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| intermediate_hidden_state: Optional[bool] = None |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| r""" |
| Returns: |
| |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if pixel_values is None: |
| raise ValueError("You have to specify pixel_values") |
|
|
| hidden_states = self.embeddings(pixel_values) |
| hidden_states = self.pre_layrnorm(hidden_states) |
|
|
| encoder_outputs = self.encoder( |
| inputs_embeds=hidden_states, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| intermediate_hidden_state=intermediate_hidden_state |
| ) |
|
|
| last_hidden_state = self.post_layrnorm(encoder_outputs[0]) |
| pooled_output = last_hidden_state[:, 0, :] |
|
|
| if not return_dict: |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooled_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| intermediate_hidden_state=encoder_outputs.intermediate_hidden_state |
| ) |
|
|
|
|
| class BertPooler(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.activation = nn.Tanh() |
|
|
| def forward(self, hidden_states): |
| |
| |
| first_token_tensor = hidden_states[:, 0] |
| pooled_output = self.dense(first_token_tensor) |
| pooled_output = self.activation(pooled_output) |
| return pooled_output |
|
|
|
|
| class BertPredictionHeadTransform(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| if isinstance(config.hidden_act, str): |
| self.transform_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.transform_act_fn = config.hidden_act |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.transform_act_fn(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| return hidden_states |
|
|
|
|
| class BertLMPredictionHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.transform = BertPredictionHeadTransform(config) |
|
|
| |
| |
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
| |
| self.decoder.bias = self.bias |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.transform(hidden_states) |
| hidden_states = self.decoder(hidden_states) |
| return hidden_states |
|
|
|
|
| class BertOnlyMLMHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = BertLMPredictionHead(config) |
|
|
| def forward(self, sequence_output): |
| prediction_scores = self.predictions(sequence_output) |
| return prediction_scores |
|
|
|
|
| class VisionTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| |
| |
| supports_gradient_checkpointing = True |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| factor = self.config.initializer_factor |
| if isinstance(module, VisionEmbeddings): |
| factor = self.config.initializer_factor |
| nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) |
| nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) |
| nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) |
| elif isinstance(module, Attention): |
| factor = self.config.initializer_factor |
| in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
| out_proj_std = (module.embed_dim**-0.5) * factor |
| nn.init.normal_(module.q_proj.weight, std=in_proj_std) |
| nn.init.normal_(module.k_proj.weight, std=in_proj_std) |
| nn.init.normal_(module.v_proj.weight, std=in_proj_std) |
| nn.init.normal_(module.out_proj.weight, std=out_proj_std) |
| elif isinstance(module, MLP): |
| factor = self.config.initializer_factor |
| in_proj_std = ( |
| (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
| ) |
| fc_std = (2 * module.config.hidden_size) ** -0.5 * factor |
| nn.init.normal_(module.fc1.weight, std=fc_std) |
| nn.init.normal_(module.fc2.weight, std=in_proj_std) |
| |
| if isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, VisionEncoder): |
| module.gradient_checkpointing = value |
|
|
|
|
| class BertPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = BertConfig |
| base_model_prefix = "bert" |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
| def _init_weights(self, module): |
| """ Initialize the weights """ |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class BertModel(BertPreTrainedModel): |
| """ |
| The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of |
| cross-attention is added between the self-attention layers, following the architecture described in `Attention is |
| all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, |
| Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. |
| argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an |
| input to the forward pass. |
| """ |
|
|
| def __init__(self, config, add_pooling_layer=True): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embeddings = BertEmbeddings(config) |
| |
| self.encoder = BertEncoder(config) |
|
|
| self.pooler = BertPooler(config) if add_pooling_layer else None |
|
|
| self.init_weights() |
| |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.word_embeddings = value |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
| class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
| |
| def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: |
| """ |
| Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
| |
| Arguments: |
| attention_mask (:obj:`torch.Tensor`): |
| Mask with ones indicating tokens to attend to, zeros for tokens to ignore. |
| input_shape (:obj:`Tuple[int]`): |
| The shape of the input to the model. |
| device: (:obj:`torch.device`): |
| The device of the input to the model. |
| |
| Returns: |
| :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. |
| """ |
| |
| |
| if attention_mask.dim() == 3: |
| extended_attention_mask = attention_mask[:, None, :, :] |
| elif attention_mask.dim() == 2: |
| |
| |
| |
| if is_decoder: |
| batch_size, seq_length = input_shape |
|
|
| seq_ids = torch.arange(seq_length, device=device) |
| causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] |
| |
| |
| causal_mask = causal_mask.to(attention_mask.dtype) |
| |
| if causal_mask.shape[1] < attention_mask.shape[1]: |
| prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] |
| causal_mask = torch.cat( |
| [ |
| torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), |
| causal_mask, |
| ], |
| axis=-1, |
| ) |
|
|
| extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] |
| else: |
| extended_attention_mask = attention_mask[:, None, None, :] |
| else: |
| raise ValueError( |
| "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( |
| input_shape, attention_mask.shape |
| ) |
| ) |
|
|
| |
| |
| |
| |
| |
| extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
| return extended_attention_mask |
| |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| is_decoder=False, |
| mode='multimodal', |
| ): |
| r""" |
| encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
| the model is configured as a decoder. |
| encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
| Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
| If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
| (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
| instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
| use_cache (:obj:`bool`, `optional`): |
| If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
| decoding (see :obj:`past_key_values`). |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if is_decoder: |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| else: |
| use_cache = False |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| batch_size, seq_length = input_shape |
| device = input_ids.device |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| batch_size, seq_length = input_shape |
| device = inputs_embeds.device |
| elif encoder_embeds is not None: |
| input_shape = encoder_embeds.size()[:-1] |
| batch_size, seq_length = input_shape |
| device = encoder_embeds.device |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") |
|
|
| |
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| |
| |
| |
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, |
| device, is_decoder) |
|
|
| |
| |
| if encoder_hidden_states is not None: |
| if type(encoder_hidden_states) == list: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() |
| else: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| |
| if type(encoder_attention_mask) == list: |
| encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] |
| elif encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| else: |
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
| |
| if encoder_embeds is None: |
| embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| past_key_values_length=past_key_values_length, |
| ) |
| else: |
| embedding_output = encoder_embeds |
| |
| encoder_outputs = self.encoder( |
| embedding_output, |
| attention_mask=extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| mode=mode, |
| ) |
| sequence_output = encoder_outputs[0] |
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
| if not return_dict: |
| return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
| return BaseModelOutputWithPoolingAndCrossAttentions( |
| last_hidden_state=sequence_output, |
| pooler_output=pooled_output, |
| past_key_values=encoder_outputs.past_key_values, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| cross_attentions=encoder_outputs.cross_attentions, |
| ) |
| |
|
|
|
|
| class BertLMHeadModel(BertPreTrainedModel): |
| |
| _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.bert = BertModel(config, add_pooling_layer=False) |
| self.cls = BertOnlyMLMHead(config) |
|
|
| self.init_weights() |
|
|
| def get_output_embeddings(self): |
| return self.cls.predictions.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.cls.predictions.decoder = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| labels=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| return_logits=False, |
| is_decoder=True, |
| reduction='mean', |
| mode='multimodal', |
| ): |
| r""" |
| encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
| the model is configured as a decoder. |
| encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in |
| ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are |
| ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` |
| past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
| Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
| If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
| (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
| instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
| use_cache (:obj:`bool`, `optional`): |
| If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
| decoding (see :obj:`past_key_values`). |
| Returns: |
| Example:: |
| >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig |
| >>> import torch |
| >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') |
| >>> config = BertConfig.from_pretrained("bert-base-cased") |
| >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
| >>> outputs = model(**inputs) |
| >>> prediction_logits = outputs.logits |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if labels is not None: |
| use_cache = False |
|
|
| outputs = self.bert( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| is_decoder=is_decoder, |
| mode=mode, |
| ) |
| |
| sequence_output = outputs[0] |
| prediction_scores = self.cls(sequence_output) |
| |
| if return_logits: |
| return prediction_scores[:, :-1, :].contiguous() |
|
|
| lm_loss = None |
| if labels is not None: |
| |
| shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() |
| labels = labels[:, 1:].contiguous() |
| loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) |
| lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
| if reduction=='none': |
| lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) |
|
|
| if not return_dict: |
| output = (prediction_scores,) + outputs[2:] |
| return ((lm_loss,) + output) if lm_loss is not None else output |
|
|
| return CausalLMOutputWithCrossAttentions( |
| loss=lm_loss, |
| logits=prediction_scores, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| cross_attentions=outputs.cross_attentions, |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): |
| input_shape = input_ids.shape |
| |
| if attention_mask is None: |
| attention_mask = input_ids.new_ones(input_shape) |
|
|
| |
| if past is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past, |
| "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), |
| "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), |
| "is_decoder": True, |
| } |
|
|
| def _reorder_cache(self, past, beam_idx): |
| reordered_past = () |
| for layer_past in past: |
| reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
| return reordered_past |