| import logging |
| import json |
| import torch |
| from torch import nn |
| from .config import InternVideo2Config, EasyDict |
| from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224 |
| from transformers.utils import logging |
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| class InternVideo2_Stage2(nn.Module): |
| """docstring for InternVideo2_Stage2""" |
|
|
| def __init__(self, config, is_pretrain=True): |
| super(InternVideo2_Stage2, self).__init__() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| self.config = config |
|
|
| self.is_pretrain = is_pretrain |
| self.vision_width = config.model.vision_encoder.clip_embed_dim |
| |
| self.embed_dim = config.model.embed_dim |
|
|
| |
| self.vision_encoder = self.build_vision_encoder() |
| if config.model.get("freeze_vision", False): |
| self.freeze_vision() |
|
|
| self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) |
|
|
| self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) |
| self.uta_image_only = config.criterion.get('uta_image_only', False) |
|
|
| |
|
|
| def freeze_vision(self): |
| """freeze vision encoder""" |
| for p in self.vision_encoder.parameters(): |
| p.requires_grad = False |
| |
| def no_weight_decay(self): |
| ret = {"temp"} |
| ret.update( |
| {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} |
| ) |
| |
| |
| |
|
|
| return ret |
|
|
| @property |
| def dtype(self): |
| return self.vision_encoder.patch_embed.proj.weight.dtype |
|
|
| def encode_vision(self, image): |
| """encode image / videos as features. |
| |
| Args: |
| image (torch.Tensor): The input images. Shape(B, N, C, H, W) |
| test (bool): Whether testing. |
| |
| Returns: tuple. |
| - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. |
| - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. |
| - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. |
| - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. |
| |
| """ |
| T = image.shape[1] |
| use_image = True if T == 1 else False |
| image = image.permute(0, 2, 1, 3, 4) |
| |
| |
| vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( |
| image, None, use_image) |
| return vision_embeds, pooled_vision_embeds |
|
|
| def build_vision_encoder(self): |
| """build vision encoder |
| Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. |
| |
| """ |
| encoder_name = self.config.model.vision_encoder.name |
| |
| if encoder_name == 'pretrain_internvideo2_1b_patch14_224': |
| vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model) |
| elif encoder_name == 'pretrain_internvideo2_6b_patch14_224': |
| vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model) |
| else: |
| raise ValueError(f"Not implemented: {encoder_name}") |
| return vision_encoder |
|
|
|
|