| |
| |
| |
| |
| |
|
|
| import warnings |
| from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, |
| Set, Tuple, Type, TypedDict, Union) |
|
|
| import torch.utils.checkpoint |
| import transformers |
| import torch |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, |
| LlamaTokenizer) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ModelOutput, logging |
|
|
| from transformer_engine.pytorch import RMSNorm |
| from transformers.activations import ACT2FN |
|
|
| from .configuration_yuanvl import YuanVLChatConfig |
| from .conversation import get_conv_template |
| from .modeling_intern_vit import InternVisionModel, has_flash_attn |
| from .modeling_yuanlm2 import YuanForCausalLM |
| from .utils import flatten_bn, merge_multimodal_embeddings |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class InternVLImagePixelInputs(TypedDict): |
| type: Literal["pixel_values"] |
| data: Union[torch.Tensor, List[torch.Tensor]] |
| """ |
| Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` |
| |
| Note that `num_patches` may be different for each batch, in which case |
| the data is passed as a list instead of a batched tensor. |
| """ |
| patches_per_image: List[int] |
| """ |
| List of number of total patches for each image in the batch. |
| """ |
|
|
|
|
| class InternVLImageEmbeddingInputs(TypedDict): |
| type: Literal["image_embeds"] |
| data: Any |
| """ |
| A tensor of shape `(num_images, total_image_feature_size, hidden_size)` |
| or a list of tensors of shape `(total_image_feature_size, hidden_size)` |
| |
| `hidden_size` must match the hidden size of language model backbone. |
| """ |
|
|
|
|
| InternVLImageInputs = Union[InternVLImagePixelInputs, |
| InternVLImageEmbeddingInputs] |
|
|
|
|
| def version_cmp(v1, v2, op='eq'): |
| import operator |
|
|
| from packaging import version |
| op_func = getattr(operator, op) |
| return op_func(version.parse(v1), version.parse(v2)) |
|
|
| class YuanImageMLP(nn.Module): |
| |
| def __init__( |
| self, |
| hidden_size: int, |
| intermediate_size: int, |
| output_size: int, |
| hidden_act: str, |
| ) -> None: |
| super().__init__() |
| |
| |
| |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, output_size, bias=False) |
|
|
| if hidden_act != "silu": |
| raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported for now.") |
| |
| self.act_fn = ACT2FN[hidden_act] |
| |
| @torch.compile |
| def swiglu(self, y_1, y_2): |
| return self.act_fn(y_1) * y_2 |
| |
| def forward(self, x): |
| |
| x1 = self.up_proj(x) |
| x2 = self.gate_proj(x) |
| x3 = self.swiglu(x1, x2) |
| |
| |
| x = self.down_proj(x3) |
| return x |
|
|
| class YuanVLChatModel(PreTrainedModel): |
| config_class = YuanVLChatConfig |
| main_input_name = 'pixel_values' |
| base_model_prefix = 'language_model' |
| _supports_flash_attn_2 = True |
| _no_split_modules = ['InternVisionModel', 'YuanDeocderLayer'] |
|
|
| def __init__(self, config: YuanVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True): |
| super().__init__(config) |
|
|
| assert version_cmp(transformers.__version__, '4.37.0', 'ge') |
| image_size = config.force_image_size or config.vision_config.image_size |
| patch_size = config.vision_config.patch_size |
| self.patch_size = patch_size |
| self.select_layer = config.select_layer |
| self.template = config.template |
| self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) |
| self.downsample_ratio = config.downsample_ratio |
| self.ps_version = config.ps_version |
| use_flash_attn = use_flash_attn if has_flash_attn else False |
| config.vision_config.use_flash_attn = True if use_flash_attn else False |
| config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' |
|
|
| logger.info(f'num_image_token: {self.num_image_token}') |
| logger.info(f'ps_version: {self.ps_version}') |
| if vision_model is not None: |
| self.vision_model = vision_model |
| else: |
| self.vision_model = InternVisionModel(config.vision_config) |
| if language_model is not None: |
| self.language_model = language_model |
| else: |
| if config.llm_config.architectures[0] == 'YuanForCausalLM': |
| self.language_model = YuanForCausalLM(config.llm_config) |
| else: |
| raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') |
|
|
| self.pixel_unshuffle = torch.nn.PixelUnshuffle(downscale_factor=2) |
| |
| |
| |
| |
| layernorm_epsilon = config.llm_config.rms_norm_eps |
|
|
| self.imagemlp_input_hiddensize = int(config.vision_config.hidden_size / self.downsample_ratio ** 2) |
| self.imagemlp_ffn_hidden_size = config.llm_config.ffn_hidden_size |
|
|
| self.imagemlp = YuanImageMLP(self.imagemlp_input_hiddensize, self.imagemlp_ffn_hidden_size, |
| output_size=config.llm_config.hidden_size, hidden_act="silu") |
| self.imagemlp_layernorm = RMSNorm(config.llm_config.hidden_size, eps=layernorm_epsilon) |
|
|
| ''' |
| # modify internvl vision |
| vit_hidden_size = config.vision_config.hidden_size |
| llm_hidden_size = config.llm_config.hidden_size |
| self.mlp1 = nn.Sequential( |
| nn.LayerNorm(vit_hidden_size * int(1/self.downsample_ratio) ** 2), |
| nn.Linear(vit_hidden_size * int(1/self.downsample_ratio) ** 2, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, llm_hidden_size) |
| ) |
| ''' |
|
|
| self.img_context_token_id = config.img_context_token_id |
| self.conv_template = get_conv_template(self.template) |
| self.system_message = self.conv_template.system_message |
|
|
| def _validate_pixel_values(self, |
| data: Union[torch.Tensor, List[torch.Tensor]] |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| |
| h = w = self.config.vision_config.image_size |
| expected_dims = (3, h, w) |
|
|
| def _validate_shape(d: torch.Tensor): |
| actual_dims = tuple(d.shape) |
| if actual_dims != expected_dims: |
| |
| expected_expr = (expected_dims) |
| raise ValueError("The expected shape of pixel values in each batch element " |
| f" is {expected_expr}. You supplied {tuple(d.shape)}.") |
| |
| |
| for d in data: |
| _validate_shape(d) |
| return data |
|
|
|
|
|
|
| def _parse_and_validate_image_input(self, |
| pixel_values: List[torch.Tensor] = None, |
| image_token_id: torch.Tensor = None, |
| image_embeds: torch.Tensor = None, |
| ) -> Optional[InternVLImagePixelInputs]: |
| |
| if pixel_values is None and image_embeds is None: |
| return None |
| |
| |
| if image_embeds is not None: |
| if not isinstance(image_embeds, torch.Tensor): |
| raise ValueError("Incorrect type of image embeddings. " |
| f"Got type: {type(image_embeds)}") |
| return InternVLImageEmbeddingInputs( |
| type="image_embeds", |
| data=flatten_bn(image_embeds), |
| ) |
| |
| |
| if pixel_values is not None: |
| if not isinstance(pixel_values, (torch.Tensor, list)): |
| raise ValueError("Incorrect type of pixel values. " |
| f"Got type: {type(pixel_values)}") |
| patches_per_image = [] |
| |
| for request_pixel_values in pixel_values: |
| |
| patches_per_image.append(request_pixel_values.shape[0]) |
|
|
| |
| |
| |
| return InternVLImagePixelInputs( |
| type="pixel_values", |
| data=self._validate_pixel_values(flatten_bn(pixel_values)), |
| patches_per_image=patches_per_image) |
| raise AssertionError("This line should be unreachable") |
| |
| def _process_image_input( |
| self, |
| image_input: InternVLImageInputs, |
| ) -> Tuple[torch.Tensor] : |
| if image_input["type"] == "image_embeds": |
| return image_input["data"] |
| assert self.vision_model is not None |
| |
| image_embeds = self.extract_feature(image_input["data"]) |
|
|
| patches_per_image = image_input["patches_per_image"] |
|
|
| |
| |
| if len(patches_per_image) == 1: |
| |
| image_embeds = image_embeds.view(-1, self.config.llm_config.hidden_size).unsqueeze(1) |
| return image_embeds |
| |
| |
| |
| feature_size = image_embeds.shape[1] |
| |
| image_embeds = image_embeds.view(-1, self.config.llm_config.hidden_size) |
| image_feature_sizes = [num_patches * feature_size for num_patches in patches_per_image] |
| |
| image_embeds = image_embeds.split(image_feature_sizes) |
|
|
| return image_embeds |
| |
|
|
| |
| def get_multimodal_embeddings(self, |
| pixel_values: Optional[List[torch.Tensor]] = None, |
| image_token_id: Optional[List[torch.Tensor]] = None, |
| image_embeds: Optional[List[torch.Tensor]] = None, |
| image_input: InternVLImageInputs = None, |
| ): |
| image_input = self._parse_and_validate_image_input(pixel_values, image_token_id, image_embeds) |
| if image_input is None: |
| return None |
| |
| |
| vision_embeddings = self._process_image_input(image_input) |
| return vision_embeddings |
| |
| def get_input_embeddings( |
| self, |
| input_ids: torch.Tensor, |
| multimodal_embeddings: Optional[torch.Tensor] |
| ) -> torch.Tensor: |
| |
| inputs_embeds = self.language_model.model.get_input_embeddings(input_ids) |
| |
| if multimodal_embeddings is not None: |
| assert self.img_context_token_id is not None |
| |
| |
| |
| |
| inputs_embeds = merge_multimodal_embeddings( |
| input_ids, inputs_embeds, multimodal_embeddings, |
| self.img_context_token_id) |
| return inputs_embeds |
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: torch.Tensor = None, |
| position_ids: torch.LongTensor = None, |
| past_key_values: List[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[List[torch.Tensor]] = None, |
| image_token_id: Optional[List[torch.Tensor]] = None, |
| image_embeds: Optional[List[torch.Tensor]] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
| import pdb |
| pdb.set_trace() |
| if inputs_embeds is None: |
| |
| vision_embeddings = self.get_multimodal_embeddings(pixel_values, image_token_id, image_embeds) |
| |
| inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings).permute(1, 0, 2) |
| input_ids = None |
| |
| hidden_states = self.language_model.model(input_ids, attention_mask, position_ids, past_key_values, |
| inputs_embeds, labels, use_cache, output_attentions, |
| output_hidden_states, return_dict) |
|
|
| return hidden_states |
|
|
| def pixel_shuffle(self, x, scale_factor=0.5): |
| n, w, h, c = x.size() |
| |
| x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) |
| |
| x = x.permute(0, 2, 1, 3).contiguous() |
| |
| x = x.view(n, int(h * scale_factor), int(w * scale_factor), |
| int(c / (scale_factor * scale_factor))) |
| if self.ps_version == 'v1': |
| warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " |
| 'which results in a transposed image.') |
| else: |
| x = x.permute(0, 2, 1, 3).contiguous() |
| return x |
|
|
| |
| def extract_feature(self, pixel_values): |
| |
| pixel_values = pixel_values.to(torch.bfloat16) |
| output = self.vision_model(pixel_values=pixel_values) |
| vit_embeds=output[0] |
| |
| vit_embeds = vit_embeds[:, 1:, :] |
|
|
| pn, phw, pc = vit_embeds.shape |
| ph = pw = int(phw**0.5) |
| vit_embeds = vit_embeds.view(pn, ph, pw, pc).permute(0, 3, 1, 2) |
| vit_embeds = self.pixel_unshuffle(vit_embeds) |
| pn, pc, ph, pw = vit_embeds.shape |
| vit_embeds = vit_embeds.view(pn, pc, ph * pw).permute(0, 2, 1) |
| num_images, cvs, chs = vit_embeds.shape |
| |
| |
| |
| vit_embeds = vit_embeds.reshape(1, -1, vit_embeds.shape[-1]).permute(1, 0, 2) |
| vit_embeds = self.imagemlp(vit_embeds) |
| vit_embeds = self.imagemlp_layernorm(vit_embeds) |
| vit_embeds = vit_embeds.view(num_images, cvs, -1) |
| return vit_embeds |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| ) -> torch.LongTensor: |
|
|
| |
| if pixel_values is not None: |
| if visual_features is not None: |
| vit_embeds = visual_features |
| else: |
| vit_embeds = self.get_multimodal_embeddings(pixel_values) |
| inputs_embeds = self.get_input_embeddings(input_ids, vit_embeds) |
| input_ids = None |
| |
| |
| outputs = self.language_model.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| position_ids=position_ids, |
| max_length=8192, |
| use_cache=True, |
| ) |
|
|
|
|
| return outputs |
|
|