| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| from .modules import ConvBatchNormReLU, SFA |
| from .modules import * |
| from .position_encoding import * |
|
|
| import clip |
| import math |
| import sys |
|
|
| sys.path.append('../') |
| from utils.utils import * |
|
|
|
|
| class Simple_fusion(nn.Module): |
| def __init__(self, visual_dim=1024, text_dim=768, proj_dim=1024, jemb_drop_out=0.1, leaky=True): |
| super(Simple_fusion, self).__init__() |
| self.proj_dim = proj_dim |
| self.mapping_visu = ConvBatchNormReLU(visual_dim, proj_dim, 1, 1, 0, 1, leaky=leaky) |
| self.lang_attn = nn.Sequential( |
| nn.Linear(text_dim, text_dim), |
| nn.Tanh(), |
| nn.Dropout(jemb_drop_out), |
| nn.Softmax(dim=1)) |
| |
| self.lang_proj = nn.Sequential( |
| nn.Linear(text_dim, proj_dim), |
| nn.BatchNorm1d(proj_dim), |
| nn.LeakyReLU(0.1)) |
|
|
| self.fusion = nn.Sequential( |
| nn.BatchNorm2d(proj_dim), |
| nn.LeakyReLU(0.1)) |
| |
| def forward(self, visual_feat, lang_feat): |
| |
| visual_feat_proj = self.mapping_visu(visual_feat) |
| |
| """ |
| # lang attn |
| lang_feat_attn = self.lang_attn(lang_feat) #[bt, 15, 768] |
| lang_feat_new = lang_feat * lang_feat_attn |
| lang_feat_new = lang_feat_new.sum(dim=1) #[bt, 768] |
| """ |
|
|
| lang_feat = lang_feat.squeeze(1) |
| |
| |
| lang_feat_new = self.lang_proj(lang_feat) |
|
|
| |
| h, w = visual_feat.shape[-2], visual_feat.shape[-1] |
| lang_feat_new_tile = lang_feat_new.view(-1, self.proj_dim, 1, 1).repeat(1, 1, h, w) |
| fusion_feat = lang_feat_new_tile * visual_feat_proj |
| fusion_feat = self.fusion(fusion_feat) |
| return fusion_feat |
|
|
| class up_proj_cat_proj(nn.Module): |
| def __init__(self, input_1, input_2, do=512, leaky=True): |
| super(up_proj_cat_proj, self).__init__() |
| self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky) |
| self.proj2 = ConvBatchNormReLU(input_1+input_2, do, 1, 1, 0, 1, leaky=leaky) |
| |
| def forward(self, x, y): |
| x = F.interpolate(x, scale_factor=2, mode='nearest') |
| y = self.proj1(y) |
| out = torch.cat([x,y], dim=1) |
| out = self.proj2(out) |
| return out |
|
|
| class pool_proj_cat_proj(nn.Module): |
| def __init__(self, input_1, input_2, do=512, leaky=True): |
| super(pool_proj_cat_proj, self).__init__() |
| self.downsample = nn.AvgPool2d(2, 2) |
| self.proj1 = ConvBatchNormReLU(input_2, do // 2, 1, 1, 0, 1, leaky=leaky) |
| self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky) |
| self.proj3 = ConvBatchNormReLU(input_1+do, do, 1, 1, 0, 1, leaky=leaky) |
|
|
| def forward(self, x, y): |
| y = self.downsample(y) |
| y = self.proj1(y) |
| y = self.proj2(y) |
| output = self.proj3(torch.cat([x,y], dim=1)) |
| return output |
|
|
| class proj_cat_proj(nn.Module): |
| def __init__(self, input_1, input_2, do=512, leaky=True): |
| super(proj_cat_proj, self).__init__() |
| self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky) |
| self.proj2 = ConvBatchNormReLU(input_1 + input_2, do, 1, 1, 0, 1, leaky=leaky) |
| |
| def forward(self, x, y): |
| y = self.proj1(y) |
| out = torch.cat([x, y], dim=1) |
| out = self.proj2(out) |
| return out |
|
|
| class proj_cat(nn.Module): |
| def __init__(self, input_1, input_2, do=512, leaky=True): |
| super(proj_cat, self).__init__() |
| self.proj1 = ConvBatchNormReLU(input_1, do // 2, 1, 1, 0, 1, leaky=leaky) |
| self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky) |
|
|
| def forward(self, x, y): |
| x = self.proj1(x) |
| x = self.proj2(x) |
| output = torch.cat([x,y], dim=1) |
| return output |
|
|
| class mask_decoder(nn.Module): |
| def __init__(self, input_1, seg_out_stride=2, leaky=True): |
| super(mask_decoder, self).__init__() |
| self.proj1 = ConvBatchNormReLU(input_1, input_1//2, 3, 1, 1, 1, leaky=leaky) |
| self.proj2 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
|
|
| self.proj3 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
| self.proj4 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
| self.proj5 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
| |
| self.proj = nn.Conv2d(input_1//2, 32, 3, 1, 1, 1) |
|
|
| def forward(self, x, seg_out_stride): |
| x = self.proj1(x) |
| x = self.proj2(x) |
|
|
|
|
| if seg_out_stride <= 8: |
| x = F.interpolate(x, scale_factor=2, mode='nearest') |
| x = self.proj3(x) |
|
|
| if seg_out_stride <= 4: |
| x = F.interpolate(x, scale_factor=2, mode='nearest') |
| x = self.proj4(x) |
|
|
| if seg_out_stride <= 2: |
| x = F.interpolate(x, scale_factor=2, mode='nearest') |
| x = self.proj5(x) |
|
|
| x = self.proj(x) |
| |
| return x |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
| class ResidualAttentionblk(nn.Module): |
| def __init__(self, clip_module): |
| super().__init__() |
|
|
| self.clip_module = clip_module |
|
|
| self.selected_tokens = int(676 * 0.8) |
|
|
| |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, lang_tokens=None, index=0): |
|
|
|
|
| if lang_tokens is None: |
| x = x + self.clip_module.attention(self.clip_module.ln_1(x)) |
| else: |
|
|
| |
| |
| |
| |
| |
| |
|
|
| N, B, C = x.shape |
| cls_x = x[:1, :, :] |
| x = x[1:, :, :] |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| score = torch.bmm(x.transpose(0, 1), lang_tokens.permute(1, 2, 0)).mean(dim=-1) |
| score = score.transpose(0, 1) |
|
|
| sorted_scores, sorted_indices = torch.sort(score, descending=True, dim=0) |
|
|
| |
| high_mask = torch.ones_like(sorted_scores) |
| for i in range(B): |
| high_mask[sorted_indices[self.selected_tokens:, i], i] = 0 |
| high_mask = high_mask > 0.5 |
|
|
| delta_x = x[high_mask].reshape(-1, B, C) |
| low_x = x[~high_mask].reshape(-1, B, C) |
| low_score = score[~high_mask].reshape(-1, B, 1) |
|
|
| low_x = low_x * torch.softmax(low_score, dim=0) |
| low_x = low_x.sum(dim=0, keepdim=True) |
|
|
| delta_x = torch.cat([cls_x, delta_x, low_x], dim=0) |
| delta_x = self.clip_module.attention(self.clip_module.ln_1(delta_x)) |
|
|
| |
| |
| |
| |
| temple = torch.zeros_like(x).type(delta_x.type()) |
| temple[high_mask] = delta_x[1:-1, :, :].reshape(-1, C) |
| temple[~high_mask] = delta_x[-1:, :, :].reshape(-1, 1, C).repeat(1, 676 - self.selected_tokens, 1).reshape(-1, C) |
| x = x + temple |
| cls_x = cls_x + delta_x[:1, :, :] |
|
|
| x = torch.cat([cls_x, x], dim=0) |
|
|
| x = x + self.clip_module.mlp(self.clip_module.ln_2(x)) |
| return x |
|
|
| class Model(nn.Module): |
| def __init__(self, clip_model='RN50', tunelang=False, fusion_dim=2048, num_query=16, do=512, leaky=True, length=17): |
| super(Model, self).__init__() |
|
|
| self.tunelang = tunelang |
| self.length = length |
|
|
| |
| clip_models = clip.load(clip_model, jit=False, device=torch.device("cpu"))[0].cuda() |
|
|
| self.visumodel = clip_models.visual |
| self.visu_dim = 768 |
|
|
| self.cut_list = [] |
| self.visu_resblocks = nn.ModuleList([ResidualAttentionblk(self.visumodel.transformer.resblocks[i]) for i in range(12)]) |
| self.visu_proj = nn.ModuleList([nn.Linear(do, self.visu_dim) for _ in range(len(self.cut_list))]) |
|
|
| self.positional_embedding = nn.Parameter(torch.FloatTensor(1, 26 ** 2 + 1, 768)) |
| v = self.resize_pos_embed(self.visumodel.positional_embedding.data.unsqueeze(0), self.positional_embedding, 26, 26) |
| self.positional_embedding.data.copy_(v) |
|
|
| self.textmodel = clip_models.transformer |
| self.textmodel_token_embedding = clip_models.token_embedding |
| self.textmodel_pos_embed = nn.Parameter(clip_models.positional_embedding[:self.length, :].unsqueeze(0)) |
| self.textmodel_ln_final = clip_models.ln_final |
| self.textdim = self.textmodel_pos_embed.shape[-1] |
| for module in self.textmodel.resblocks: |
| module.attn_mask = self.build_attention_mask() |
|
|
| |
| self.vis_select = nn.Linear(self.visu_dim, do, bias=False) |
|
|
| |
|
|
| |
| self.fusion = Simple_fusion(visual_dim=self.visu_dim, text_dim=self.textdim, proj_dim=fusion_dim) |
|
|
| |
| self.up_proj_cat_proj_1 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=fusion_dim) |
| self.pool_proj_cat_proj_2 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=do) |
| |
| |
| self.proj_cat = proj_cat(input_1=fusion_dim, input_2=do, do=do) |
| self.up_proj_cat_2 = proj_cat_proj(input_1=fusion_dim, input_2=do * 2, do=do) |
| self.proj_0 = ConvBatchNormReLU(do, do, 1, 1, 0, 1, leaky=leaky) |
|
|
| self.fpn = SFA(in_channels=self.visu_dim, out_channels=do) |
|
|
| |
| f_dim = 512 |
| self.fc_2 = nn.Linear(f_dim, f_dim, bias=False) |
| self.norm1 = nn.LayerNorm(f_dim) |
| self.norm2 = nn.LayerNorm(f_dim) |
| |
| |
| self.pos_embedding = PositionEmbeddingSine(f_dim) |
| encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim, |
| dropout=0.1, activation='relu', normalize_before=False) |
| self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim)) |
|
|
| |
| self.mask_decoder = mask_decoder(f_dim, seg_out_stride=2) |
|
|
| |
| |
| |
| self.lang_tf_enc = lang_tf_enc(do, do, do, head_num=8) |
| self.proj1 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky) |
| self.proj2 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky) |
| self.proj3 = nn.Conv2d(do, 32, 3, 1, 1, 1) |
| self.projout = nn.Linear(26*26*32, 32, bias=False) |
| |
| |
| self.feature_selector_l = nn.Linear(do, 1, bias=True) |
| self.feature_selector_m = nn.Linear(do, 1, bias=True) |
|
|
| def resize_pos_embed(self, posemb, posemb_new, hight, width): |
| ntok_new = posemb_new.shape[1] |
|
|
| posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] |
| ntok_new -= 1 |
|
|
| gs_old = int(math.sqrt(len(posemb_grid))) |
| print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) |
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) |
| posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') |
| posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) |
| posemb = torch.cat([posemb_token, posemb_grid], dim=1) |
| return posemb |
|
|
|
|
| def build_attention_mask(self): |
| |
| |
| mask = torch.empty(self.length, self.length) |
| mask.fill_(float("-inf")) |
| mask.triu_(1) |
| return mask |
|
|
| def forward(self, image, word_id, word_mask): |
| |
|
|
| batch_size = image.size(0) |
|
|
| |
| x = self.visumodel.conv1(image) |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
| x = torch.cat([self.visumodel.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
| x = x + self.positional_embedding.to(x.dtype) |
| x = self.visumodel.ln_pre(x) |
| x = x.permute(1, 0, 2) |
|
|
| raw_fword = self.textmodel_token_embedding(word_id).squeeze(1) |
| raw_fword = raw_fword + self.textmodel_pos_embed |
| raw_fword = raw_fword.permute(1, 0, 2) |
| |
| visu_list_l = [] |
| visu_list_m = [] |
| |
| scores_l = [] |
| scores_m = [] |
|
|
| for i, [blk_visu, blk_lang] in enumerate(zip(self.visu_resblocks, self.textmodel.resblocks)): |
| x = blk_visu(x) |
| raw_fword = blk_lang(raw_fword) |
|
|
| img_cls = self.vis_select(x[0, :, :]) |
| tex_cls = raw_fword[word_id.argmax(dim=-1).reshape(-1), torch.arange(raw_fword.shape[1]), :] |
| score = img_cls * tex_cls |
| score = score.unsqueeze(1) |
| |
| if i >=3 and i <= 5: |
| visu_list_l.append(x) |
| scores_l.append(score) |
|
|
| if i>=6 and i <=8: |
| visu_list_m.append(x) |
| scores_m.append(score) |
|
|
|
|
| scores_l = torch.cat(scores_l, dim=1) |
| scores_m = torch.cat(scores_m, dim=1) |
|
|
| scores_l = self.feature_selector_l(scores_l).squeeze(-1) |
| scores_l = F.softmax(scores_l, dim=-1) |
| scores_m = self.feature_selector_m(scores_m).squeeze(-1) |
| scores_m = F.softmax(scores_m, dim=-1) |
|
|
| visu_list_l = torch.cat(visu_list_l, dim=0).reshape(len(visu_list_l), -1, batch_size, self.visu_dim).permute(0,2,1,3) |
| visu_list_m = torch.cat(visu_list_m, dim=0).reshape(len(visu_list_m), -1, batch_size, self.visu_dim).permute(0,2,1,3) |
|
|
|
|
| x6 = visu_list_l[scores_l.argmax(dim=-1).reshape(-1), torch.arange(visu_list_l.shape[1]), :, :].permute(1,0,2) |
| x9 = visu_list_m[scores_m.argmax(dim=-1).reshape(-1), torch.arange(visu_list_m.shape[1]), :, :].permute(1,0,2) |
| |
| |
| x6 = x6.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) |
| x9 = x9.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) |
| x12 = x.permute(1, 0, 2)[:, 1:, :] |
| x12 = x12.reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) |
|
|
|
|
| raw_fword = raw_fword.permute(1, 0, 2) |
| raw_fword = self.textmodel_ln_final(raw_fword) |
| |
| if not self.tunelang: |
| raw_fword = raw_fword.detach() |
|
|
| eos_token = raw_fword[torch.arange(raw_fword.shape[0]), word_id.argmax(dim=-1).reshape(-1), :] |
|
|
| F_g = self.fusion(x12, eos_token) |
| F_tf = self.fpn([F_g, x9, x6]) |
|
|
| |
| b, c, h, w = F_tf.shape |
|
|
| flatten_length = h*w |
| visu_feat = F_tf.reshape(b, c, flatten_length) |
| visu_feat = F.relu(visu_feat) |
| lang_feat = F.relu(self.fc_2(raw_fword)) |
|
|
| visu_feat = visu_feat.permute(0, 2, 1) |
| pos_embed = self.pos_embedding(visu_feat) |
| visu_feat = visu_feat.transpose(0, 1) |
| pos_embed = pos_embed.transpose(0, 1) |
| visu_feat = self.encoder(visu_feat, pos=pos_embed) |
| |
| |
| visu_feat_ = visu_feat.permute(1,0,2) |
|
|
| |
| visu_feat = visu_feat.reshape(h, w, b, c) |
| visu_feat = visu_feat.permute(2,3,0,1) |
| proto_masks = self.mask_decoder(visu_feat, 2) |
|
|
| |
| proto_masks = F.relu(proto_masks) |
|
|
| |
| coef = self.lang_tf_enc(visu_feat_, lang_feat) |
| coef = coef.view(b, h, w, c) |
| coef = coef.permute(0, 3, 1, 2) |
|
|
| coef = self.proj1(coef) |
| coef = self.proj2(coef) |
| coef = self.proj3(coef) |
| coef = coef.permute(0, 2, 3, 1) |
| coef = coef.contiguous().view(b, h*w*32) |
| |
| coef = self.projout(coef).unsqueeze(-1) |
| coef = F.tanh(coef) |
| |
| |
| proto_masks = proto_masks.permute(0, 2, 3, 1) |
| proto_masks = proto_masks.view(b, -1, 32) |
| |
|
|
| mask_out = torch.bmm(proto_masks, coef, out=None) |
| mask_out = mask_out.view(b, 208, 208, 1) |
| mask_out = mask_out.permute(0, 3, 1, 2) |
| return mask_out |
| |