| import json |
| import logging |
| import math |
| import os |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from open_clip.factory import get_model_config, load_state_dict |
| from open_clip.model import (CLIPTextCfg, CLIPVisionCfg, _build_text_tower, |
| _build_vision_tower, |
| convert_to_custom_text_state_dict) |
| from open_clip.transformer import text_global_pool |
| from torch import nn |
| from torchvision.ops import roi_align |
| from transformers import (CONFIG_MAPPING, AutoConfig, AutoModel, |
| AutoModelForCausalLM, GenerationConfig, |
| PretrainedConfig, PreTrainedModel, StoppingCriteria, |
| StoppingCriteriaList) |
| from transformers.activations import ACT2FN |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.generation import GenerationConfig |
| from transformers.modeling_utils import load_state_dict |
| from transformers.utils import logging, strtobool |
|
|
| from .convnext import ConvNextVisionEncoder |
|
|
| logger = logging.get_logger(__name__) |
|
|
| XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() |
| XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() |
|
|
| IGNORE_INDEX = -100 |
| DEFAULT_PAD_TOKEN_INDEX = 0 |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
| |
| DEFAULT_OBJECT_TOKEN = "<obj<i>>" |
| DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" |
| DEFAULT_OBJECT_INDEX = -300 |
|
|
| |
| DEFAULT_GROUNDING_START = "<ground>" |
| DEFAULT_GROUNDING_END = "</ground>" |
| DEFAULT_GROUNDING_OBJECTS_START = "<objects>" |
| DEFAULT_GROUNDING_OBJECTS_END = "</objects>" |
|
|
| def is_fsdp_enabled(): |
| return ( |
| torch.distributed.is_available() |
| and torch.distributed.is_initialized() |
| and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 |
| and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 |
| ) |
|
|
|
|
|
|
|
|
| def get_token_slices(input_ids: torch.Tensor): |
| """ |
| Get slices of tokens based on special markers in the input tensor. |
| |
| Args: |
| input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token, |
| DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens. |
| |
| Returns: |
| List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the |
| token slice ('text', 'image', 'object') and the span as a list of start and end indices. |
| """ |
| |
| type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"} |
|
|
| |
| image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0] |
| object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0] |
| if len(object_indices) > 0: |
| has_object = True |
| else: |
| has_object = False |
|
|
| |
| special_indices = torch.cat((image_indices, object_indices)) |
| special_indices, _ = torch.sort(special_indices) |
| special_tokens = input_ids[special_indices] |
|
|
| slices = [] |
| start_idx = 0 |
|
|
| for i, idx in enumerate(special_indices): |
| if start_idx < idx: |
| slices.append({"type": "text", "span": [start_idx, idx.item()]}) |
| token_type = type_map[special_tokens[i].item()] |
| slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]}) |
| start_idx = idx.item() + 1 |
|
|
| if start_idx < len(input_ids): |
| slices.append({"type": "text", "span": [start_idx, len(input_ids)]}) |
|
|
| return slices, has_object |
|
|
|
|
| def prepare_inputs_labels_for_multimodal( |
| llm, |
| input_ids: torch.LongTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| bbox_feats=None, |
| extra_llm_input_embed: nn.Embedding = None, |
| **kwargs, |
| ): |
| if pixel_values is None: |
| return { |
| "input_ids": input_ids, |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| "inputs_embeds": None, |
| "labels": labels, |
| } |
|
|
| _labels = labels |
| _position_ids = position_ids |
| _attention_mask = attention_mask |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange( |
| 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device |
| ) |
| if labels is None: |
| labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
| |
| input_ids = [ |
| cur_input_ids[cur_attention_mask] |
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| ] |
| labels = [ |
| cur_labels[cur_attention_mask] |
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
| ] |
|
|
| new_inputs_embeds = [] |
| new_labels = [] |
| cur_image_idx = 0 |
| cur_object_idx = 0 |
| for batch_idx, cur_input_ids in enumerate(input_ids): |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| if num_images == 0: |
| cur_pixel_values = pixel_values[cur_image_idx] |
| cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) |
| cur_inputs_embeds = torch.cat( |
| [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0 |
| ) |
| new_inputs_embeds.append(cur_inputs_embeds) |
| new_labels.append(labels[batch_idx]) |
| cur_image_idx += 1 |
| cur_object_idx += 1 |
| continue |
|
|
| cur_labels = labels[batch_idx] |
| token_slices, has_object = get_token_slices(cur_input_ids) |
| result_input_embeddings = [] |
| result_output_labels = [] |
| cur_gt_bnox_indice = 0 |
| for slice in token_slices: |
| slice_type = slice["type"] |
| slice_span = slice["span"] |
| if slice_type == "text": |
| cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]] |
| cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]] |
| cur_input_embeds = llm.get_input_embeddings()(cur_input_ids_noim) |
| result_input_embeddings.append(cur_input_embeds) |
| result_output_labels.append(cur_labels_noim) |
| elif slice_type == "image": |
| cur_input_embeds = pixel_values[cur_image_idx] |
| result_input_embeddings.append(cur_input_embeds) |
| result_output_labels.append( |
| torch.full( |
| (cur_input_embeds.shape[0],), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype, |
| ) |
| ) |
| cur_image_idx += 1 |
| elif slice_type == "object": |
| try: |
| result_input_embeddings.append( |
| bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0) |
| ) |
| except: |
| raise ValueError( |
| f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, " |
| ) |
| cur_gt_bnox_indice += 1 |
| result_output_labels.append( |
| torch.full( |
| (1,), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype, |
| ) |
| ) |
| cur_object_idx += 1 |
| result_input_embeddings = torch.cat(result_input_embeddings) |
| result_output_labels = torch.cat(result_output_labels) |
| assert len(result_output_labels) == len(result_input_embeddings) |
| new_inputs_embeds.append(result_input_embeddings) |
| new_labels.append(result_output_labels) |
|
|
| |
| max_len = max(x.shape[0] for x in new_inputs_embeds) |
| batch_size = len(new_inputs_embeds) |
|
|
| new_inputs_embeds_padded = [] |
| new_labels_padded = torch.full( |
| (batch_size, max_len), |
| IGNORE_INDEX, |
| dtype=new_labels[0].dtype, |
| device=new_labels[0].device, |
| ) |
| attention_mask = torch.zeros( |
| (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device |
| ) |
| position_ids = torch.zeros( |
| (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device |
| ) |
|
|
| for i, (cur_new_embed, cur_new_labels) in enumerate( |
| zip(new_inputs_embeds, new_labels) |
| ): |
| cur_len = cur_new_embed.shape[0] |
| new_inputs_embeds_padded.append( |
| torch.cat( |
| ( |
| cur_new_embed, |
| torch.zeros( |
| (max_len - cur_len, cur_new_embed.shape[1]), |
| dtype=cur_new_embed.dtype, |
| device=cur_new_embed.device, |
| ), |
| ), |
| dim=0, |
| ) |
| ) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange( |
| 0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
| ) |
|
|
| new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) |
|
|
| if _labels is None: |
| new_labels = None |
| else: |
| new_labels = new_labels_padded |
|
|
| if _attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
| if _position_ids is None: |
| position_ids = None |
|
|
| return { |
| "input_ids": None, |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| "inputs_embeds": new_inputs_embeds, |
| "labels": new_labels, |
| } |
|
|
| class StopWordStoppingCriteria(StoppingCriteria): |
| """StopWord stopping criteria.""" |
|
|
| def __init__(self, tokenizer, stop_word): |
| self.tokenizer = tokenizer |
| self.stop_word = stop_word |
| self.length = len(self.stop_word) |
|
|
| def __call__(self, input_ids, *args, **kwargs) -> bool: |
| cur_text = self.tokenizer.decode(input_ids[0]) |
| cur_text = cur_text.replace('\r', '').replace('\n', '') |
| return cur_text[-self.length:] == self.stop_word |
|
|
| def get_stop_criteria( |
| tokenizer, |
| stop_words=[], |
| ): |
| stop_criteria = StoppingCriteriaList() |
| for word in stop_words: |
| stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
| return stop_criteria |
|
|
| class DualPathFuseModule(nn.Module): |
| |
| def __init__(self, low_res_dim, high_res_dim, zero_init=True): |
| super().__init__() |
|
|
| self.slow_conv = nn.Conv2d(high_res_dim, high_res_dim, 1) |
| self.slow_proj = nn.Conv2d(high_res_dim, low_res_dim, 1) |
|
|
| self.fast_conv = nn.Conv2d( |
| low_res_dim, low_res_dim, 7, padding=3, groups=low_res_dim |
| ) |
| self.fast_proj = nn.Conv2d(low_res_dim, low_res_dim, 1) |
|
|
| self.gate = nn.Sequential( |
| nn.Linear(low_res_dim * 2, low_res_dim // 2), |
| nn.GELU(), |
| nn.Linear(low_res_dim // 2, 1), |
| ) |
|
|
| nn.init.xavier_uniform_(self.slow_conv.weight) |
| nn.init.xavier_uniform_(self.fast_conv.weight) |
| nn.init.zeros_(self.slow_conv.bias) |
| nn.init.zeros_(self.fast_conv.bias) |
| if zero_init: |
| nn.init.zeros_(self.slow_proj.weight) |
| nn.init.zeros_(self.fast_proj.weight) |
| else: |
| nn.init.xavier_uniform_(self.slow_proj.weight) |
| nn.init.xavier_uniform_(self.fast_proj.weight) |
| nn.init.zeros_(self.slow_proj.bias) |
| nn.init.zeros_(self.fast_proj.bias) |
|
|
| def forward(self, low_res_feat, high_res_feat, sampler=None): |
| b, c, h, w = high_res_feat.shape |
| _, _, d = low_res_feat.shape |
| high_res_feat = self.slow_proj( |
| F.gelu(self.slow_conv(high_res_feat)) |
| ) |
| high_res_feat = high_res_feat.view(b, d, -1).transpose(1, 2) |
| dst_size = int(math.sqrt(low_res_feat.shape[1])) |
| low_res_feat = low_res_feat.transpose(1, 2).view( |
| b, d, dst_size, dst_size |
| ) |
| low_res_feat = low_res_feat + self.fast_proj( |
| F.gelu(self.fast_conv(low_res_feat)) |
| ) |
| low_res_feat = low_res_feat.view(b, d, dst_size * dst_size).transpose( |
| 1, 2 |
| ) |
| gate = self.gate( |
| torch.cat([low_res_feat, high_res_feat], -1).mean(1) |
| ).unsqueeze( |
| 1 |
| ) |
| low_res_feat = low_res_feat + high_res_feat * gate.tanh() |
| return low_res_feat |
|
|
| class ProjectorConfig(PretrainedConfig): |
| model_type = "projector" |
| _auto_class = "AutoConfig" |
|
|
| def __init__( |
| self, |
| visual_hidden_size=4096, |
| llm_hidden_size=4096, |
| depth=2, |
| hidden_act="gelu", |
| bias=True, |
| **kwargs, |
| ): |
| self.visual_hidden_size = visual_hidden_size |
| self.llm_hidden_size = llm_hidden_size |
| self.depth = depth |
| self.hidden_act = hidden_act |
| self.bias = bias |
| super().__init__(**kwargs) |
|
|
| class ProjectorModel(PreTrainedModel): |
| _auto_class = "AutoModel" |
| config_class = ProjectorConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = [] |
|
|
| def __init__(self, config: ProjectorConfig) -> None: |
| super().__init__(config) |
| self.gradient_checkpointing = False |
|
|
| modules = [ |
| nn.Linear( |
| config.visual_hidden_size, config.llm_hidden_size, bias=config.bias |
| ) |
| ] |
| for _ in range(1, config.depth): |
| modules.append(ACT2FN[config.hidden_act]) |
| modules.append( |
| nn.Linear( |
| config.llm_hidden_size, config.llm_hidden_size, bias=config.bias |
| ) |
| ) |
| self.model = nn.Sequential(*modules) |
|
|
| def enable_input_require_grads(self): |
|
|
| def make_inputs_require_grad(module, input, output): |
| output.requires_grad_(True) |
|
|
| self.model.register_forward_hook(make_inputs_require_grad) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, ProjectorModel): |
| module.gradient_checkpointing = value |
|
|
| def forward(self, x): |
| layer_outputs = self.model(x) |
| return layer_outputs |
|
|
|
|
| def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): |
| """Generate sine position embedding from a position tensor. |
| |
| Args: |
| pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in |
| normalized coordinates in range [0, 1]. |
| out_dim (int): the output dimension of the position embedding. |
| |
| Returns: |
| pos (torch.Tensor): shape: [batch_size, N, out_dim]. |
| """ |
| scale = 2 * math.pi |
| dim_t = torch.arange( |
| dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device |
| ) |
| dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) |
| x_embed = pos_tensor[:, :, 0] * scale |
| y_embed = pos_tensor[:, :, 1] * scale |
| pos_x = x_embed[:, :, None] / dim_t |
| pos_y = y_embed[:, :, None] / dim_t |
| pos_x = torch.stack( |
| (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
| pos_y = torch.stack( |
| (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
| if pos_tensor.size(-1) == 2: |
| pos = torch.cat((pos_y, pos_x), dim=2) |
| elif pos_tensor.size(-1) == 4: |
| w_embed = pos_tensor[:, :, 2] * scale |
| pos_w = w_embed[:, :, None] / dim_t |
| pos_w = torch.stack( |
| (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
|
|
| h_embed = pos_tensor[:, :, 3] * scale |
| pos_h = h_embed[:, :, None] / dim_t |
| pos_h = torch.stack( |
| (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
|
|
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
| else: |
| raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
| return pos |
|
|
|
|
| class MultiLevelROIVisualPrompt(nn.Module): |
| """Initialize the MultiLevelROIVisualPrompt. |
| |
| Args: |
| output_size (Optional[int]): The size of the output. Default is None. |
| channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536]. |
| spatial_scale (Optional[float]): The spatial scale factor. Default is None. |
| with_additional_projection (bool): Whether to use additional projection. Default is False. |
| visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024. |
| add_pos_embedding (bool): Whether to add position embedding. Default is False. |
| pos_embedding_dim (int): The dimension of the position embedding. Default is 1024. |
| """ |
|
|
| def __init__( |
| self, |
| output_size: int = None, |
| channel_per_level: List[int] = [192, 384, 768, 1536], |
| spatail_scale: float = None, |
| visual_prompt_hidden_size: bool = 1024, |
| add_pos_embedding: bool = False, |
| pos_embedding_dim: int = 1024, |
| ): |
| super(MultiLevelROIVisualPrompt, self).__init__() |
| self.output_size = output_size |
| self.channel_per_level = channel_per_level |
| self.spatail_scale = spatail_scale |
| self.add_pos_embedding = add_pos_embedding |
| self.pos_embedding_dim = pos_embedding_dim |
|
|
| def __call__( |
| self, |
| multi_level_features: List[torch.Tensor], |
| boxes: Union[torch.Tensor, List[torch.Tensor]], |
| ) -> torch.Tensor: |
| """Performs Region of Interest (RoI) Align operator on multi-level features. The RoI |
| feature on each scale will go through a different linear layer for projection. Different |
| RoI features will be summed up and then average pooled. |
| |
| Args: |
| multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels |
| boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) |
| format where the regions will be taken from. |
| Returns: |
| Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs |
| """ |
| boxes[0] = boxes[0].float() |
| concat_multi_level_feature = [] |
| max_height = max([feature.shape[2] for feature in multi_level_features]) |
| max_width = max([feature.shape[3] for feature in multi_level_features]) |
| |
| for level, feature in enumerate(multi_level_features): |
| if level != 0: |
| concat_multi_level_feature.append( |
| F.interpolate( |
| feature.float(), |
| size=(max_height, max_width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| ) |
| else: |
| concat_multi_level_feature.append(feature.float()) |
| concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) |
|
|
| |
| out_box_feat = roi_align( |
| concat_multi_level_feature, |
| boxes, |
| output_size=self.output_size, |
| spatial_scale=self.spatail_scale, |
| ) |
|
|
| |
| out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( |
| 1, out_box_feat.shape[0], out_box_feat.shape[1] |
| ) |
| if self.add_pos_embedding: |
| |
| boxes = boxes[0] |
| boxes = boxes.to(out_box_feat.dtype) |
| original_img_width = max_width / self.spatail_scale |
| original_img_height = max_height / self.spatail_scale |
| boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width |
| boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height |
| |
| boxes[:, 2] = boxes[:, 2] - boxes[:, 0] |
| boxes[:, 3] = boxes[:, 3] - boxes[:, 1] |
| boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 |
| boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 |
| pos_embed = gen_sineembed_for_position( |
| boxes.unsqueeze(0), self.pos_embedding_dim // 4 |
| ) |
| out_box_feat = out_box_feat + pos_embed |
|
|
| return out_box_feat |
|
|
|
|
|
|
| class ChatRexAuxConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of ChatRexAux model. |
| |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| |
| Args: |
| vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): |
| The config object or dictionary of the vision backbone. |
| vision_aux_config (`Union[AutoConfig, dict]`, *optional*, defaults to `OpenCLIPVisionTower`): |
| visual_prompt_encoder (`Union[AutoConfig, dict]`, *optional*, defaults to `MultiLevelROIVisualPrompt`): |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| ignore_index (`int`, *optional*, defaults to -100): |
| The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32000): |
| The image token index to encode the image prompt. |
| projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
| The activation function used by the multimodal projector. |
| vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): |
| The feature selection strategy used to select the vision feature from the vision backbone. |
| Can be one of `"default"` or `"full"`. |
| vision_feature_layer (`int`, *optional*, defaults to -2): |
| The index of the layer to select the vision feature. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig |
| |
| >>> # Initializing a CLIP-vision config |
| >>> vision_config = CLIPVisionConfig() |
| |
| >>> # Initializing a Llama config |
| >>> text_config = LlamaConfig() |
| |
| >>> # Initializing a Llava llava-1.5-7b style configuration |
| >>> configuration = LlavaConfig(vision_config, text_config) |
| |
| >>> # Initializing a model from the llava-1.5-7b style configuration |
| >>> model = LlavaForConditionalGeneration(configuration) |
| |
| >>> # Accessing the model configuration |
| >>> configuration = model.config |
| ```""" |
|
|
| model_type = "chatrex" |
| is_composition = False |
|
|
| def __init__( |
| self, |
| vision_config=None, |
| vision_aux_config=None, |
| visual_prompt_encoder_config=None, |
| text_config=None, |
| ignore_index=-100, |
| image_token_index=32000, |
| projector_hidden_act="gelu", |
| vision_feature_select_strategy="default", |
| vision_feature_layer=-2, |
| projector_depth=2, |
| visual_prompt_hidden_size=2880, |
| **kwargs, |
| ): |
| self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.projector_hidden_act = projector_hidden_act |
| self.projector_depth = projector_depth |
| self.visual_prompt_hidden_size = visual_prompt_hidden_size |
| self.visual_prompt_encoder_config = visual_prompt_encoder_config |
|
|
| if vision_feature_select_strategy not in ["default", "full"]: |
| raise ValueError( |
| "vision_feature_select_strategy should be one of 'default', 'full'." |
| f"Got: {vision_feature_select_strategy}" |
| ) |
|
|
| self.vision_feature_select_strategy = vision_feature_select_strategy |
| self.vision_feature_layer = vision_feature_layer |
|
|
| if isinstance(vision_config, dict): |
| vision_config["model_type"] = ( |
| vision_config["model_type"] |
| if "model_type" in vision_config |
| else "clip_vision_model" |
| ) |
| vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) |
| elif vision_config is None: |
| vision_config = CONFIG_MAPPING["clip_vision_model"]( |
| intermediate_size=4096, |
| hidden_size=1024, |
| patch_size=14, |
| image_size=336, |
| num_hidden_layers=24, |
| num_attention_heads=16, |
| vocab_size=32000, |
| projection_dim=768, |
| ) |
|
|
| self.vision_config = vision_config |
| self.vision_aux_config = vision_aux_config |
|
|
| if isinstance(text_config, dict): |
| text_config["model_type"] = ( |
| text_config["model_type"] if "model_type" in text_config else "llama" |
| ) |
| text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) |
| elif text_config is None: |
| text_config = CONFIG_MAPPING["llama"]() |
|
|
| self.text_config = text_config |
|
|
| super().__init__(**kwargs) |
|
|
|
|
| class ChatRexAuxPreTrainedModel(PreTrainedModel): |
| config_class = ChatRexAuxConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["LlavaVisionAttention"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_cache_class = True |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @property |
| def _supports_sdpa(self): |
| """ |
| Retrieve language_model's attribute to check whether the model supports |
| SDPA or not. |
| """ |
| return self.language_model._supports_sdpa |
|
|
|
|
| class ChatRexAuxForConditionalGeneration(ChatRexAuxPreTrainedModel): |
|
|
| def __init__(self, config: ChatRexAuxConfig): |
| super().__init__(config) |
| |
| self.vision_encoder = AutoModel.from_config(config.vision_config) |
| |
| self.vision_encoder_aux = ConvNextVisionEncoder() |
|
|
| |
| projector_config = ProjectorConfig( |
| visual_hidden_size=config.vision_config.hidden_size, |
| llm_hidden_size=config.text_config.hidden_size, |
| depth=config.projector_depth, |
| ) |
| self.projector = ProjectorModel(projector_config) |
|
|
| |
| vp_projector_config = ProjectorConfig( |
| visual_hidden_size=config.visual_prompt_hidden_size, |
| llm_hidden_size=config.text_config.hidden_size, |
| depth=config.projector_depth, |
| ) |
| self.vp_projector = ProjectorModel(vp_projector_config) |
|
|
| |
| self.fuser = DualPathFuseModule( |
| low_res_dim=config.vision_config.hidden_size, |
| high_res_dim=1536, |
| ) |
|
|
| |
| self.vp_encoder = MultiLevelROIVisualPrompt( |
| output_size=7, |
| channel_per_level=[192, 384, 768, 1536], |
| spatail_scale=192 / 768, |
| add_pos_embedding=True, |
| pos_embedding_dim=2880, |
| ) |
|
|
| |
| self.gen_config = None |
|
|
| self.vocab_size = config.text_config.vocab_size |
| self.llm = AutoModelForCausalLM.from_config( |
| config.text_config, attn_implementation=config._attn_implementation |
| ) |
| self.pad_token_id = ( |
| self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
| ) |
| self.post_init() |
|
|
| |
| def _prepare_data_for_llm(self, data): |
| if "pixel_values" in data: |
| visual_outputs = self.vision_encoder( |
| data["pixel_values"].to(self.vision_encoder.dtype), |
| output_hidden_states=True, |
| ) |
| if type(self.vision_encoder).__name__ in [ |
| "CLIPVisionModel", |
| "CLIPVisionModelAnyRes", |
| ]: |
| visual_outputs = visual_outputs.hidden_states[-2][ |
| :, 1: |
| ] |
| elif type(self.vision_encoder).__name__ == "SiglipVisionModel": |
| visual_outputs = visual_outputs.hidden_states[-2] |
| else: |
| raise NotImplementedError |
|
|
| |
| if self.vision_encoder_aux is not None: |
| pixels_aux = [] |
| for pixels in data["pixel_values_aux"]: |
| if pixels.dim() == 3: |
| pixels = pixels.unsqueeze(0) |
| elif pixels.dim() == 4: |
| pixels = pixels.permute(1, 0, 2, 3) |
| pixels_aux.append(pixels) |
| visual_outputs_aux = torch.cat( |
| pixels_aux, dim=0 |
| ) |
| aux_output = self.vision_encoder_aux( |
| visual_outputs_aux |
| ) |
| visual_outputs_aux = aux_output["image_features"] |
| last_feat = aux_output["last_feat"] |
| |
| fuse_features = self.fuser( |
| low_res_feat=visual_outputs, high_res_feat=last_feat |
| ) |
| pixel_values = self.projector(fuse_features) |
| data["pixel_values"] = pixel_values |
|
|
| |
| bbox_visual_outputs = [] |
| if "gt_boxes" in data: |
| for batch_idx, boxes in enumerate(data["gt_boxes"]): |
| if len(boxes) == 0: |
| bbox_visual_outputs.append(None) |
| continue |
| multi_level_aux_features = [ |
| visual_output_aux[batch_idx].unsqueeze(0) |
| for visual_output_aux in visual_outputs_aux |
| ] |
| boxes = boxes.to(torch.float32) |
| out_vp_feat = self.vp_encoder( |
| multi_level_aux_features, |
| [boxes], |
| ).squeeze(0) |
| out_vp_feat = out_vp_feat.to(pixel_values.dtype) |
| out_vp_feat = self.vp_projector(out_vp_feat) |
| bbox_visual_outputs.append(out_vp_feat) |
| |
| data["bbox_feats"] = bbox_visual_outputs |
| |
| data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) |
| return data |
|
|
| |
| def generate(self, data_dict: Dict[str, Any], gen_config=None, tokenizer=None): |
| """Perform inference on the given data. |
| |
| Args: |
| data_dict (Dict[str, Any]): The data to perform inference on. |
| |
| Returns: |
| str: The answer to the question. |
| """ |
| data_dict = self._prepare_data_for_llm(data_dict) |
| data_dict["inputs_embeds"] = data_dict["inputs_embeds"].to(self.llm.dtype) |
| stop_criteria = get_stop_criteria( |
| tokenizer=tokenizer, stop_words=[] |
| ) |
| generate_output = self.llm.generate( |
| **data_dict, |
| generation_config=self.gen_config if gen_config is None else gen_config, |
| streamer=None, |
| bos_token_id=tokenizer.bos_token_id, |
| stopping_criteria=stop_criteria, |
| ) |
| print(f'generate_output:', generate_output) |
| prediction = tokenizer.decode( |
| generate_output[0], skip_special_tokens=False |
| ).strip() |
| prediction = prediction.replace("<s>", "").replace("</s>", "").strip() |
| return prediction |
|
|
|
|
| AutoConfig.register("chatrex", ChatRexAuxConfig) |
| AutoModelForCausalLM.register(ChatRexAuxConfig, ChatRexAuxForConditionalGeneration) |
|
|