| |
| |
| |
|
|
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Tuple |
| from torch import Tensor |
| from open_clip import get_tokenizer, create_model_from_pretrained |
| import torchvision.transforms as T |
| from .utils import imagenet_templates |
|
|
| OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
| class MaskClip(nn.Module): |
| def __init__( |
| self, |
| clip_model="ViT-B-16", |
| pretrained="laion2b_s34b_b88k", |
| patch_size=16, |
| img_size=(224, 224), |
| in_channels=768, |
| text_channels=512, |
| ): |
| super(MaskClip, self).__init__() |
|
|
| self.patch_size = patch_size |
| self.img_size = img_size |
| model, _ = create_model_from_pretrained(clip_model, pretrained=pretrained) |
| model.eval() |
| self.clip_T = OPENAI_NORMALIZE |
| self.hook_features = {} |
| self.backbone = model |
| def hook_fn_forward(module, input, output): |
| self.hook_features["v"] = output |
| self.backbone.visual.transformer.resblocks[-2].register_forward_hook(hook_fn_forward) |
| self._positional_embd = nn.Parameter(self.backbone.visual.positional_embedding.data.clone()) |
| self.proj = nn.Conv2d(in_channels, text_channels, 1, bias=False) |
| self.proj.weight = nn.Parameter(model.visual.proj.t()[:, :, None, None]) |
| self.tokenizer = get_tokenizer(clip_model) |
|
|
| @torch.no_grad() |
| def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: |
| """Extract features from images.""" |
| pos_embed = self.backbone.visual.positional_embedding |
|
|
| B, C, H, W = inputs.shape |
| hw_shape = (H // self.patch_size, W // self.patch_size) |
| x_len, pos_len = hw_shape[0]*hw_shape[1], pos_embed.shape[0] |
|
|
| if x_len != pos_len: |
| if pos_len == (self.img_size[0] // self.patch_size) * (self.img_size[1] // self.patch_size) + 1: |
| pos_h = self.img_size[0] // self.patch_size |
| pos_w = self.img_size[1] // self.patch_size |
| else: |
| raise ValueError( |
| '{}, {}'.format(x_len, pos_len)) |
|
|
| self.backbone.visual.positional_embedding.data = self.resize_pos_embed( |
| self._positional_embd[None], hw_shape, (pos_h, pos_w), 'bicubic')[0] |
|
|
| _ = self.backbone(inputs) |
| v = self.hook_features["v"] |
| v = self.extract_v(v, self.backbone.visual.transformer.resblocks[-1]).permute(1, 0, 2) |
| v = self.backbone.visual.ln_post(v) |
| |
| v = v.permute(1, 0, 2)[:, 1:] |
| v = v.reshape(B, hw_shape[0], hw_shape[1], -1).permute(0, 3, 1, 2).contiguous() |
|
|
| self.backbone.visual.positional_embedding.data = self._positional_embd |
| return v |
|
|
| @torch.no_grad() |
| def extract_v(self, x, block): |
| y = block.ln_1(x) |
| y = torch.nn.functional.linear(y, block.attn.in_proj_weight, block.attn.in_proj_bias) |
| B, N, C = y.shape |
| y = y.view(B, N, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * B, N, C // 3) |
| y = F.linear(y, block.attn.out_proj.weight, block.attn.out_proj.bias) |
| q, k, v = y.tensor_split(3, dim=0) |
| v += x |
| v += block.mlp(block.ln_2(v)) |
| return v |
|
|
|
|
| @staticmethod |
| def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): |
| """Resize pos_embed weights. |
| |
| Resize pos_embed using bicubic interpolate method. |
| Args: |
| pos_embed (torch.Tensor): Position embedding weights. |
| input_shpae (tuple): Tuple for (downsampled input image height, |
| downsampled input image width). |
| pos_shape (tuple): The resolution of downsampled origin training |
| image. |
| mode (str): Algorithm used for upsampling: |
| ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | |
| ``'trilinear'``. Default: ``'nearest'`` |
| Return: |
| torch.Tensor: The resized pos_embed of shape [B, L_new, C] |
| """ |
| assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' |
| pos_h, pos_w = pos_shape |
| cls_token_weight = pos_embed[:, 0] |
| pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] |
| pos_embed_weight = pos_embed_weight.reshape( |
| 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) |
| pos_embed_weight = F.interpolate( |
| pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) |
| cls_token_weight = cls_token_weight.unsqueeze(1) |
| pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) |
| pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) |
| return pos_embed |
| |
| @torch.no_grad() |
| def decode_head(self, x: Tensor) -> Tensor: |
| feat = self.proj(x) |
|
|
| return feat |
|
|
|
|
| @torch.no_grad() |
| def forward(self, inputs: Tensor) -> Tensor: |
| """Encode images with backbone and decode into a semantic segmentation |
| map of the same size as input.""" |
| inputs = self.clip_T(inputs) |
| x = self.extract_feat(inputs) |
| feats = self.decode_head(x) |
| return feats |
| |
|
|
| @torch.no_grad() |
| def get_classifier(self, classnames:List[str]) -> Tensor: |
| aug_embeddings = torch.stack([self._embed_label(label) for label in classnames]) |
| aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True) |
| return aug_embeddings.squeeze(1) |
|
|
|
|
| @torch.no_grad() |
| def _embed_label(self, label: str) -> Tensor: |
| """Encode label name into a single vector.""" |
| all_prompts = [self.tokenizer(template.format(label)) for template in imagenet_templates] |
| all_prompts = torch.cat(all_prompts) |
| all_prompts = all_prompts.to(self.backbone.visual.positional_embedding.device) |
| out = self.backbone.encode_text(all_prompts) |
| out /= out.norm(dim=-1, keepdim=True) |
| out = out.mean(dim=0) |
| return out |
|
|