| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| ModelOutput, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| logging, |
| replace_return_docstrings, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| ) |
| from transformers.activations import ACT2FN |
| from transformers.modeling_attn_mask_utils import ( |
| _prepare_4d_attention_mask, |
| _prepare_4d_attention_mask_for_sdpa, |
| _prepare_4d_causal_attention_mask, |
| _prepare_4d_causal_attention_mask_for_sdpa, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| Seq2SeqLMOutput, |
| Seq2SeqModelOutput, |
| ) |
|
|
| from transformers.cache_utils import Cache, HybridCache |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| SequenceClassifierOutputWithPast, |
| TokenClassifierOutput, |
| ) |
|
|
| from typing import List, Optional, Tuple, Union |
|
|
| from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2ForCausalLM,Gemma2DecoderLayer,Gemma2RMSNorm |
| from .configuration_feynmodel import FeynModelConfig,Florence2VisionConfig |
|
|
| from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM |
| import json |
| import math |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import logging |
|
|
| from transformers.utils import ( |
| ModelOutput, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| logging, |
| replace_return_docstrings, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| ) |
|
|
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from collections import OrderedDict |
| from einops import rearrange |
| from timm.models.layers import DropPath, trunc_normal_ |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class MySequential(nn.Sequential): |
| def forward(self, *inputs): |
| for module in self._modules.values(): |
| if type(inputs) == tuple: |
| inputs = module(*inputs) |
| else: |
| inputs = module(inputs) |
| return inputs |
|
|
|
|
| class PreNorm(nn.Module): |
| def __init__(self, norm, fn, drop_path=None): |
| super().__init__() |
| self.norm = norm |
| self.fn = fn |
| self.drop_path = drop_path |
|
|
| def forward(self, x, *args, **kwargs): |
| shortcut = x |
| if self.norm != None: |
| x, size = self.fn(self.norm(x), *args, **kwargs) |
| else: |
| x, size = self.fn(x, *args, **kwargs) |
|
|
| if self.drop_path: |
| x = self.drop_path(x) |
|
|
| x = shortcut + x |
|
|
| return x, size |
|
|
|
|
| class Mlp(nn.Module): |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.net = nn.Sequential(OrderedDict([ |
| ("fc1", nn.Linear(in_features, hidden_features)), |
| ("act", act_layer()), |
| ("fc2", nn.Linear(hidden_features, out_features)) |
| ])) |
|
|
| def forward(self, x, size): |
| return self.net(x), size |
|
|
|
|
| class DepthWiseConv2d(nn.Module): |
| def __init__( |
| self, |
| dim_in, |
| kernel_size, |
| padding, |
| stride, |
| bias=True, |
| ): |
| super().__init__() |
| self.dw = nn.Conv2d( |
| dim_in, dim_in, |
| kernel_size=kernel_size, |
| padding=padding, |
| groups=dim_in, |
| stride=stride, |
| bias=bias |
| ) |
|
|
| def forward(self, x, size): |
| B, N, C = x.shape |
| H, W = size |
| assert N == H * W |
|
|
| x = self.dw(x.transpose(1, 2).view(B, C, H, W)) |
| size = (x.size(-2), x.size(-1)) |
| x = x.flatten(2).transpose(1, 2) |
| return x, size |
|
|
|
|
| class ConvEmbed(nn.Module): |
| """ Image to Patch Embedding |
| """ |
|
|
| def __init__( |
| self, |
| patch_size=7, |
| in_chans=3, |
| embed_dim=64, |
| stride=4, |
| padding=2, |
| norm_layer=None, |
| pre_norm=True |
| ): |
| super().__init__() |
| self.patch_size = patch_size |
|
|
| self.proj = nn.Conv2d( |
| in_chans, embed_dim, |
| kernel_size=patch_size, |
| stride=stride, |
| padding=padding |
| ) |
|
|
| dim_norm = in_chans if pre_norm else embed_dim |
| self.norm = norm_layer(dim_norm) if norm_layer else None |
|
|
| self.pre_norm = pre_norm |
|
|
| def forward(self, x, size): |
| H, W = size |
| if len(x.size()) == 3: |
| if self.norm and self.pre_norm: |
| x = self.norm(x) |
| x = rearrange( |
| x, 'b (h w) c -> b c h w', |
| h=H, w=W |
| ) |
|
|
| x = self.proj(x) |
|
|
| _, _, H, W = x.shape |
| x = rearrange(x, 'b c h w -> b (h w) c') |
| if self.norm and not self.pre_norm: |
| x = self.norm(x) |
|
|
| return x, (H, W) |
|
|
|
|
| class ChannelAttention(nn.Module): |
|
|
| def __init__(self, dim, groups=8, qkv_bias=True): |
| super().__init__() |
|
|
| self.groups = groups |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim) |
|
|
| def forward(self, x, size): |
| B, N, C = x.shape |
|
|
| qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| q = q * (float(N) ** -0.5) |
| attention = q.transpose(-1, -2) @ k |
| attention = attention.softmax(dim=-1) |
| x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) |
| x = x.transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| return x, size |
|
|
|
|
| class ChannelBlock(nn.Module): |
|
|
| def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, |
| drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, |
| conv_at_attn=True, conv_at_ffn=True): |
| super().__init__() |
|
|
| drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
| self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None |
| self.channel_attn = PreNorm( |
| norm_layer(dim), |
| ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), |
| drop_path |
| ) |
| self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None |
| self.ffn = PreNorm( |
| norm_layer(dim), |
| Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), |
| drop_path |
| ) |
|
|
| def forward(self, x, size): |
| if self.conv1: |
| x, size = self.conv1(x, size) |
| x, size = self.channel_attn(x, size) |
|
|
| if self.conv2: |
| x, size = self.conv2(x, size) |
| x, size = self.ffn(x, size) |
|
|
| return x, size |
|
|
|
|
| def window_partition(x, window_size: int): |
| B, H, W, C = x.shape |
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| return windows |
|
|
|
|
| def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): |
| B = batch_size |
| |
| |
| x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
| return x |
|
|
|
|
| class WindowAttention(nn.Module): |
| def __init__(self, dim, num_heads, window_size, qkv_bias=True): |
|
|
| super().__init__() |
| self.dim = dim |
| self.window_size = window_size |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = float(head_dim) ** -0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim) |
|
|
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self, x, size): |
|
|
| H, W = size |
| B, L, C = x.shape |
| assert L == H * W, "input feature has wrong size" |
|
|
| x = x.view(B, H, W, C) |
|
|
| pad_l = pad_t = 0 |
| pad_r = (self.window_size - W % self.window_size) % self.window_size |
| pad_b = (self.window_size - H % self.window_size) % self.window_size |
| x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) |
| _, Hp, Wp, _ = x.shape |
|
|
| x = window_partition(x, self.window_size) |
| x = x.view(-1, self.window_size * self.window_size, C) |
|
|
| |
| |
|
|
| B_, N, C = x.shape |
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| q = q * self.scale |
| attn = (q @ k.transpose(-2, -1)) |
| attn = self.softmax(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| x = self.proj(x) |
|
|
| |
| x = x.view( |
| -1, self.window_size, self.window_size, C |
| ) |
| x = window_reverse(x, B, self.window_size, Hp, Wp) |
|
|
| if pad_r > 0 or pad_b > 0: |
| x = x[:, :H, :W, :].contiguous() |
|
|
| x = x.view(B, H * W, C) |
|
|
| return x, size |
|
|
|
|
| class SpatialBlock(nn.Module): |
|
|
| def __init__(self, dim, num_heads, window_size, |
| mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): |
| super().__init__() |
|
|
| drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
| self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None |
| self.window_attn = PreNorm( |
| norm_layer(dim), |
| WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), |
| drop_path |
| ) |
| self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None |
| self.ffn = PreNorm( |
| norm_layer(dim), |
| Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), |
| drop_path |
| ) |
|
|
| def forward(self, x, size): |
| if self.conv1: |
| x, size = self.conv1(x, size) |
| x, size = self.window_attn(x, size) |
|
|
| if self.conv2: |
| x, size = self.conv2(x, size) |
| x, size = self.ffn(x, size) |
| return x, size |
|
|
|
|
| class DaViT(nn.Module): |
| """ DaViT: Dual-Attention Transformer |
| |
| Args: |
| in_chans (int): Number of input image channels. Default: 3. |
| num_classes (int): Number of classes for classification head. Default: 1000. |
| patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). |
| patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). |
| patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). |
| patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). |
| embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). |
| num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). |
| num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). |
| window_size (int): Window size. Default: 7. |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. |
| drop_path_rate (float): Stochastic depth rate. Default: 0.1. |
| norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
| enable_checkpoint (bool): If True, enable checkpointing. Default: False. |
| conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. |
| conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. |
| """ |
|
|
| def __init__( |
| self, |
| in_chans=3, |
| num_classes=1000, |
| depths=(1, 1, 3, 1), |
| patch_size=(7, 2, 2, 2), |
| patch_stride=(4, 2, 2, 2), |
| patch_padding=(3, 0, 0, 0), |
| patch_prenorm=(False, False, False, False), |
| embed_dims=(64, 128, 192, 256), |
| num_heads=(3, 6, 12, 24), |
| num_groups=(3, 6, 12, 24), |
| window_size=7, |
| mlp_ratio=4., |
| qkv_bias=True, |
| drop_path_rate=0.1, |
| norm_layer=nn.LayerNorm, |
| enable_checkpoint=False, |
| conv_at_attn=True, |
| conv_at_ffn=True, |
| ): |
| super().__init__() |
|
|
| self.num_classes = num_classes |
| self.embed_dims = embed_dims |
| self.num_heads = num_heads |
| self.num_groups = num_groups |
| self.num_stages = len(self.embed_dims) |
| self.enable_checkpoint = enable_checkpoint |
| assert self.num_stages == len(self.num_heads) == len(self.num_groups) |
|
|
| num_stages = len(embed_dims) |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] |
|
|
| depth_offset = 0 |
| convs = [] |
| blocks = [] |
| for i in range(num_stages): |
| conv_embed = ConvEmbed( |
| patch_size=patch_size[i], |
| stride=patch_stride[i], |
| padding=patch_padding[i], |
| in_chans=in_chans if i == 0 else self.embed_dims[i - 1], |
| embed_dim=self.embed_dims[i], |
| norm_layer=norm_layer, |
| pre_norm=patch_prenorm[i] |
| ) |
| convs.append(conv_embed) |
|
|
| block = MySequential( |
| *[ |
| MySequential(OrderedDict([ |
| ( |
| 'spatial_block', SpatialBlock( |
| embed_dims[i], |
| num_heads[i], |
| window_size, |
| drop_path_rate=dpr[depth_offset+j*2], |
| qkv_bias=qkv_bias, |
| mlp_ratio=mlp_ratio, |
| conv_at_attn=conv_at_attn, |
| conv_at_ffn=conv_at_ffn, |
| ) |
| ), |
| ( |
| 'channel_block', ChannelBlock( |
| embed_dims[i], |
| num_groups[i], |
| drop_path_rate=dpr[depth_offset+j*2+1], |
| qkv_bias=qkv_bias, |
| mlp_ratio=mlp_ratio, |
| conv_at_attn=conv_at_attn, |
| conv_at_ffn=conv_at_ffn, |
| ) |
| ) |
| ])) for j in range(depths[i]) |
| ] |
| ) |
| blocks.append(block) |
| depth_offset += depths[i]*2 |
|
|
| self.convs = nn.ModuleList(convs) |
| self.blocks = nn.ModuleList(blocks) |
|
|
| self.norms = norm_layer(self.embed_dims[-1]) |
| self.avgpool = nn.AdaptiveAvgPool1d(1) |
| self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() |
|
|
| self.apply(self._init_weights) |
|
|
| @property |
| def dim_out(self): |
| return self.embed_dims[-1] |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.normal_(m.weight, std=0.02) |
| for name, _ in m.named_parameters(): |
| if name in ['bias']: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.weight, 1.0) |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1.0) |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward_features_unpool(self, x): |
| """ |
| forward until avg pooling |
| Args: |
| x (_type_): input image tensor |
| """ |
| input_size = (x.size(2), x.size(3)) |
| for conv, block in zip(self.convs, self.blocks): |
| x, input_size = conv(x, input_size) |
| if self.enable_checkpoint: |
| x, input_size = checkpoint.checkpoint(block, x, input_size) |
| else: |
| x, input_size = block(x, input_size) |
| return x |
|
|
| def forward_features(self, x): |
| x = self.forward_features_unpool(x) |
|
|
| |
| x = self.avgpool(x.transpose(1, 2)) |
| |
| x = torch.flatten(x, 1) |
| x = self.norms(x) |
|
|
| return x |
|
|
| def forward(self, x): |
| x = self.forward_features(x) |
| x = self.head(x) |
| return x |
| |
| @classmethod |
| def from_config(cls, config): |
| return cls( |
| depths=config.depths, |
| embed_dims=config.dim_embed, |
| num_heads=config.num_heads, |
| num_groups=config.num_groups, |
| patch_size=config.patch_size, |
| patch_stride=config.patch_stride, |
| patch_padding=config.patch_padding, |
| patch_prenorm=config.patch_prenorm, |
| drop_path_rate=config.drop_path_rate, |
| window_size=config.window_size, |
| ) |
|
|
|
|
|
|
|
|
| _CONFIG_FOR_DOC = "FeynModelConfig" |
|
|
| FEYNMODEL_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`FeynModelConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
| FEYNMODEL_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
| `past_key_values`). |
| |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
| returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
| |
| Two formats are allowed: |
| - a [`~cache_utils.Cache`] instance; |
| - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
| cache format. |
| |
| The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
| legacy cache format will be returned. |
| |
| If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
| have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
| of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| the complete sequence length. |
| """ |
|
|
| |
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| min_dtype: float, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| ): |
| |
| |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| |
| |
| |
| causal_mask = attention_mask[:, :, -sequence_length:, :] |
| |
| |
| else: |
| |
| |
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=torch.float32, device=device) |
| |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| |
| if attention_mask is not None: |
| |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
| |
|
|
| return causal_mask |
|
|
| class LearnedAbsolutePositionEmbedding2D(nn.Module): |
| """ |
| This module learns positional embeddings up to a fixed maximum size. |
| """ |
|
|
| def __init__(self, embedding_dim=256, num_pos=50): |
| super().__init__() |
| self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) |
| self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) |
|
|
| def forward(self, pixel_values): |
| """ |
| pixel_values: (batch_size, height, width, num_channels) |
| returns: (batch_size, height, width, embedding_dim * 2) |
| """ |
| if len(pixel_values.shape) != 4: |
| raise ValueError('pixel_values must be a 4D tensor') |
| height, width = pixel_values.shape[1:3] |
| width_values = torch.arange(width, device=pixel_values.device) |
| height_values = torch.arange(height, device=pixel_values.device) |
| x_emb = self.column_embeddings(width_values) |
| y_emb = self.row_embeddings(height_values) |
| |
| pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) |
| |
| pos = pos.permute(2, 0, 1) |
| pos = pos.unsqueeze(0) |
| |
| pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) |
| |
| pos = pos.permute(0, 2, 3, 1) |
| return pos |
|
|
| class PositionalEmbeddingCosine1D(nn.Module): |
| """ |
| This class implements a very simple positional encoding. It follows closely |
| the encoder from the link below: |
| https://pytorch.org/tutorials/beginner/translation_transformer.html |
| Args: |
| embed_dim: The dimension of the embeddings. |
| dropout_prob: The dropout probability. |
| max_seq_len: The maximum length to precompute the positional encodings. |
| """ |
| def __init__( |
| self, |
| embed_dim: int = 512, |
| max_seq_len: int = 1024) -> None: |
| super(PositionalEmbeddingCosine1D, self).__init__() |
| self.embed_dim = embed_dim |
| self.max_seq_len = max_seq_len |
| |
| factor = math.log(10000) |
| denominator = torch.exp( |
| -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) |
| |
| |
| frequencies = \ |
| torch.arange(0, self.max_seq_len) \ |
| .reshape(self.max_seq_len, 1) * denominator |
| pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) |
| |
| pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) |
| pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) |
| |
| self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) |
|
|
| def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| seq_embeds: The sequence embeddings in order. Allowed size: |
| 1. [T, D], where T is the length of the sequence, and D is the |
| frame embedding dimension. |
| 2. [B, T, D], where B is the batch size and T and D are the |
| same as above. |
| Returns a tensor of with the same dimensions as the input: i.e., |
| [1, T, D] or [T, D]. |
| """ |
| shape_len = len(seq_embeds.shape) |
| assert 2 <= shape_len <= 3 |
| len_seq = seq_embeds.size(-2) |
| assert len_seq <= self.max_seq_len |
| pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] |
| |
| if shape_len == 3: |
| pos_embeds = pos_embeds.view( |
| (1, pos_embeds.size(0), pos_embeds.size(1))) |
| return pos_embeds |
|
|
|
|
| class LearnedAbsolutePositionEmbedding1D(nn.Module): |
| """ |
| Learnable absolute positional embeddings for 1D sequences. |
| Args: |
| embed_dim: The dimension of the embeddings. |
| max_seq_len: The maximum length to precompute the positional encodings. |
| """ |
| def __init__( |
| self, |
| embedding_dim: int = 512, |
| num_pos: int = 1024) -> None: |
| super(LearnedAbsolutePositionEmbedding1D, self).__init__() |
| self.embeddings = nn.Embedding(num_pos, embedding_dim) |
| self.num_pos = num_pos |
|
|
| def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| seq_embeds: The sequence embeddings in order. Allowed size: |
| 1. [T, D], where T is the length of the sequence, and D is the |
| frame embedding dimension. |
| 2. [B, T, D], where B is the batch size and T and D are the |
| same as above. |
| Returns a tensor of with the same dimensions as the input: i.e., |
| [1, T, D] or [T, D]. |
| """ |
| shape_len = len(seq_embeds.shape) |
| assert 2 <= shape_len <= 3 |
| len_seq = seq_embeds.size(-2) |
| assert len_seq <= self.num_pos |
| |
| pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) |
| |
| if shape_len == 3: |
| pos_embeds = pos_embeds.view( |
| (1, pos_embeds.size(0), pos_embeds.size(1))) |
| return pos_embeds |
|
|
| def create_git_attention_mask( |
| tgt: torch.Tensor, |
| memory: torch.Tensor, |
| max_length: int |
| ) -> torch.Tensor: |
| |
| batch_size = tgt.size(0) |
| num_tgt = tgt.shape[1] |
| num_memory = memory.shape[1] |
| total_length = num_memory + num_tgt |
|
|
| |
| top_left = torch.zeros((num_memory, num_memory)) |
| top_right = torch.full((num_memory, num_tgt), float(-3.4028e+38)) |
|
|
| |
| bottom_left = torch.zeros((num_tgt, num_memory)) |
|
|
| |
| bottom_right = torch.tril(torch.ones(num_tgt, num_tgt)) |
|
|
| |
| bottom_right = bottom_right.masked_fill(bottom_right == 0, float(-3.4028e+38)) |
| bottom_right = bottom_right.masked_fill(bottom_right == 1, float(0)) |
|
|
| |
| left = torch.cat((top_left, bottom_left), dim=0) |
| right = torch.cat((top_right, bottom_right), dim=0) |
|
|
| |
| full_attention_mask = torch.cat((left, right), dim=1) |
|
|
| |
| padding = torch.full((total_length, max_length - total_length), float(-3.4028e+38)) |
| full_attention_mask = torch.cat((full_attention_mask, padding), dim=1) |
|
|
| |
| full_attention_mask = full_attention_mask[None, None, :, :] |
|
|
| |
| full_attention_mask = full_attention_mask.expand(batch_size, 1, full_attention_mask.size(-2), full_attention_mask.size(-1)) |
|
|
| return full_attention_mask |
|
|
| def get_position_ids_from_binary_attention_mask(mask): |
| """ |
| Extract position IDs from a binary attention mask. |
| |
| Args: |
| mask (torch.Tensor): The attention mask tensor of shape (1, 1, seq_len, seq_len), |
| where 1 indicates allowed attention and 0 indicates blocked attention. |
| |
| Returns: |
| list: A list of lists where each sublist contains the allowed position IDs for each query position. |
| """ |
| |
| _, _, seq_len, _ = mask.shape |
| |
| |
| position_ids = torch.arange(seq_len, dtype=torch.long, device=mask.device) |
| |
| |
| position_ids = position_ids.unsqueeze(0) |
| |
| return position_ids |
|
|
| def ensure_tensor(variable): |
| |
| if isinstance(variable, torch.Tensor): |
| |
| return variable |
| else: |
| |
| try: |
| |
| tensor = torch.tensor(variable) |
| |
| return tensor |
| except Exception as e: |
| print(f"Error converting to tensor: {e}") |
| raise |
|
|
| @add_start_docstrings( |
| "The bare Model outputting raw hidden-states without any specific head on top.", |
| FEYNMODEL_START_DOCSTRING, |
| ) |
| class FeynModel(Gemma2Model): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. |
| Each layer is a [`FeynModelDecoderLayer`] + ['LoraLayer'] for *proj* moduls |
| NB : LoraLayers will be added and activatd on proj modules onpy if pixel_values is not None |
| |
| Args: |
| config: FeynModelConfig |
| """ |
|
|
| def __init__(self, config: FeynModelConfig): |
| super().__init__(config) |
| |
| self.mode='llm' |
| ''' |
| self.image_patch_tokens = int( |
| (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1 |
| ) |
| |
| if config.num_image_with_embedding is not None: |
| self.image_patch_tokens *= config.num_image_with_embedding |
| ''' |
| self.image_patch_tokens = 577 |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| |
| |
| |
| @add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| causal_attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| |
| |
| |
| |
| |
| if cache_position is None: |
| batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) |
| cache_position = torch.zeros((batch_size,), dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device) |
|
|
|
|
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| else: |
| causal_mask = ensure_tensor(causal_attention_mask) |
| position_ids = get_position_ids_from_binary_attention_mask(attention_mask) |
| |
| |
|
|
| if cache_position is None: |
| cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) |
|
|
| if position_ids is None : |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| |
| |
| if not isinstance(position_ids, torch.Tensor): |
| |
| position_ids = torch.tensor(position_ids, dtype=torch.long, device=inputs_embeds.device) |
| |
| |
| |
| hidden_states = inputs_embeds |
|
|
| |
| |
| |
| normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) |
| hidden_states = hidden_states * normalizer |
|
|
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| |
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
| |
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = past_key_values if use_cache else None |
| |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
| |
|
|
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: torch.Tensor, |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_key_values: Cache, |
| output_attentions: bool, |
| ): |
|
|
| |
| |
| |
| |
| |
| if self.config._attn_implementation == "flash_attention_2": |
| return attention_mask |
|
|
| dtype, device = input_tensor.dtype, input_tensor.device |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| if isinstance(past_key_values, HybridCache): |
| target_length = past_key_values.get_max_length() |
| else: |
| target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
|
|
| |
| causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=input_tensor.shape[0], |
| ) |
| |
| return causal_mask |
|
|
|
|
|
|
| class FeynModelForCausalLM(Gemma2ForCausalLM): |
| _tied_weights_keys = ["lm_head.weight"] |
| config_class = FeynModelConfig |
| def __init__(self, config): |
| super().__init__(config) |
| config.vision_config=Florence2VisionConfig.from_dict(config.vision_config) |
| self.model = FeynModel(config) |
| |
| |
| self.vision_tower = DaViT.from_config(config=config.vision_config) |
| self._build_image_projection_layers(config) |
|
|
| self.__causal_attention_mask = None |
| |
| |
| self.post_init() |
|
|
| |
| def _build_image_projection_layers(self, config): |
| image_dim_out = config.vision_config.dim_embed[-1] |
| dim_projection = config.vision_config.projection_dim |
| self.image_projection = nn.Parameter( |
| torch.empty(image_dim_out, dim_projection) |
| ) |
| self.image_proj_norm = nn.LayerNorm(dim_projection) |
| image_pos_embed_config = config.vision_config.image_pos_embed |
| if image_pos_embed_config['type'] == 'learned_abs_2d': |
| self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( |
| embedding_dim=image_dim_out, |
| num_pos=image_pos_embed_config['max_pos_embeddings'] |
| ) |
| else: |
| raise NotImplementedError('Not implemented yet') |
|
|
| self.image_feature_source = config.vision_config.image_feature_source |
|
|
| |
| visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding |
| if visual_temporal_embedding_config['type'] == 'COSINE': |
| self.visual_temporal_embed = PositionalEmbeddingCosine1D( |
| embed_dim=image_dim_out, |
| max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] |
| ) |
| else: |
| raise NotImplementedError('Not implemented yet') |
|
|
| |
|
|
| def _merge_input_ids_with_image_features(self, image_features, inputs_embeds): |
| batch_size, image_token_length = image_features.size()[:-1] |
| device = image_features.device |
| image_attention_mask = torch.ones(batch_size, image_token_length, device=device) |
|
|
| if inputs_embeds is None: |
| return image_features, image_attention_mask |
|
|
| task_prefix_embeds = inputs_embeds |
| task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) |
|
|
| |
| if len(task_prefix_attention_mask.shape) == 3: |
| task_prefix_attention_mask = task_prefix_attention_mask.squeeze(1) |
|
|
| |
| if image_features.size(0) != task_prefix_embeds.size(0): |
| raise ValueError("Batch sizes of image_features and task_prefix_embeds do not match") |
|
|
| |
| if image_features.dim() < task_prefix_embeds.dim(): |
| image_features = image_features.unsqueeze(-1) |
| elif task_prefix_embeds.dim() < image_features.dim(): |
| task_prefix_embeds = task_prefix_embeds.unsqueeze(-1) |
|
|
| |
| if image_features.size(2) != task_prefix_embeds.size(2): |
| |
| raise ValueError("Internal dimensions of image_features and task_prefix_embeds do not match") |
|
|
| inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) |
| attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) |
|
|
| return inputs_embeds, attention_mask |
|
|
| def _encode_image(self, pixel_values): |
| if len(pixel_values.shape) == 4: |
| batch_size, C, H, W = pixel_values.shape |
| T = 1 |
| x = self.vision_tower.forward_features_unpool(pixel_values) |
| else: |
| |
| pixel_values = pixel_values.unsqueeze(0) |
| batch_size, C, H, W = pixel_values.shape |
| T = 1 |
| x = self.vision_tower.forward_features_unpool(pixel_values) |
| |
| if self.image_pos_embed is not None: |
| x = x.view(batch_size * T, -1, x.shape[-1]) |
| num_tokens = x.shape[-2] |
| h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) |
| assert h * w == num_tokens, 'only support square feature maps for now' |
| x = x.view(batch_size * T, h, w, x.shape[-1]) |
| pos_embed = self.image_pos_embed(x) |
| x = x + pos_embed |
| x = x.view(batch_size, T * h*w, x.shape[-1]) |
|
|
| if self.visual_temporal_embed is not None: |
| visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) |
| x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) |
|
|
| x_feat_dict = {} |
|
|
| spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) |
| x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x |
|
|
| temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) |
| x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x |
|
|
| x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] |
| x_feat_dict['last_frame'] = x |
|
|
| new_x = [] |
| for _image_feature_source in self.image_feature_source: |
| if _image_feature_source not in x_feat_dict: |
| raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) |
| new_x.append(x_feat_dict[_image_feature_source]) |
|
|
| x = torch.cat(new_x, dim=1) |
|
|
| x = x @ self.image_projection |
| x = self.image_proj_norm(x) |
|
|
| return x |
| |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, GemmaForCausalLM |
| |
| >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") |
| >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
| |
| >>> prompt = "What is your favorite condiment?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "What is your favorite condiment?" |
| ```""" |
|
|
| |
| if self.training and self.config._attn_implementation != "eager": |
| logger.warning_once( |
| "It is strongly recommended to train FeynModel models with the `eager` attention implementation " |
| f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." |
| ) |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if pixel_values is not None: |
| self.model.mode='vlm' |
| |
| if input_ids is not None: |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| image_features = self._encode_image(pixel_values) |
| inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) |
| causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=8192) |
| causal_attention_mask=causal_attention_mask.to(input_ids.device) |
| self.__causal_attention_mask=causal_attention_mask |
|
|
| |
| if pixel_values is not None: |
| outputs = self.model( |
| input_ids=None, |
| attention_mask=causal_attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| causal_attention_mask=causal_attention_mask, |
| ) |
| else: |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| causal_attention_mask=self.__causal_attention_mask, |
| ) |
| |
| |
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| |
| if self.config.final_logit_softcapping is not None: |
| logits = logits / self.config.final_logit_softcapping |
| logits = torch.tanh(logits) |
| logits = logits * self.config.final_logit_softcapping |
| |
|
|
| logits = logits.float() |
| loss = None |
| if labels is not None: |
| |
| num_image_tokens = self.model.image_patch_tokens |
| shifted_logits = logits[:, num_image_tokens:-1, :].contiguous() |
| labels = labels[:, 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
| if not return_dict: |
| |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
| |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| |
| |
| |
| |
| |
| if past_key_values is not None: |
| if inputs_embeds is not None: |
| input_ids = input_ids[:, -cache_position.shape[0] :] |
| elif input_ids.shape[1] != cache_position.shape[0]: |
| input_ids = input_ids[:, cache_position] |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
| |
| |
| |
| |
| |
| position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
| |
|
|
| |
| if inputs_embeds is not None and cache_position[0] == 0: |
| |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| |
| |
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} |
|
|
| if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: |
| if inputs_embeds is not None and input_ids.size(1)!= 0 : |
| |
| batch_size, sequence_length, _ = inputs_embeds.shape |
| device = inputs_embeds.device |
| |
| else: |
| batch_size, sequence_length = position_ids.shape |
| device = input_ids.device |
| |
|
|
| |
| |
| if hasattr(self.lm_head, 'weight'): |
| |
| if isinstance(self.lm_head.weight, torch.Tensor): |
| dtype = self.lm_head.weight.dtype |
| elif callable(self.lm_head.weight): |
| dtype = self.lm_head.weight().dtype |
| else: |
| raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}") |
| |
| |
| |
| if isinstance(self.lm_head, torch.ao.nn.quantized.dynamic.Linear): |
| |
| weight, bias = self.lm_head._weight_bias() |
| dtype = weight.dtype |
| else: |
| dtype = self.lm_head.weight.dtype |
|
|
| |
| if torch.is_floating_point(torch.empty(0, dtype=dtype)): |
| |
| min_dtype = torch.finfo(torch.float32).min |
| else: |
| min_dtype = torch.iinfo(dtype).min |
| |
|
|
|
|
| attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=past_key_values.get_max_length(), |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=batch_size, |
| ) |
| |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
| |
| def generate( |
| self, |
| input_ids, |
| pixel_values=None, |
| max_length=None, |
| do_sample=True, |
| temperature=0.7, |
| **kwargs |
| ): |
| |
|
|
| if pixel_values is not None: |
| if input_ids is not None: |
| |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
| image_features = self._encode_image(pixel_values) |
| inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) |
| causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=max_length) |
| causal_attention_mask=causal_attention_mask.to(input_ids.device) |
| self.__causal_attention_mask=causal_attention_mask |
| self.model.mode='vlm' |
| result = super().generate( |
| input_ids=None, |
| inputs_embeds=inputs_embeds, |
| max_length=max_length, |
| do_sample=do_sample, |
| temperature=temperature, |
| **kwargs |
| ) |
| |
| else: |
| |
| self.model.mode=='llm' |
| result = super().generate( |
| input_ids=input_ids, |
| |
| max_length=max_length, |
| do_sample=do_sample, |
| temperature=temperature, |
| **kwargs |
| ) |
| self.__causal_attention_mask = None |
|
|
| return result |
| |
|
|