| import torch, os |
| import torch.nn as nn |
| from timm import create_model |
| from transformers import CLIPImageProcessor |
| from .convnext import convnext_xxlarge |
| from torch.utils.checkpoint import checkpoint |
| import torch |
| from torchvision import transforms as T |
| from PIL import Image |
|
|
|
|
|
|
| cfg={ |
| "crop_size": 256, |
| "do_center_crop": True, |
| "do_normalize": True, |
| "do_resize": True, |
| "feature_extractor_type": "CLIPFeatureExtractor", |
| "image_mean": [ |
| 0.48145466, |
| 0.4578275, |
| 0.40821073 |
| ], |
| "image_std": [ |
| 0.26862954, |
| 0.26130258, |
| 0.27577711 |
| ], |
| "resample": 3, |
| "size": 256 |
| } |
|
|
|
|
|
|
| MEAN_SLIP = [0.5, 0.5, 0.5] |
| STD_SLIP = [0.5, 0.5, 0.5] |
|
|
| MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073] |
| STD_CLIP = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
|
| a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)] |
| b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)] |
|
|
|
|
| class SlipToClipTransform: |
| def __init__(self, a, b): |
| self.a = torch.tensor(a).view(-1, 1, 1) |
| self.b = torch.tensor(b).view(-1, 1, 1) |
| |
| def __call__(self, x_slip): |
| return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device) |
| slip_to_clip = SlipToClipTransform(a, b) |
|
|
| class ConvNextVisionTower(nn.Module): |
| def __init__(self, vision_tower, args, delay_load=False, normalize_type=None): |
| super().__init__() |
|
|
| self.is_loaded = False |
| self.freeze_vision=args.freeze_vision |
| self.input_image_size=args.input_image_size |
| self.vision_tower_name = vision_tower |
| self.name = 'convnext' |
| self.select_layer = args.mm_vision_select_layer |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
| self.pre_norm = normalize_type |
|
|
| print('pre_norm: ', self.pre_norm) |
| self.delay_load = delay_load |
| self.load_model() |
|
|
| def load_model(self): |
| if 'xxlarge' in self.vision_tower_name: |
| if self.delay_load: |
| self.vision_tower = convnext_xxlarge(pretrained=False) |
| else: |
| self.vision_tower = convnext_xxlarge(self.vision_tower_name) |
| setattr(self.vision_tower, 'hidden_size', 3072) |
| elif os.path.exists(self.vision_tower_name): |
| self.vision_tower = torch.load(self.vision_tower_name) |
| else: |
| assert False, 'Not implemented' |
|
|
|
|
| self.vision_tower = self.vision_tower.to(torch.bfloat16) |
|
|
| if self.freeze_vision: |
| self.vision_tower.requires_grad_(False) |
|
|
| |
| for s in self.vision_tower.stages: |
| s.grad_checkpointing = True |
|
|
| self.is_loaded = True |
|
|
| def feature_select(self, image_forward_outs): |
|
|
| if self.select_layer>100: |
| image_features = image_forward_outs[-4:] |
| else: |
| image_features = image_forward_outs[-1] |
| return image_features |
|
|
| def forward_features(self, x): |
| x = self.vision_tower.stem(x) |
| image_forward_out=[] |
| for blk in self.vision_tower.stages: |
| x = blk(x) |
| b,c,h,w=x.shape |
| image_forward_out.append(x.view(b,c,-1).transpose(1,2)) |
| return image_forward_out |
|
|
| def forward(self, images): |
| if self.freeze_vision: |
| with torch.no_grad(): |
| image_features = self._forward_images(images) |
| else: |
| image_features = self._forward_images(images) |
|
|
| return image_features |
|
|
| def _forward_images(self, images): |
| |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| if self.pre_norm == 'siglip': |
| dtype = image.dtype |
| image = slip_to_clip(image.to(torch.float32)).to(dtype) |
| image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) |
| image_feature = self.feature_select(image_forward_out) |
| image_features.append(image_feature) |
| else: |
| if self.pre_norm == 'siglip': |
| dtype = images.dtype |
| images = slip_to_clip(images.to(torch.float32)).to(dtype) |
| image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype)) |
| image_features = self.feature_select(image_forward_outs) |
|
|
| return image_features |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return next(self.vision_tower.parameters()).dtype |
|
|
| @property |
| def device(self): |
| return next(self.vision_tower.parameters()).device |
|
|
| @property |
| def config(self): |
| assert NotImplementedError |
| pass |
|
|
| @property |
| def num_attention_heads(self): |
| |
| return 16 |
| @property |
| def num_layers(self): |
| |
| return 4 |
| @property |
| def hidden_size(self): |
| return self.vision_tower.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2 |
|
|
|
|
| class ConvNextFPNVisionTower(nn.Module): |
| def __init__(self, |
| vision_tower, |
| args, |
| fpn_target_level=1, |
| fpn_layer_idx=[1,2,3], |
| fpn_input_dim=[768,1536,3072], |
| delay_load=False): |
| |
| super().__init__() |
|
|
| self.is_loaded = False |
| self.vision_tower_name = vision_tower.replace('-fpn', 'fpn') |
| self.freeze_vision = getattr(args, "frozen_backbone", True) |
| |
| self.input_image_size = 1024 |
| self.select_layer = args.mm_vision_select_layer |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
|
|
| self.need_fpn = True |
| self.fpn_layer_idx = fpn_layer_idx |
| self.fpn_input_dim = [768, 1536, 3072] |
| self.delay_load = delay_load |
| self.load_model() |
|
|
| def load_model(self): |
| if self.is_loaded: |
| return |
| |
| self.image_processor = CLIPImageProcessor(**cfg) |
| if 'xxlarge' in self.vision_tower_name: |
| self.vision_tower = convnext_xxlarge(self.vision_tower_name) |
| setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim) |
| |
| else: |
| self.vision_tower = convnext_large_mlp(self.vision_tower_name) |
| setattr(self.vision_tower, 'hidden_size', 1536) |
| if self.freeze_vision: |
| self.vision_tower.requires_grad_(False) |
|
|
| |
| for s in self.vision_tower.stages: |
| s.grad_checkpointing = True |
|
|
| if self.input_image_size is not None: |
| self.image_processor.size=self.input_image_size |
| self.image_processor.crop_size={ |
| 'height':self.input_image_size, |
| 'width': self.input_image_size |
| } |
|
|
| self.is_loaded = True |
|
|
| @torch.no_grad() |
| def forward_features(self, x): |
| x = self.vision_tower.stem(x) |
| image_forward_out=[] |
| for blk in self.vision_tower.stages: |
| x = blk(x) |
| image_forward_out.append(x) |
| return image_forward_out |
|
|
| @torch.no_grad() |
| def forward(self, images): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) |
| image_features.append(image_feature) |
| else: |
| image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype)) |
| image_features = [image_features[idx] for idx in self.fpn_layer_idx] |
|
|
| return image_features |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return next(self.vision_tower.parameters()).dtype |
|
|
| @property |
| def device(self): |
| return next(self.vision_tower.parameters()).device |
|
|
| @property |
| def config(self): |
| assert NotImplementedError |
| pass |
|
|
| @property |
| def num_attention_heads(self): |
| |
| return 16 |
| @property |
| def num_layers(self): |
| |
| return 4 |
| @property |
| def hidden_size(self): |
| return self.vision_tower.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2 |
|
|
| if __name__ == '__main__': |
| COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)] |
| COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)] |
|
|
| |
| combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD) |
| x = torch.randn(1, 3, 256, 256).cuda() |
| a = normalize_clip(x).to(torch.bfloat16) |
| b = normalize_siglip(x).to(torch.bfloat16) |
| c = denormalize_siglip(b.to(torch.float32)) |
| c2 = normalize_clip(c).to(torch.bfloat16) |
| c3 = combined_normalize(b) |
| print((c-x).abs().max()) |
| print((c2-a).abs().max()) |
| print((c3-a).abs().max()) |
| from IPython import embed |
| embed() |
| exit() |