| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings, |
| designed for processing image inputs in vision-language models. |
| |
| This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM): |
| https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py |
| """ |
| from functools import partial |
| from typing import Any, Callable, Mapping, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .ar_module_normalization import create_norm |
| from .ar_network_transformer import TransformerBlock |
| from .log import log |
|
|
|
|
| def get_vit_config(model_name: str) -> Mapping[str, Any]: |
| """ |
| Get the ViT configuration for a given model name. |
| """ |
| if model_name == "pixtral-12b-vit": |
| |
| return dict( |
| dim=1024, |
| num_channels=3, |
| image_size=1024, |
| patch_size=16, |
| rope_theta=10000, |
| ffn_hidden_size=4096, |
| n_layers=24, |
| n_heads=16, |
| n_kv_heads=16, |
| norm_type="rmsnorm", |
| norm_eps=1e-5, |
| image_token_id=10, |
| ) |
| else: |
| raise ValueError(f"Unknown model name: {model_name}") |
|
|
|
|
| def precompute_freqs_cis_2d( |
| dim: int, |
| height: int, |
| width: int, |
| theta: float, |
| ) -> torch.Tensor: |
| """ |
| Precompute 2D complex tensor for rotary position embedding. |
| |
| This function generates a 2D complex tensor used for rotary position embeddings, |
| which helps the model understand spatial relationships in the input image. |
| |
| Args: |
| dim (int): Dimension of the model (typically the hidden size divided by number of heads). |
| height (int): Height of the image in patches. |
| width (int): Width of the image in patches. |
| theta (float): Base value for the angle calculation, controls the frequency range. |
| |
| Returns: |
| torch.Tensor: 2D complex tensor of shape (height, width, dim // 2). |
| """ |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
| h = torch.arange(height, device=freqs.device) |
| w = torch.arange(width, device=freqs.device) |
|
|
| freqs_h = torch.outer(h, freqs[::2]).float() |
| freqs_w = torch.outer(w, freqs[1::2]).float() |
| freqs_2d = torch.cat( |
| [ |
| freqs_h[:, None, :].repeat(1, width, 1), |
| freqs_w[None, :, :].repeat(height, 1, 1), |
| ], |
| dim=-1, |
| ) |
| return torch.polar(torch.ones_like(freqs_2d), freqs_2d) |
|
|
|
|
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| """ |
| Reshape frequency tensor for broadcasting with input tensor. |
| |
| This function ensures that the frequency tensor can be properly broadcast |
| with the input tensor during the rotary embedding process. |
| |
| Args: |
| freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d. |
| x (torch.Tensor): Input tensor to be embedded. |
| |
| Returns: |
| torch.Tensor: Reshaped frequency tensor ready for broadcasting. |
| """ |
| ndim = x.ndim |
| assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}" |
| assert freqs_cis.shape == ( |
| x.shape[1], |
| x.shape[-1], |
| ), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}" |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| return freqs_cis.view(*shape) |
|
|
|
|
| def apply_rotary_emb( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| *args, |
| freqs_cis: torch.Tensor, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply rotary positional embeddings to input tensors. |
| |
| This function applies the rotary positional embeddings to the query and key tensors, |
| which helps the model understand spatial relationships in the input. |
| |
| Args: |
| xq (torch.Tensor): Query tensor. |
| xk (torch.Tensor): Key tensor. |
| freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d. |
| *args: Variable length argument list (unused). |
| **kwargs: Arbitrary keyword arguments (unused). |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. |
| """ |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| class VisionTransformer(nn.Module): |
| """ |
| Vision Transformer model for image processing. |
| |
| This class implements a Vision Transformer that processes images using a patch-based approach |
| and applies transformer layers with rotary position embeddings. |
| |
| Args: |
| dim (int): Dimension of the model (hidden size). |
| num_channels (int): Number of input image channels (e.g., 3 for RGB). |
| patch_size (int): Size of each image patch (e.g., 16x16 pixels). |
| n_layers (int): Number of transformer layers. |
| n_heads (int): Number of attention heads. |
| ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks. |
| norm_type (str): Type of normalization to use (e.g., "rmsnorm"). |
| norm_eps (float): Epsilon value for normalization layers. |
| image_size (int): Size of the input image (assumed square). |
| rope_theta (float): Base value for rotary position embedding calculation. |
| attention_dropout (float): Dropout rate for attention layers. |
| hidden_dropout (float): Dropout rate for hidden layers. |
| image_token_id (int): Token ID for the image token (if present). |
| """ |
|
|
| def __init__( |
| self, |
| dim: int = 1024, |
| num_channels: int = 3, |
| patch_size: int = 16, |
| n_layers: int = 24, |
| n_heads: int = 16, |
| n_kv_heads: int = None, |
| ffn_hidden_size: int = 4096, |
| norm_type: str = "rmsnorm", |
| norm_eps: float = 1e-5, |
| image_size: int = 1024, |
| rope_theta: float = 1000000.0, |
| image_token_id: int = None, |
| ): |
| super().__init__() |
| self.patch_conv = nn.Conv2d( |
| in_channels=num_channels, |
| out_channels=dim, |
| kernel_size=patch_size, |
| stride=patch_size, |
| bias=False, |
| ) |
| self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps) |
| if n_kv_heads is None: |
| n_kv_heads = n_heads |
| layer_args = dict( |
| n_layers=n_layers, |
| n_heads=n_heads, |
| n_kv_heads=n_kv_heads, |
| dim=dim, |
| use_qk_normalization=False, |
| max_seq_len=None, |
| max_batch_size=None, |
| ffn_hidden_size=ffn_hidden_size, |
| norm_type=norm_type, |
| norm_eps=norm_eps, |
| causal_mask=False, |
| head_dim=None, |
| insert_cross_attn=False, |
| attn_type="full", |
| ) |
|
|
| self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args) |
|
|
| head_dim = dim // n_heads |
| assert head_dim % 2 == 0, "ROPE requires even head_dim" |
|
|
| self.dim = dim |
| self.n_heads = n_heads |
| self.max_patches_per_side = image_size // patch_size |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.rope_theta = rope_theta |
| self._freqs_cis: Optional[torch.Tensor] = None |
| self.image_token_id = image_token_id |
|
|
| num_params = self.get_num_params() |
| log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M") |
|
|
| @classmethod |
| def build( |
| cls, |
| config: Mapping[str, Any], |
| ) -> "VisionTransformer": |
| """ |
| Create a Vision Transformer from a configuration dictionary. |
| |
| This class method creates a Vision Transformer from a configuration dictionary, |
| which is typically loaded from a JSON file or other configuration source. |
| |
| Args: |
| config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer. |
| |
| Returns: |
| VisionTransformer: Vision Transformer model instance. |
| """ |
| necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"] |
| missing_keys = [k for k in necessary_keys if k not in config] |
| assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}" |
| return cls( |
| **config, |
| ) |
|
|
| def expand_in_channels(self, new_in_channels: int): |
| """ |
| Expand the input channels of the patch convolution layer. |
| This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel. |
| Note that you should only call this method after the weight is loaded. |
| """ |
| assert ( |
| new_in_channels > self.patch_conv.in_channels |
| ), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels." |
| log.debug( |
| f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s." |
| ) |
| new_conv = nn.Conv2d( |
| in_channels=new_in_channels, |
| out_channels=self.patch_conv.out_channels, |
| kernel_size=self.patch_conv.kernel_size, |
| stride=self.patch_conv.stride, |
| bias=False, |
| ) |
| new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data) |
| new_conv.weight.data[ |
| :, self.patch_conv.in_channels : |
| ].zero_() |
| self.patch_conv = new_conv |
|
|
| @property |
| def device(self) -> torch.device: |
| """Get the device of the model.""" |
| return next(self.parameters()).device |
|
|
| @property |
| def freqs_cis(self) -> torch.Tensor: |
| """ |
| Get or compute the frequency tensor for rotary position embedding. |
| |
| This property lazily initializes and caches the frequency tensor used for |
| rotary position embeddings, ensuring it's on the correct device. |
| |
| Returns: |
| torch.Tensor: The frequency tensor for rotary position embeddings. |
| """ |
| if self._freqs_cis is None: |
| self._freqs_cis = precompute_freqs_cis_2d( |
| dim=self.dim // self.n_heads, |
| height=self.max_patches_per_side, |
| width=self.max_patches_per_side, |
| theta=self.rope_theta, |
| ) |
|
|
| if self._freqs_cis.device != self.device: |
| self._freqs_cis = self._freqs_cis.to(device=self.device) |
|
|
| return self._freqs_cis |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Forward pass of the Vision Transformer. |
| |
| This method processes the input image through the Vision Transformer, |
| including patch embedding, position embedding, and transformer layers. |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size, |
| C is number of channels, and H, W are height and width. |
| |
| Returns: |
| torch.Tensor: Output features of shape (B, N, D), where N is the number of patches |
| and D is the embedding dimension. |
| """ |
|
|
| patch_embeds = self.patch_conv(x) |
| _, _, Hp, Wp = patch_embeds.shape |
| patch_embeds = patch_embeds.flatten(2) |
| patch_embeds = patch_embeds.transpose(1, 2) |
| patch_embeds = self.ln_pre(patch_embeds) |
| positions = torch.stack( |
| torch.meshgrid( |
| torch.arange(Hp), |
| torch.arange(Wp), |
| indexing="ij", |
| ), |
| dim=-1, |
| ).reshape(-1, 2) |
|
|
| freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] |
| rope = partial(apply_rotary_emb, freqs_cis=freqs_cis) |
| out = self.transformer(patch_embeds, rope=rope) |
|
|
| return out |
|
|
| def get_num_params( |
| self, |
| ) -> int: |
| """ |
| Return the number of parameters in the model. |
| """ |
| n_params = sum(p.numel() for p in self.parameters()) |
| return n_params |
|
|
|
|
| class VisionTransformerBlocks(nn.Module): |
| """ |
| Vision Transformer Blocks. |
| |
| This class implements a stack of Transformer blocks used in the Vision Transformer. |
| |
| Args: |
| n_layers (int): Number of transformer layers. |
| args (Mapping[str, Any]): Arguments for each transformer block, including dimensions, |
| """ |
|
|
| def __init__( |
| self, |
| n_layers: int, |
| args: Mapping[str, Any], |
| ): |
| super().__init__() |
| self.layers = torch.nn.ModuleList() |
|
|
| for layer_id in range(n_layers): |
| self.layers.append( |
| TransformerBlock( |
| layer_id=layer_id, |
| args=args, |
| ) |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope: Callable, |
| ) -> torch.Tensor: |
| """ |
| Forward pass through the Vision Transformer Blocks. |
| |
| This method applies a series of Transformer blocks to the input tensor, |
| using the provided rotary position embedding function. |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size, |
| N is the number of patches, and D is the embedding dimension. |
| rope (Callable): Rotary position embedding function to be applied in each layer. |
| |
| Returns: |
| torch.Tensor: Output tensor after passing through all transformer layers, |
| with the same shape as the input. |
| """ |
| for layer in self.layers: |
| x = layer(x, input_pos=None, mask=None, rope=rope) |
| return x |
|
|