| |
|
|
| import math |
|
|
| import clip |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from nncore.nn import MODELS, build_loss, build_model |
|
|
| from .generator import PointGenerator |
|
|
| _CLIP_ARCHS = { |
| 'ViT-B/32': (768, 512, 50), |
| 'ViT-B/16': (768, 512, 197), |
| 'ViT-L/14': (1024, 768, 50), |
| 'ViT-L/14-336px': (1024, 768, 577) |
| } |
|
|
|
|
| @MODELS.register() |
| class R2Tuning(nn.Module): |
|
|
| def __init__(self, |
| arch='ViT-B/32', |
| init=True, |
| dims=256, |
| strides=(1, 2, 4, 8), |
| buffer_size=1024, |
| max_num_moment=50, |
| merge_cls_sal=True, |
| adapter_cfg=None, |
| pyramid_cfg=None, |
| pooling_cfg=None, |
| class_head_cfg=None, |
| coord_head_cfg=None, |
| loss_cfg=None): |
| super(R2Tuning, self).__init__() |
|
|
| if init: |
| self.clip, _ = clip.load(arch, device='cpu') |
| for param in self.clip.parameters(): |
| param.requires_grad = False |
|
|
| self.cfg = _CLIP_ARCHS[arch] |
| self.adapter = build_model(adapter_cfg, dims, self.cfg[:2]) |
| self.pyramid = build_model(pyramid_cfg, dims, strides) |
| self.pooling = build_model(pooling_cfg, dims) |
|
|
| self.class_head = build_model(class_head_cfg, dims, 1) |
| self.coord_head = build_model(coord_head_cfg, dims, 2) |
|
|
| self.generator = PointGenerator(strides, buffer_size) |
|
|
| self.coef = nn.Parameter(torch.ones(len(strides))) |
| self.loss = build_loss(loss_cfg) |
|
|
| self.max_num_moment = max_num_moment |
| self.merge_cls_sal = merge_cls_sal |
|
|
| def train(self, mode=True): |
| super(R2Tuning, self).train(mode=mode) |
| if hasattr(self, 'clip'): |
| self.clip.eval() |
|
|
| @torch.no_grad |
| def clip_video_tower(self, video): |
| video = video.type(self.clip.dtype) |
| video = self.clip.visual.conv1(video) |
| video = video.reshape(video.size(0), video.size(1), -1).permute(0, 2, 1) |
| c_emb = video.new_zeros(video.size(0), 1, video.size(-1)) |
| c_emb = self.clip.visual.class_embedding.to(video.dtype) + c_emb |
| video = torch.cat((c_emb, video), dim=1) |
| video = video + self.clip.visual.positional_embedding.to(video.dtype) |
| video = self.clip.visual.ln_pre(video).permute(1, 0, 2) |
| emb = [video] |
| for blk in self.clip.visual.transformer.resblocks: |
| emb.append(blk(emb[-1])) |
| video = torch.stack([e.permute(1, 0, 2) for e in emb]) |
| return video |
|
|
| @torch.no_grad |
| def clip_query_tower(self, query): |
| query = self.clip.token_embedding(query).type(self.clip.dtype) |
| query = query + self.clip.positional_embedding.type(self.clip.dtype) |
| query = query.permute(1, 0, 2) |
| emb = [query] |
| for blk in self.clip.transformer.resblocks: |
| emb.append(blk(emb[-1])) |
| query = torch.stack([e.permute(1, 0, 2) for e in emb]) |
| return query |
|
|
| def forward(self, data, mode='test'): |
| video, query = data['video'], data['query'] |
|
|
| if hasattr(self, 'clip'): |
| video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0) |
| query_msk = torch.where(query == 0, 0, 1) |
|
|
| video[~video.isfinite()] = 0 |
|
|
| (b, t), d = video.size()[:2], int(math.sqrt(video.size(2) / 3)) |
| video = video.view(b * t, 3, d, d) |
|
|
| video_emb = self.clip_video_tower(video) |
| query_emb = self.clip_query_tower(query) |
|
|
| n, _, p, c = video_emb.size() |
| video_emb = video_emb.view(n, b, t, p, c) |
| else: |
| video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0) |
| query_msk = torch.where(query[:, :, 0].isfinite(), 1, 0) |
|
|
| video[~video.isfinite()] = 0 |
| query[~query.isfinite()] = 0 |
|
|
| (b, t), l = video.size()[:2], query.size(1) |
| video = video.view(b, t, -1, self.cfg[2], self.cfg[0]).permute(2, 0, 1, 3, 4) |
| query = query.view(b, l, -1, self.cfg[1]).permute(2, 0, 1, 3) |
|
|
| video_emb = video.float() |
| query_emb = query.float() |
|
|
| |
| |
|
|
| video_emb, query_emb, coll_v, coll_q = self.adapter(video_emb, query_emb, |
| video_msk, query_msk) |
|
|
| pymid, pymid_msk = self.pyramid(video_emb, video_msk, return_mask=mode != 'test') |
| point = self.generator(pymid) |
|
|
| with torch.autocast('cuda', enabled=False): |
| video_emb = video_emb.float() |
| query_emb = self.pooling(query_emb.float(), query_msk) |
|
|
| out_class = [self.class_head(e.float()) for e in pymid] |
| out_class = torch.cat(out_class, dim=1) |
|
|
| if self.coord_head is not None: |
| out_coord = [ |
| self.coord_head(e.float()).exp() * self.coef[i] |
| for i, e in enumerate(pymid) |
| ] |
| out_coord = torch.cat(out_coord, dim=1) |
| else: |
| out_coord = None |
|
|
| output = dict(_avg_factor=b) |
|
|
| if mode != 'test': |
| data['coll_v'] = [e.float() for e in coll_v] |
| data['coll_q'] = [self.pooling(e.float(), query_msk) for e in coll_q] |
|
|
| data['point'] = point |
| data['video_emb'] = video_emb |
| data['query_emb'] = query_emb |
| data['video_msk'] = video_msk |
| data['pymid_msk'] = pymid_msk |
| data['out_class'] = out_class |
| data['out_coord'] = out_coord |
|
|
| output = self.loss(data, output) |
|
|
| if mode != 'train': |
| assert b == 1, 'batch size larger than 1 is not supported for inference' |
| out_class = out_class.sigmoid() |
| out_score = F.cosine_similarity(video_emb, query_emb, dim=-1) |
|
|
| output['_out'] = dict(label=data.get('label', [None])[0]) |
|
|
| pyd_shape = [e.size(1) for e in pymid] |
| pyd_class = out_class[0, :, 0].split(pyd_shape) |
|
|
| saliency = [] |
| for shape, score in zip(pyd_shape, pyd_class): |
| if t >= shape: |
| score = score.repeat_interleave(int(t / shape)) |
| postfix = score[-1:].repeat(t - score.size(0)) |
| score = torch.cat((score, postfix)) |
| else: |
| scale = int(shape / t) |
| score = F.max_pool1d(score.unsqueeze(0), scale, stride=scale)[0] |
| saliency.append(score) |
|
|
| saliency = torch.stack(saliency).amax(dim=0) |
|
|
| if self.merge_cls_sal: |
| saliency *= out_score[0] |
|
|
| output['_out']['saliency'] = saliency |
|
|
| if self.coord_head is not None: |
| boundary = out_coord[0] |
| boundary[:, 0] *= -1 |
| boundary *= point[:, 3, None].repeat(1, 2) |
| boundary += point[:, 0, None].repeat(1, 2) |
| boundary /= data['fps'][0] |
| boundary = torch.cat((boundary, out_class[0]), dim=-1) |
|
|
| _, inds = out_class[0, :, 0].sort(descending=True) |
| boundary = boundary[inds[:self.max_num_moment]] |
|
|
| output['_out']['boundary'] = boundary |
|
|
| return output |
|
|