| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from typing import Tuple |
|
|
| def compute_img_bkg_seg( |
| attentions, |
| feats, |
| featmap_dims, |
| th_bkg, |
| dim=64, |
| epsilon: float = 1e-10, |
| apply_weights: bool = True, |
| ) -> Tuple[torch.Tensor, float]: |
| """ |
| inputs |
| - attentions [B, ] |
| """ |
| |
| w_featmap, h_featmap = featmap_dims |
|
|
| nb, nh, _ = attentions.shape[:3] |
| |
| att = attentions[:, :, 0, 1:].reshape(nb, nh, -1) |
| att = att.reshape(nb, nh, w_featmap, h_featmap) |
|
|
| |
| |
| threshold = torch.mean(att.reshape(nb, -1), dim=1) |
| Q = torch.sum( |
| att.reshape(nb, nh, w_featmap * h_featmap) > threshold[:, None, None], axis=2 |
| ) / (w_featmap * h_featmap) |
| beta = torch.log(torch.sum(Q + epsilon, dim=1)[:, None] / (Q + epsilon)) |
|
|
| |
| descs = feats[:,1:,] |
| if apply_weights: |
| descs = (descs.reshape(nb, -1, nh, dim) * beta[:, None, :, None]).reshape( |
| nb, -1, nh * dim |
| ) |
| else: |
| descs = (descs.reshape(nb, -1, nh, dim)).reshape( |
| nb, -1, nh * dim |
| ) |
|
|
| |
| |
| descs = F.normalize(descs, dim=-1, p=2) |
| cos_sim = torch.bmm(descs, descs.permute(0, 2, 1)) |
|
|
| |
| |
| if apply_weights: |
| att = att.reshape(nb, nh, w_featmap, h_featmap) * beta[:, :, None, None] |
| else: |
| att = att.reshape(nb, nh, w_featmap, h_featmap) |
| id_pixel_ref = torch.argmin(torch.sum(att, axis=1).reshape(nb, -1), dim=-1) |
|
|
| |
| |
| cos_sim = cos_sim.reshape(nb, -1, w_featmap * h_featmap) |
|
|
| bkg_mask = ( |
| cos_sim[torch.arange(cos_sim.size(0)), id_pixel_ref, :].reshape( |
| nb, w_featmap, h_featmap |
| ) |
| > th_bkg |
| ) |
|
|
| return bkg_mask.float() |