| """Contrastive loss used during SAM2 + fusion training (config from Hydra `contrastive_learning`, tmp.code style).""" |
| import torch |
| from abc import ABC |
| import torch.nn as nn |
|
|
|
|
| class ContrastLoss(nn.Module, ABC): |
| def __init__(self, hyp_param): |
| super(ContrastLoss, self).__init__() |
| self.param = hyp_param |
| _defaults = { |
| "temperature": 0.10, |
| "ignore_idx": 255, |
| "ood_idx": 254, |
| "max_views": 512, |
| "proj_dim": 512, |
| "sample_limits": 64, |
| "total_limits": 15240, |
| } |
| _raw = getattr(hyp_param, "contrastive_learning", None) or {} |
| _cfg = {**_defaults, **_raw} |
| self.temperature = _cfg["temperature"] |
| self.ignore_idx = _cfg["ignore_idx"] |
| self.ood_idx = _cfg["ood_idx"] |
| self.max_views = _cfg["max_views"] |
| self.proj_dim = _cfg["proj_dim"] |
| self.sample_limits = _cfg["sample_limits"] |
| self.total_limits = _cfg["total_limits"] |
|
|
| def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx): |
| embedding_sample_list = [] |
| label_list = [] |
| embedding_sample_list_a = [] |
| label_list_a = [] |
| class_index_list = torch.unique(masks) |
| |
| if len(class_index_list) > 1: |
| for class_index in class_index_list[1:]: |
| embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) |
| label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3) |
| else: |
| embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) |
| label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3) |
|
|
| |
| |
| sample_limits = self.sample_limits |
| |
| embeddings = embeddings.permute(1, 0) |
| for class_index in class_index_list: |
| hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()] |
| easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()] |
|
|
| hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] |
|
|
| |
| selective_num_hard = min(sample_limits, hard_indices_num) |
| selective_num_easy = min(sample_limits, easy_indices_num) |
|
|
| if (selective_num_hard + selective_num_easy) < sample_limits * 2: |
| if selective_num_hard > selective_num_easy: |
| selective_num_hard += sample_limits * 2 - selective_num_easy |
| else: |
| selective_num_easy += sample_limits * 2 - selective_num_hard |
|
|
| |
| |
| |
| hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] |
| embedding_sample_list.append(hard_indices[hard_chosen_indices]) |
| label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3) |
|
|
| |
| easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] |
| embedding_sample_list.append(easy_indices[easy_chosen_indices]) |
| label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3) |
| return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a |
|
|
| def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions): |
| masks = masks.flatten(start_dim=1) |
| predictions = predictions.flatten(start_dim=1) |
| visual_embeddings = visual_embeddings.flatten(start_dim=-2) |
|
|
| visual_embedding_sample_list = [] |
| visual_label_list = [] |
| audio_embedding_sample_list = [] |
| audio_label_list = [] |
|
|
| for frame_idx in range(masks.shape[0]): |
| current_vision_feats = visual_embeddings[frame_idx] |
| current_masks = masks[frame_idx] |
| current_predictions = predictions[frame_idx] |
| current_audio_feats = audio_embeddings[frame_idx] |
| for layer_idx in range(3): |
| (selected_vision_embeddings, selected_vision_labels, |
| selected_audio_embeddings, selected_audio_labels) = self.select_class_wise_samples(current_vision_feats[layer_idx], |
| current_audio_feats[layer_idx], |
| current_predictions, |
| current_masks, |
| 0) |
|
|
| visual_embedding_sample_list += selected_vision_embeddings |
| visual_label_list += selected_vision_labels |
|
|
| audio_embedding_sample_list += selected_audio_embeddings |
| audio_label_list += selected_audio_labels |
|
|
| if len(visual_embedding_sample_list) == 0: return 0. |
| visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze() |
| visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1) |
| audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze() |
| audio_label_list = torch.cat(audio_label_list).unsqueeze(1) |
|
|
| |
| |
| |
| total_limits = self.total_limits |
| if visual_embedding_sample_list.shape[0] > total_limits: |
| rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits] |
| visual_embedding_sample_list = visual_embedding_sample_list[:rand_index] |
| visual_label_list = visual_label_list[:rand_index] |
| loss = self.info_nce(visual_embedding_sample_list, visual_label_list, audio_embedding_sample_list, |
| audio_label_list) |
| return loss |
|
|
|
|
| |
| |
| |
| def forward(self, embeddings, output_dicts, masks): |
| predictions = torch.cat([i['multistep_pred_masks'] for i in output_dicts]) |
| predictions = torch.nn.functional.interpolate(predictions, size=(int(self.param.image_size/16), int(self.param.image_size/16)), |
| mode='bilinear', align_corners=False).squeeze(1) |
| masks = torch.nn.functional.interpolate(masks.unsqueeze(1), size=(int(self.param.image_size/16), int(self.param.image_size/16)), |
| mode='nearest').squeeze(1) |
| visual_embeddings, audio_embeddings = embeddings |
| |
| |
| |
|
|
| visual_embeddings = torch.cat([torch.cat([visual_embeddings[0][i].unsqueeze(0), |
| visual_embeddings[1][i].unsqueeze(0), |
| visual_embeddings[2][i].unsqueeze(0)]).unsqueeze(0) |
| for i in range(masks.shape[0])]) |
| audio_embeddings = torch.cat([torch.cat([audio_embeddings[0][i].unsqueeze(0), |
| audio_embeddings[1][i].unsqueeze(0), |
| audio_embeddings[2][i].unsqueeze(0)]).unsqueeze(0) |
| for i in range(masks.shape[0])]) |
|
|
| |
| |
| |
| |
| return self.forward_audio_visual(visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| # embeddings_size = (int(self.param.image_size/16), int(self.param.image_size/16)) |
| # masks = torch.nn.functional.interpolate(masks.float(), embeddings_size, mode='nearest') |
| # masks = masks.flatten(start_dim=1) |
| # predictions = torch.nn.functional.interpolate(predictions.float(), embeddings_size, mode='nearest') |
| # predictions = predictions.flatten(start_dim=1) |
| # |
| # embedding_sample_list = [] |
| # label_list = [] |
| # contras_sample_list = [] |
| # contras_label_list = [] |
| |
| # temp3. |
| # embedding_visual, embedding_audio = embeddings |
| # embedding_visual = torch.nn.functional.normalize(embedding_visual, p=2, dim=1) |
| # embedding_audio = torch.nn.functional.normalize(embedding_audio, p=2, dim=1) |
| # embedding_visual = embedding_visual.reshape(self.param.batch_size, int(embedding_visual.shape[0]/self.param.batch_size), |
| # *embedding_visual.shape[-2:]) |
| # |
| # embedding_audio = embedding_audio.reshape(self.param.batch_size, int(embedding_audio.shape[0]/self.param.batch_size), |
| # *embedding_audio.shape[-2:]) |
| # masks = masks.reshape(self.param.batch_size, int(masks.shape[0]/self.param.batch_size), |
| # masks.shape[-1]) |
| # predictions = predictions.reshape(self.param.batch_size, int(predictions.shape[0]/self.param.batch_size), |
| # predictions.shape[-1]) |
| # |
| # for batch_idx in range(masks.shape[0]): |
| # current_video_clip_embed = embedding_visual[batch_idx] |
| # current_video_clip_masks = masks[batch_idx] |
| # current_video_clip_preds = predictions[batch_idx] |
| # current_audio_clip_embed = embedding_audio[batch_idx] |
| # # print(current_video_clip_embed.shape, current_audio_clip_embed.shape, current_video_clip_masks.shape, current_video_clip_preds.shape) |
| # # exit(1) |
| # for sample_idx in range(masks.shape[1]): |
| # current_vision_feats = current_video_clip_embed[batch_idx] |
| # current_audio_feats = current_audio_clip_embed[batch_idx] |
| # current_masks = current_video_clip_masks[batch_idx] |
| # current_predictions = current_video_clip_preds[batch_idx] |
| # current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats, |
| # current_audio_feats, |
| # current_predictions, |
| # current_masks, |
| # batch_idx) |
| |
| # temp2. |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
| |
| embeddings = embeddings.reshape(self.param.batch_size, int(embeddings.shape[0]/self.param.batch_size), |
| *embeddings.shape[-2:]) |
| masks = masks.reshape(self.param.batch_size, int(masks.shape[0]/self.param.batch_size), |
| masks.shape[-1]) |
| predictions = predictions.reshape(self.param.batch_size, int(predictions.shape[0]/self.param.batch_size), |
| predictions.shape[-1]) |
| |
| for batch_idx in range(masks.shape[0]): |
| current_video_clip_embed = embeddings[batch_idx] |
| current_video_clip_masks = masks[batch_idx] |
| current_video_clip_preds = predictions[batch_idx] |
| # current_audio_clip_feats = |
| for sample_idx in range(masks.shape[1]): |
| current_vision_feats = current_video_clip_embed[batch_idx] |
| current_masks = current_video_clip_masks[batch_idx] |
| current_predictions = current_video_clip_preds[batch_idx] |
| current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats, |
| current_predictions, |
| current_masks, |
| batch_idx) |
| embedding_sample_list += current_select_embeddings |
| label_list += current_select_labels |
| # hard_indices = current_vision_feats[(current_masks != current_predictions).nonzero()] |
| # easy_indices = current_vision_feats[(current_masks == current_predictions).nonzero()] |
| # |
| # hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] |
| # |
| # # the number that is selected to the contrastive learning. |
| # selective_num_hard = min(sample_limits, hard_indices_num) |
| # selective_num_easy = min(sample_limits, easy_indices_num) |
| # # skip if contains too limited samples. |
| # if selective_num_hard < 10 or selective_num_easy < 10: |
| # continue |
| # |
| # hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] |
| # embedding_sample_list.append(hard_indices[hard_chosen_indices]) |
| # label_list.append(current_masks[hard_chosen_indices] + batch_idx * 1e3) |
| # |
| # # add negative features to list. |
| # easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] |
| # embedding_sample_list.append(easy_indices[easy_chosen_indices]) |
| # label_list.append(current_masks[easy_chosen_indices] + batch_idx * 1e3) |
| |
| if len(embedding_sample_list) == 0: return 0. |
| embedding_sample_list = torch.cat(embedding_sample_list, dim=0).squeeze() |
| label_list = torch.cat(label_list, dim=0).unsqueeze(-1) |
| total_limits = self.total_limits |
| if embedding_sample_list.shape[0] > total_limits: |
| rand_index = torch.randperm(embedding_sample_list.shape[0])[total_limits] |
| embedding_sample_list = embedding_sample_list[:rand_index] |
| label_list = label_list[:rand_index] |
| loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, |
| label_list) |
| |
| # temp. |
| # sample_limits = 500 |
| # for batch_idx in range(masks.shape[0]): |
| # # go through 3 layers embeddings. |
| # for j in range(len(embeddings)): |
| # current_vision_feats_list = embeddings[j] |
| # current_vision_feats = torch.nn.functional.normalize(current_vision_feats_list[batch_idx], p=2, dim=1) |
| # current_masks = masks[batch_idx] |
| # positive_indices = current_vision_feats[current_masks > 0, ...] |
| # negative_indices = current_vision_feats[current_masks == 0, ...] |
| # positive_indices_num, negative_indices_num = positive_indices.shape[0], negative_indices.shape[0] |
| # |
| # # the number that is selected to the contrastive learning. |
| # selective_num = min(sample_limits, positive_indices_num, negative_indices_num) |
| # if selective_num < 50: continue # skip if contains too limited samples. |
| # |
| # embedding_sample_list.append(positive_indices[torch.randperm(positive_indices_num)[:selective_num]]) |
| # label_list.append(torch.tensor([batch_idx + (self.param.local_rank * 100)] * selective_num, |
| # device=positive_indices.device)) |
| # |
| # # add negative features to list. |
| # negative_sample_list.append(negative_indices[torch.randperm(negative_indices_num)[:selective_num]]) |
| # negative_label_list.append(torch.tensor([-1] * selective_num, device=negative_indices.device)) |
| # |
| # if len(embedding_sample_list) == 0: return 0. |
| # embedding_sample_list = torch.cat(embedding_sample_list, dim=0) |
| # negative_sample_list = torch.cat(negative_sample_list, dim=0) |
| # label_list = torch.cat(label_list) |
| # negative_label_list = torch.cat(negative_label_list) |
| # |
| # loss = self.info_nce(embedding_sample_list, label_list.unsqueeze(-1), |
| # torch.cat([embedding_sample_list, negative_sample_list], dim=0), |
| # torch.cat([label_list, negative_label_list]).unsqueeze(-1)) |
| |
| # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] |
| # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] |
| # |
| # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) |
| # torch.distributed.all_gather(output_list_labels, label_list) |
| # |
| # output_list_embeddings = torch.cat(output_list_embeddings) |
| # output_list_labels = torch.cat(output_list_labels, dim=1) |
| # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) |
| return loss |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @staticmethod |
| def manipulate_cover_mask(a_label, current_mask): |
| |
| |
| a_label = a_label + 1 |
| visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1)) |
| |
| current_mask[:visual_mask.shape[1], :visual_mask.shape[0]][visual_mask == 1.] = 0 |
| current_mask[:visual_mask.shape[1], :visual_mask.shape[0]][visual_mask == 4.] = 0 |
|
|
| return current_mask |
|
|
| |
| |
| def info_nce(self, anchors_, a_labels_, contras_, c_labels_): |
| c_labels_ = torch.cat([a_labels_, c_labels_]) |
| contras_ = torch.cat([anchors_, contras_]) |
| |
| mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float() |
|
|
| |
| anchor_dot_contrast = torch.div(torch.matmul(anchors_, torch.transpose(contras_, 0, 1)), |
| self.temperature) |
|
|
| |
| logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) |
| logits = anchor_dot_contrast - logits_max.detach() |
|
|
| |
| neg_mask = 1 - mask |
|
|
| |
| mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask) |
| mask = mask.fill_diagonal_(0.) |
|
|
| |
| neg_logits = torch.exp(logits) * neg_mask |
| neg_logits = neg_logits.sum(1, keepdim=True) |
|
|
| exp_logits = torch.exp(logits) |
|
|
| |
| |
| log_prob = logits - torch.log(exp_logits + neg_logits) |
| |
|
|
| |
| mask_pos_pairs = mask.sum(1) |
| mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) |
| |
| mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs |
| assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any()) |
| return - mean_log_prob_pos.mean() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
|
|