| import math |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class KeyPhraseAlignmentLoss(nn.Module): |
|
|
| def __init__( |
| self, |
| hidden_dim=768, |
| use_vision_cls_token=True, |
| attn_temperature=None, |
| loss_temperature=0.07, |
| text_features_l2_norm=False, |
| mpnce_row_sum=False, |
| mpnce_col_sum=False, |
| sim_op="cos", |
| use_layer_norm=True, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| self.hidden_dim = hidden_dim |
| self.layer_norm = nn.LayerNorm(hidden_dim) if use_layer_norm else None |
|
|
| self.use_vision_cls_token = use_vision_cls_token |
| self.loss_temperature = nn.Parameter( |
| torch.FloatTensor([np.log(loss_temperature)]) |
| ) |
| if attn_temperature is not None: |
| self.attn_temperature = nn.Parameter( |
| torch.FloatTensor([np.log(attn_temperature)]) |
| ) |
| else: |
| self.attn_temperature = None |
| self.text_features_l2_norm = text_features_l2_norm |
| self.sim_op = sim_op |
|
|
| self.similarity_logit = SimilarityLogit(sim_op) |
|
|
| self.mpnce_row_sum = mpnce_row_sum |
| self.mpnce_col_sum = mpnce_col_sum |
|
|
| def forward( |
| self, |
| key_phrases, |
| vision_tokens, |
| forward_text_model, |
| ddp_gather=True, |
| need_attn_weights=False, |
| compute_loss=True, |
| **kwargs, |
| ): |
| outputs = {} |
|
|
| text_features, group_map = self.compute_text_features( |
| key_phrases, forward_text_model, ddp_gather |
| ) |
|
|
| if ddp_gather and dist.is_initialized(): |
| vision_tokens = torch.cat(dist.nn.all_gather(vision_tokens), dim=0) |
|
|
| if self.layer_norm is not None: |
| vision_tokens = self.layer_norm(vision_tokens) |
|
|
| vision_patch_tokens = vision_tokens[:, 1:] |
|
|
| |
| if not self.use_vision_cls_token: |
| vision_attn_tokens = vision_patch_tokens |
| else: |
| vision_attn_tokens = vision_tokens |
|
|
| t2i_logits, t2i_attn_weights_list = self.compute_t2i_logits( |
| text_features, vision_attn_tokens, need_attn_weights |
| ) |
| outputs["t2i_logits"] = t2i_logits |
| outputs["t2i_attn_weights"] = t2i_attn_weights_list |
|
|
| if compute_loss: |
| losses = {} |
| loss = 0 |
|
|
| |
| t2i_loss = multi_positive_nce_loss( |
| t2i_logits, |
| group_map, |
| temperature=self.loss_temperature.exp(), |
| row_sum=self.mpnce_row_sum, |
| col_sum=self.mpnce_col_sum, |
| ) |
| loss += t2i_loss |
| losses["t2i_loss"] = t2i_loss |
|
|
| losses["loss"] = loss |
| outputs["losses"] = losses |
| return outputs |
|
|
| def compute_text_features(self, key_phrases, forward_text_model, ddp_gather=True): |
|
|
| key_text_features_list = list() |
| group_list = list() |
|
|
| B_local = len(key_phrases) |
| |
| local_rank = dist.get_rank() if (ddp_gather and dist.is_initialized()) else 0 |
|
|
| for i, kp in enumerate(key_phrases): |
| feats = forward_text_model(kp) |
|
|
| |
| if self.text_features_l2_norm: |
| feat = feats["text_features"] |
| else: |
| feat = feats["text_features_wo_l2_norm"] |
|
|
| if feat.shape[-1] == 2 * self.hidden_dim: |
| feat = feat[:, self.hidden_dim :] |
|
|
| key_text_features_list.append(feat) |
|
|
| |
| global_index = i + local_rank * B_local |
| group_list.extend([global_index] * feat.size(0)) |
|
|
| text_features = torch.cat(key_text_features_list, dim=0) |
| group_map = torch.tensor(group_list, device=text_features.device) |
|
|
| if ddp_gather and dist.is_initialized(): |
| |
| text_features = pad_and_gather(text_features) |
|
|
| group_map = pad_and_gather(group_map) |
| group_map = group_map.long() |
|
|
| if self.layer_norm is not None: |
| text_features = self.layer_norm(text_features) |
|
|
| return text_features, group_map |
|
|
| def compute_t2i_logits( |
| self, text_features, vision_attn_tokens, need_attn_weights, repeat=True |
| ): |
|
|
| t2i_logits, t2i_attn_weights_list = self.similarity_logit( |
| text_features, |
| vision_attn_tokens, |
| need_attn_weights, |
| repeat=repeat, |
| temperature=( |
| self.attn_temperature.exp() |
| if self.attn_temperature is not None |
| else self.loss_temperature.exp() |
| ), |
| ) |
|
|
| return t2i_logits, t2i_attn_weights_list |
|
|
|
|
| class SimilarityLogit(nn.Module): |
| def __init__(self, sim_op="dot", **kwargs): |
| super().__init__() |
| self.sim_op = sim_op |
|
|
| def forward( |
| self, |
| queries: torch.Tensor, |
| local_tokens: torch.Tensor, |
| need_attn_weights: bool = False, |
| repeat: bool = True, |
| **kwargs, |
| ): |
| if repeat: |
| query_attn_features = queries.unsqueeze(0).expand( |
| local_tokens.shape[0], queries.shape[0], queries.shape[1] |
| ) |
| else: |
| assert queries.dim() == 3 |
| query_attn_features = queries |
|
|
| if self.sim_op == "cos": |
| temperature = kwargs.get("temperature") |
| assert temperature is not None |
| denominator = temperature |
| query_attn_features = F.normalize(query_attn_features, p=2, dim=-1) |
| local_tokens = F.normalize(local_tokens, p=2, dim=-1) |
| elif self.sim_op == "dot": |
| denominator = math.sqrt(local_tokens.size(-1)) |
| else: |
| raise NotImplementedError |
|
|
| scores = ( |
| torch.bmm(query_attn_features, local_tokens.permute(0, 2, 1)) / denominator |
| ) |
| attn_weights = F.softmax(scores, dim=-1) |
|
|
| aggregated = torch.matmul(attn_weights, local_tokens) |
|
|
| query_attn_features = F.normalize(query_attn_features, p=2, dim=-1) |
| aggregated = F.normalize(aggregated, p=2, dim=-1) |
|
|
| logits = torch.matmul( |
| query_attn_features.unsqueeze(2), aggregated.unsqueeze(-1) |
| ).squeeze() |
|
|
| logits = logits.T |
|
|
| if need_attn_weights: |
| attn_scores = [scores] |
| else: |
| attn_scores = None |
|
|
| return logits, attn_scores |
|
|
|
|
| def multi_positive_nce_loss( |
| logits: torch.Tensor, |
| group_map: torch.Tensor, |
| temperature: float = 1.0, |
| eps: float = 1e-8, |
| row_sum: bool = False, |
| col_sum: bool = False, |
| ): |
| """ |
| Args: |
| logits: tensor of shape (N_total, B_global), each row is a logit between a key phrase and each candidate image. |
| group_map: tensor of shape (N_total,), source image index of each key phrase. |
| temperature: scaling factor. |
| |
| For each key phrase row i, the positive is the candidate image index == group_map[i], |
| and the rest are treated as negatives. |
| |
| For each column j, each positive for image j is considered independently. |
| |
| Returns: |
| loss: scalar tensor. |
| """ |
| scaled_logits = torch.exp(logits / temperature) |
|
|
| pos_logits = scaled_logits[ |
| torch.arange(scaled_logits.size(0)), group_map |
| ] |
|
|
| row_loss = get_row_loss( |
| scaled_logits, |
| pos_logits, |
| group_map, |
| eps, |
| row_sum, |
| ) |
|
|
| neg_mask = torch.ones_like(scaled_logits) |
| neg_mask[torch.arange(scaled_logits.size(0)), group_map] = 0 |
|
|
| column_loss = get_col_loss( |
| scaled_logits, |
| pos_logits, |
| neg_mask, |
| group_map, |
| eps, |
| col_sum, |
| ) |
|
|
| loss = (row_loss.mean() + column_loss.mean()) / 2 |
|
|
| return loss |
|
|
|
|
| def get_row_loss( |
| logits: torch.Tensor, |
| pos_logits: torch.Tensor, |
| group_map: torch.Tensor, |
| eps: float = 1e-8, |
| row_sum: bool = False, |
| ): |
| if row_sum: |
| |
| row_sum_logits = torch.zeros( |
| logits.shape[-1], device=logits.device |
| ) |
| row_pos_sum_logits = torch.zeros( |
| logits.shape[-1], device=logits.device |
| ) |
|
|
| |
| row_sum_logits.scatter_add_(0, group_map, logits.sum(dim=1)) |
| row_pos_sum_logits.scatter_add_(0, group_map, pos_logits) |
| p_row = row_pos_sum_logits / (row_sum_logits + eps) |
| else: |
| row_sum_logits = logits.sum(dim=1) |
| p_row = pos_logits / (row_sum_logits + eps) |
|
|
| return -torch.log(p_row + eps) |
|
|
|
|
| def get_col_loss( |
| logits: torch.Tensor, |
| pos_logits: torch.Tensor, |
| neg_mask: torch.Tensor, |
| group_map: torch.Tensor, |
| eps: float = 1e-8, |
| col_sum: bool = False, |
| ): |
| if col_sum: |
| |
| column_sum_logits = logits.sum(dim=0) |
| pos_mask = torch.ones_like(logits) - neg_mask |
| column_pos_logits = (logits * pos_mask).sum(dim=0) |
| p_column = column_pos_logits / (column_sum_logits + eps) |
| else: |
| |
| neg_logits = logits * neg_mask |
| sum_neg_logits = neg_logits.sum(dim=0) |
| sum_neg_logits = sum_neg_logits[group_map] |
| p_column = pos_logits / (pos_logits + sum_neg_logits + eps) |
|
|
| return -torch.log(p_column + eps) |
|
|
|
|
| def pad_and_gather(tensor): |
| |
| local_size = torch.tensor(tensor.size(), device=tensor.device) |
|
|
| |
| all_sizes = [torch.zeros_like(local_size) for _ in range(dist.get_world_size())] |
| dist.all_gather(all_sizes, local_size) |
|
|
| |
| max_size = torch.stack(all_sizes).max(dim=0)[0] |
|
|
| |
| padded_tensor = torch.zeros(max_size.tolist(), device=tensor.device) |
| padded_tensor[: local_size[0]] = tensor |
|
|
| |
| gathered_tensors = dist.nn.all_gather(padded_tensor) |
|
|
| |
| gathered_tensors = [g[: s[0]] for g, s in zip(gathered_tensors, all_sizes)] |
|
|
| gathered_tensors = torch.cat(gathered_tensors, dim=0) |
|
|
| return gathered_tensors |
|
|