| |
| |
| |
| |
| |
|
|
| import torch, os |
| import torch.nn as nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| from .siglip_vision_tower import SiglipVisionTower |
|
|
| import torch.nn.functional as F |
| from torch.nn.init import trunc_normal_ |
| from copy import deepcopy |
| import random |
| import math |
|
|
| class MultiBackboneChannelConcatenationVisionTower(nn.Module): |
| def __init__(self, |
| vision_tower, |
| args, |
| grid_size=32, |
| convnext_img_size=1024, |
| normalize_type=None, raw_config=None): |
| |
| super().__init__() |
|
|
| self.is_loaded = False |
| self.grid_size = grid_size |
| self.num_tokens = self.grid_size ** 2 |
| self.normalize_type = args.normalize_type |
| self.moe_version_type = args.moe_version_type |
| self.raw_config = raw_config |
| print("moe_version_type: ", self.moe_version_type) |
| assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}" |
| |
| vision_tower_name_list = vision_tower.split(";") |
| self.input_image_size = 1024 |
| self.convnext_img_size = convnext_img_size |
| self.load_vision_towers(vision_tower_name_list, args) |
|
|
| |
| def load_vision_towers(self, vision_tower_name_list, args): |
| self.vision_towers = nn.ModuleList() |
|
|
| freeze_backbone_list = args.freeze_backbones |
| if freeze_backbone_list is not None and len(freeze_backbone_list) > 0: |
| print("The frozen backbones: ", freeze_backbone_list) |
| else: |
| |
| freeze_backbone_list = "" |
|
|
| for name in vision_tower_name_list: |
| |
| |
| if name == 'convnext-1024': |
| convnext_args = deepcopy(args) |
|
|
| convnext_args.freeze_vision = False |
| if 'convnext-1024' in freeze_backbone_list: |
| convnext_args.freeze_vision = True |
|
|
| from .convnext_encoder import ConvNextVisionTower |
| convnext_args.input_image_size = self.convnext_img_size |
| convnext_vision_tower = args.vision_tower_convnext_path |
| convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, |
| convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type) |
| convnext_vision_tower.load_model() |
| self.vision_towers.append(convnext_vision_tower) |
|
|
| |
| elif name == 'palisiglip': |
| palisiglip_args = deepcopy(args) |
| palisiglip_args.input_image_size = 448 |
|
|
| palisiglip_args.freeze_vision = False |
| if 'palisiglip' in freeze_backbone_list: |
| palisiglip_args.freeze_vision = True |
|
|
| palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config) |
| |
| palisiglip_vision_tower.load_model() |
| self.vision_towers.append(palisiglip_vision_tower) |
|
|
| |
| self.image_processor = None |
| self.is_loaded = True |
|
|
| def load_model(self): |
| assert self.is_loaded, "All the vision encoders should be loaded during initialization!" |
|
|
| def forward(self, x): |
| |
| |
| if self.moe_version_type in [None, 'all_tiling']: |
| |
| features = [] |
| image_input_size = x.shape[2] |
| assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" |
| for vision_tower in self.vision_towers: |
| |
| if vision_tower.input_image_size != image_input_size: |
| resized_x = F.interpolate(x.float(), |
| size=(vision_tower.input_image_size, vision_tower.input_image_size), |
| mode='bilinear', |
| align_corners=True).to(dtype=x.dtype) |
| else: |
| resized_x = x |
| |
| feature = vision_tower(resized_x) |
| |
| if len(feature.shape) == 3: |
| b, n, c = feature.shape |
| if n == self.num_tokens: |
| features.append(feature) |
| continue |
| w = h = int(n**0.5) |
| feature = feature.transpose(1,2).reshape(b, c, h, w) |
| else: |
| b, c, h, w = feature.shape |
|
|
| if w != self.grid_size: |
| feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) |
| features.append(feature.flatten(2,3).transpose(1,2)) |
| |
| features = torch.cat(features, dim=-1) |
| elif self.moe_version_type == 'convnext_512_siglip_448': |
| features = {} |
| image_input_size = x.shape[2] |
| assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" |
| for vision_tower in self.vision_towers: |
| |
| if vision_tower.input_image_size != image_input_size: |
| resized_x = F.interpolate(x.float(), |
| size=(vision_tower.input_image_size, vision_tower.input_image_size), |
| mode='bilinear', |
| align_corners=True).to(dtype=x.dtype) |
| else: |
| resized_x = x |
| |
| feature = vision_tower(resized_x) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| features[vision_tower.name] = feature |
|
|
| else: |
| assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x)) |
| pixel_values = x['pixel_values'] |
| num_patches = x['num_patches'] |
|
|
| |
| if self.moe_version_type == 'seq_concat': |
| image_in_num_patches = [i-1 for i in num_patches] |
| else: |
| image_in_num_patches = [i for i in num_patches] |
|
|
|
|
| assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0)) |
|
|
| |
| thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1 |
| image_no_tiling = pixel_values[thumbnail_image_id] |
|
|
| |
| features = [] |
| for layer_id, vision_tower in enumerate(self.vision_towers): |
| if layer_id == 0: |
| x = pixel_values |
| else: |
| x = image_no_tiling |
|
|
| if vision_tower.input_image_size != self.input_image_size: |
| resized_x = F.interpolate(x.float(), |
| size=(vision_tower.input_image_size, vision_tower.input_image_size), |
| mode='bilinear', |
| align_corners=True).to(dtype=x.dtype) |
| else: |
| resized_x = x |
| |
| feature = vision_tower(resized_x) |
| if len(feature.shape) == 3: |
| b, n, c = feature.shape |
| if n == self.num_tokens: |
| features.append(feature) |
| continue |
|
|
| w = h = int(n**0.5) |
| feature = feature.transpose(1,2).reshape(b, c, h, w) |
| else: |
| b, c, h, w = feature.shape |
|
|
| if w != self.grid_size: |
| feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) |
| features.append(feature.flatten(2,3).transpose(1,2)) |
|
|
| clip_embeds = features[0] |
| if len(features) <= 1: |
| no_tiling_embeds = None |
| else: |
| no_tiling_embeds = torch.cat(features[1:], dim=-1) |
|
|
| if self.moe_version_type == 'feat_concat': |
| |
| clip_thumbnail_embeds = clip_embeds[thumbnail_image_id] |
| if no_tiling_embeds is not None: |
| no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1) |
| else: |
| no_tiling_embeds = clip_thumbnail_embeds |
|
|
| |
| clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id) |
| clip_embeds = clip_embeds[clip_embeds_mask] |
| |
|
|
| features = { |
| 'clip_embeds': clip_embeds, |
| 'no_tiling_embeds': no_tiling_embeds, |
| 'num_patches': num_patches |
| } |
|
|
| |
|
|
| return features |
| |
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return next(self.clip_vision_tower.parameters()).dtype |
|
|
| @property |
| def device(self): |
| return next(self.clip_vision_tower.parameters()).device |
|
|
| @property |
| def config(self): |
| assert NotImplementedError |
| pass |
|
|
| @property |
| def hidden_size(self): |
| if self.moe_version_type == 'convnext_512_siglip_448': |
| res = {} |
| for vision_tower in self.vision_towers: |
| res[vision_tower.name] = vision_tower.hidden_size |
| return res |
| else: |
| return sum([_.hidden_size for _ in self.vision_towers]) |
|
|
| @property |
| def num_patches(self): |
| return self.num_tokens |
|
|
|
|
|
|
|
|