| import torch |
| import torch.nn as nn |
| import re |
| import math |
| from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig |
|
|
|
|
| def build_vision_tower(): |
| vision_tower = '/mnt/petrelfs/share_data/dongxiaoyi/share_models/clip_l_336' |
| return CLIPVisionTower(vision_tower) |
|
|
|
|
| def build_vision_projector(): |
| projector_type = 'mlp2x_gelu' |
| mm_hidden_size = 1024 |
| hidden_size = 4096 |
|
|
| mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
| if mlp_gelu_match: |
| mlp_depth = int(mlp_gelu_match.group(1)) |
| modules = [nn.Linear(mm_hidden_size, hidden_size)] |
| for _ in range(1, mlp_depth): |
| modules.append(nn.GELU()) |
| modules.append(nn.Linear(hidden_size, hidden_size)) |
| return nn.Sequential(*modules) |
|
|
| if projector_type == 'identity': |
| return IdentityMap() |
|
|
| raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
| class IdentityMap(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
| @property |
| def config(self): |
| return {"mm_projector_type": 'identity'} |
|
|
|
|
| class CLIPVisionTower(nn.Module): |
| def __init__(self, vision_tower): |
| super().__init__() |
|
|
| self.is_loaded = False |
| self.is_resize_pos = False |
|
|
| self.vision_tower_name = vision_tower |
| self.select_layer = -1 |
| self.select_feature = 'patch' |
| self.load_model() |
| |
|
|
| def load_model(self): |
| self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) |
| self.vision_tower.requires_grad_(False) |
|
|
| self.is_loaded = True |
| def resize_pos(self): |
| pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight |
| pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0) |
| orig_size = 24 |
| new_size = 16 |
|
|
| if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1: |
| self.is_resize_pos = True |
| else: |
| embedding_size = pos_embed_checkpoint.shape[-1] |
| num_extra_tokens = 1 |
| new_num = new_size ** 2 + num_extra_tokens |
| print("Position interpolate from %dx%d to %dx%d" % |
| (orig_size, orig_size, new_size, new_size)) |
| extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| |
| pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, |
| embedding_size).permute( |
| 0, 3, 1, 2) |
| pos_tokens = torch.nn.functional.interpolate(pos_tokens, |
| size=(new_size, |
| new_size), |
| mode='bicubic', |
| align_corners=False) |
| pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
|
| new_pos_embed = new_pos_embed.squeeze(0) |
|
|
| self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024) |
| |
| |
|
|
| self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(new_pos_embed.to(pos_embed_checkpoint.device).to(pos_embed_checkpoint.dtype)) |
| self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1)).to(pos_embed_checkpoint.device) |
| self.is_resize_pos = True |
|
|
| def feature_select(self, image_forward_outs): |
| image_features = image_forward_outs.hidden_states[self.select_layer] |
| if self.select_feature == 'patch': |
| image_features = image_features[:, 1:] |
| elif self.select_feature == 'cls_patch': |
| image_features = image_features |
| else: |
| raise ValueError(f'Unexpected select feature: {self.select_feature}') |
| return image_features |
|
|
| def forward(self, images): |
| if not self.is_loaded: |
| self.load_model() |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) |
| image_feature = self.feature_select(image_forward_out).to(image.dtype) |
| image_features.append(image_feature) |
| else: |
| image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) |
| image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
|
| return image_features |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device |
|
|
| @property |
| def config(self): |
| if self.is_loaded: |
| return self.vision_tower.config |
| else: |
| return self.cfg_only |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
| class PLoRA(nn.Linear): |
| def __init__(self, |
| in_features: int, |
| out_features: int, |
| bias: bool = True, |
| device=None, |
| dtype=None, |
| lora_r=8, |
| lora_alpha=16, |
| lora_dropout=0.05, |
| lora_len=0, |
| **kwargs) -> None: |
| super().__init__(in_features, out_features, bias, device, dtype) |
| self.lora_r = lora_r |
| self.lora_alpha = lora_alpha |
| self.lora_len = lora_len |
| if lora_dropout > 0.: |
| self.lora_dropout = nn.Dropout(p=lora_dropout) |
| else: |
| self.lora_dropout = lambda x: x |
| self.lora_scaling = self.lora_alpha / self.lora_r |
|
|
| self.Plora_A = nn.Linear(in_features, |
| self.lora_r, |
| bias=False, |
| device=device, |
| dtype=dtype) |
| self.Plora_B = nn.Linear(self.lora_r, |
| out_features, |
| bias=False, |
| device=device, |
| dtype=dtype) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| if hasattr(self, 'lora_A'): |
| |
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B.weight) |
| |
|
|
| def forward(self, x, im_mask=None): |
| res = super().forward(x) |
| if im_mask is not None: |
| if torch.sum(im_mask) > 0: |
| part_x = x[im_mask] |
| res[im_mask] += self.Plora_B(self.Plora_A( |
| self.lora_dropout(part_x))) * self.lora_scaling |
| else: |
| part_x = x[:, :1] |
| res[:, :1] += self.Plora_B(self.Plora_A( |
| self.lora_dropout(part_x))) * 0 |
| return res |
|
|