"""Contrastive loss used during SAM2 + fusion training (config from Hydra `contrastive_learning`, tmp.code style).""" import torch from abc import ABC import torch.nn as nn class ContrastLoss(nn.Module, ABC): def __init__(self, hyp_param): super(ContrastLoss, self).__init__() self.param = hyp_param _defaults = { "temperature": 0.10, "ignore_idx": 255, "ood_idx": 254, "max_views": 512, "proj_dim": 512, "sample_limits": 64, "total_limits": 15240, } _raw = getattr(hyp_param, "contrastive_learning", None) or {} _cfg = {**_defaults, **_raw} self.temperature = _cfg["temperature"] self.ignore_idx = _cfg["ignore_idx"] self.ood_idx = _cfg["ood_idx"] self.max_views = _cfg["max_views"] self.proj_dim = _cfg["proj_dim"] self.sample_limits = _cfg["sample_limits"] self.total_limits = _cfg["total_limits"] def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx): embedding_sample_list = [] label_list = [] embedding_sample_list_a = [] label_list_a = [] class_index_list = torch.unique(masks) # means not silence if len(class_index_list) > 1: for class_index in class_index_list[1:]: embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3) else: embedding_sample_list_a.append(audio_embeddings.unsqueeze(0)) label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3) # contras_list = [] # contras_label_list = [] sample_limits = self.sample_limits # we only have 0, 1 embeddings = embeddings.permute(1, 0) for class_index in class_index_list: hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()] easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()] hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] # the number that is selected to the contrastive learning. selective_num_hard = min(sample_limits, hard_indices_num) selective_num_easy = min(sample_limits, easy_indices_num) if (selective_num_hard + selective_num_easy) < sample_limits * 2: if selective_num_hard > selective_num_easy: selective_num_hard += sample_limits * 2 - selective_num_easy else: selective_num_easy += sample_limits * 2 - selective_num_hard # skip if contains too limited samples. # if selective_num_hard < 10 and selective_num_easy < 10: # continue hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] embedding_sample_list.append(hard_indices[hard_chosen_indices]) label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3) # add negative features to list. easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] embedding_sample_list.append(easy_indices[easy_chosen_indices]) label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3) return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions): masks = masks.flatten(start_dim=1) predictions = predictions.flatten(start_dim=1) visual_embeddings = visual_embeddings.flatten(start_dim=-2) visual_embedding_sample_list = [] visual_label_list = [] audio_embedding_sample_list = [] audio_label_list = [] for frame_idx in range(masks.shape[0]): current_vision_feats = visual_embeddings[frame_idx] current_masks = masks[frame_idx] current_predictions = predictions[frame_idx] current_audio_feats = audio_embeddings[frame_idx] for layer_idx in range(3): (selected_vision_embeddings, selected_vision_labels, selected_audio_embeddings, selected_audio_labels) = self.select_class_wise_samples(current_vision_feats[layer_idx], current_audio_feats[layer_idx], current_predictions, current_masks, 0) visual_embedding_sample_list += selected_vision_embeddings visual_label_list += selected_vision_labels audio_embedding_sample_list += selected_audio_embeddings audio_label_list += selected_audio_labels if len(visual_embedding_sample_list) == 0: return 0. visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze() visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1) audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze() audio_label_list = torch.cat(audio_label_list).unsqueeze(1) # print(visual_embedding_sample_list.shape, visual_label_list.shape) # print(audio_embedding_sample_list.shape, audio_label_list.shape) # exit(1) total_limits = self.total_limits if visual_embedding_sample_list.shape[0] > total_limits: rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits] visual_embedding_sample_list = visual_embedding_sample_list[:rand_index] visual_label_list = visual_label_list[:rand_index] loss = self.info_nce(visual_embedding_sample_list, visual_label_list, audio_embedding_sample_list, audio_label_list) return loss # proof the q-project CAN BE the projector head of the contrastive learning. # At the moment, I do believe the ATTENTION is the another format of the contrastive learning. # First experiment: ignore the sound, only work on the projected vision mask. def forward(self, embeddings, output_dicts, masks): predictions = torch.cat([i['multistep_pred_masks'] for i in output_dicts]) predictions = torch.nn.functional.interpolate(predictions, size=(int(self.param.image_size/16), int(self.param.image_size/16)), mode='bilinear', align_corners=False).squeeze(1) masks = torch.nn.functional.interpolate(masks.unsqueeze(1), size=(int(self.param.image_size/16), int(self.param.image_size/16)), mode='nearest').squeeze(1) visual_embeddings, audio_embeddings = embeddings # if len(predictions.shape) < 3 and len(masks.shape) < 3: # predictions = predictions.unsqueeze(0) # masks = masks.unsqueeze(0) visual_embeddings = torch.cat([torch.cat([visual_embeddings[0][i].unsqueeze(0), visual_embeddings[1][i].unsqueeze(0), visual_embeddings[2][i].unsqueeze(0)]).unsqueeze(0) for i in range(masks.shape[0])]) audio_embeddings = torch.cat([torch.cat([audio_embeddings[0][i].unsqueeze(0), audio_embeddings[1][i].unsqueeze(0), audio_embeddings[2][i].unsqueeze(0)]).unsqueeze(0) for i in range(masks.shape[0])]) # dict_keys(['point_inputs', 'mask_inputs', 'multistep_pred_masks', 'multistep_pred_masks_high_res', # 'multistep_pred_multimasks', 'multistep_pred_multimasks_high_res', 'multistep_pred_ious', # 'multistep_point_inputs', 'multistep_object_score_logits', 'pred_masks', 'pred_masks_high_res', # 'maskmem_features', 'maskmem_pos_enc']) return self.forward_audio_visual(visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions) # def forward_visual_only(self, visual_embeddings, masks, predictions): # masks = masks.flatten(start_dim=1) # predictions = predictions.flatten(start_dim=1) # visual_embeddings = visual_embeddings.flatten(start_dim=-2) # # visual_embedding_sample_list = [] # visual_label_list = [] # audio_embedding_sample_list = [] # audio_label_list = [] # # for frame_idx in range(masks.shape[0]): # current_vision_feats = visual_embeddings[frame_idx] # current_masks = masks[frame_idx] # current_predictions = predictions[frame_idx] # for layer_idx in range(3): # current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats[layer_idx], # None, # current_predictions, # current_masks, # frame_idx) # visual_embedding_sample_list += current_select_embeddings # visual_label_list += current_select_labels # # # # if len(embedding_sample_list) == 0: return 0. # embedding_sample_list = torch.cat(embedding_sample_list, dim=0).squeeze() # label_list = torch.cat(label_list, dim=0).unsqueeze(-1) # total_limits = 15240 # if embedding_sample_list.shape[0] > total_limits: # rand_index = torch.randperm(embedding_sample_list.shape[0])[total_limits] # embedding_sample_list = embedding_sample_list[:rand_index] # label_list = label_list[:rand_index] # loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, # label_list) # return loss """ # embeddings_size = (int(self.param.image_size/16), int(self.param.image_size/16)) # masks = torch.nn.functional.interpolate(masks.float(), embeddings_size, mode='nearest') # masks = masks.flatten(start_dim=1) # predictions = torch.nn.functional.interpolate(predictions.float(), embeddings_size, mode='nearest') # predictions = predictions.flatten(start_dim=1) # # embedding_sample_list = [] # label_list = [] # contras_sample_list = [] # contras_label_list = [] # temp3. # embedding_visual, embedding_audio = embeddings # embedding_visual = torch.nn.functional.normalize(embedding_visual, p=2, dim=1) # embedding_audio = torch.nn.functional.normalize(embedding_audio, p=2, dim=1) # embedding_visual = embedding_visual.reshape(self.param.batch_size, int(embedding_visual.shape[0]/self.param.batch_size), # *embedding_visual.shape[-2:]) # # embedding_audio = embedding_audio.reshape(self.param.batch_size, int(embedding_audio.shape[0]/self.param.batch_size), # *embedding_audio.shape[-2:]) # masks = masks.reshape(self.param.batch_size, int(masks.shape[0]/self.param.batch_size), # masks.shape[-1]) # predictions = predictions.reshape(self.param.batch_size, int(predictions.shape[0]/self.param.batch_size), # predictions.shape[-1]) # # for batch_idx in range(masks.shape[0]): # current_video_clip_embed = embedding_visual[batch_idx] # current_video_clip_masks = masks[batch_idx] # current_video_clip_preds = predictions[batch_idx] # current_audio_clip_embed = embedding_audio[batch_idx] # # print(current_video_clip_embed.shape, current_audio_clip_embed.shape, current_video_clip_masks.shape, current_video_clip_preds.shape) # # exit(1) # for sample_idx in range(masks.shape[1]): # current_vision_feats = current_video_clip_embed[batch_idx] # current_audio_feats = current_audio_clip_embed[batch_idx] # current_masks = current_video_clip_masks[batch_idx] # current_predictions = current_video_clip_preds[batch_idx] # current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats, # current_audio_feats, # current_predictions, # current_masks, # batch_idx) # temp2. embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) embeddings = embeddings.reshape(self.param.batch_size, int(embeddings.shape[0]/self.param.batch_size), *embeddings.shape[-2:]) masks = masks.reshape(self.param.batch_size, int(masks.shape[0]/self.param.batch_size), masks.shape[-1]) predictions = predictions.reshape(self.param.batch_size, int(predictions.shape[0]/self.param.batch_size), predictions.shape[-1]) for batch_idx in range(masks.shape[0]): current_video_clip_embed = embeddings[batch_idx] current_video_clip_masks = masks[batch_idx] current_video_clip_preds = predictions[batch_idx] # current_audio_clip_feats = for sample_idx in range(masks.shape[1]): current_vision_feats = current_video_clip_embed[batch_idx] current_masks = current_video_clip_masks[batch_idx] current_predictions = current_video_clip_preds[batch_idx] current_select_embeddings, current_select_labels = self.select_class_wise_samples(current_vision_feats, current_predictions, current_masks, batch_idx) embedding_sample_list += current_select_embeddings label_list += current_select_labels # hard_indices = current_vision_feats[(current_masks != current_predictions).nonzero()] # easy_indices = current_vision_feats[(current_masks == current_predictions).nonzero()] # # hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0] # # # the number that is selected to the contrastive learning. # selective_num_hard = min(sample_limits, hard_indices_num) # selective_num_easy = min(sample_limits, easy_indices_num) # # skip if contains too limited samples. # if selective_num_hard < 10 or selective_num_easy < 10: # continue # # hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard] # embedding_sample_list.append(hard_indices[hard_chosen_indices]) # label_list.append(current_masks[hard_chosen_indices] + batch_idx * 1e3) # # # add negative features to list. # easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy] # embedding_sample_list.append(easy_indices[easy_chosen_indices]) # label_list.append(current_masks[easy_chosen_indices] + batch_idx * 1e3) if len(embedding_sample_list) == 0: return 0. embedding_sample_list = torch.cat(embedding_sample_list, dim=0).squeeze() label_list = torch.cat(label_list, dim=0).unsqueeze(-1) total_limits = self.total_limits if embedding_sample_list.shape[0] > total_limits: rand_index = torch.randperm(embedding_sample_list.shape[0])[total_limits] embedding_sample_list = embedding_sample_list[:rand_index] label_list = label_list[:rand_index] loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, label_list) # temp. # sample_limits = 500 # for batch_idx in range(masks.shape[0]): # # go through 3 layers embeddings. # for j in range(len(embeddings)): # current_vision_feats_list = embeddings[j] # current_vision_feats = torch.nn.functional.normalize(current_vision_feats_list[batch_idx], p=2, dim=1) # current_masks = masks[batch_idx] # positive_indices = current_vision_feats[current_masks > 0, ...] # negative_indices = current_vision_feats[current_masks == 0, ...] # positive_indices_num, negative_indices_num = positive_indices.shape[0], negative_indices.shape[0] # # # the number that is selected to the contrastive learning. # selective_num = min(sample_limits, positive_indices_num, negative_indices_num) # if selective_num < 50: continue # skip if contains too limited samples. # # embedding_sample_list.append(positive_indices[torch.randperm(positive_indices_num)[:selective_num]]) # label_list.append(torch.tensor([batch_idx + (self.param.local_rank * 100)] * selective_num, # device=positive_indices.device)) # # # add negative features to list. # negative_sample_list.append(negative_indices[torch.randperm(negative_indices_num)[:selective_num]]) # negative_label_list.append(torch.tensor([-1] * selective_num, device=negative_indices.device)) # # if len(embedding_sample_list) == 0: return 0. # embedding_sample_list = torch.cat(embedding_sample_list, dim=0) # negative_sample_list = torch.cat(negative_sample_list, dim=0) # label_list = torch.cat(label_list) # negative_label_list = torch.cat(negative_label_list) # # loss = self.info_nce(embedding_sample_list, label_list.unsqueeze(-1), # torch.cat([embedding_sample_list, negative_sample_list], dim=0), # torch.cat([label_list, negative_label_list]).unsqueeze(-1)) # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] # # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) # torch.distributed.all_gather(output_list_labels, label_list) # # output_list_embeddings = torch.cat(output_list_embeddings) # output_list_labels = torch.cat(output_list_labels, dim=1) # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) return loss """ # q_max. # def forward(self, embeddings, masks): # # for single-sounding obj. only, with first idx mask. # masks = torch.nn.functional.interpolate(masks.float(), (64, 64), mode='bilinear', align_corners=False) # masks = masks.flatten(start_dim=1) # # embedding_sample_list = torch.zeros([masks.shape[0], 128]).to(self.param.local_rank) # embedding_sample_list = [] # label_list = [] # # negative_sample_list = [] # negative_label_list = [] # sample_limits = 20 # for batch_idx in range(masks.shape[0]): # # go through 3 layers embeddings. # for j in range(len(embeddings)): # current_vision_feats_list, current_audio_feats_list = embeddings[j] # current_audio_feats = torch.nn.functional.normalize(current_audio_feats_list[batch_idx], p=2, dim=1) # current_vision_feats = torch.nn.functional.normalize(current_vision_feats_list[batch_idx], p=2, dim=1) # current_masks = masks[batch_idx] # # # add following features to list. # embedding_sample_list.append(current_vision_feats[current_masks > 0, ...].max(dim=0)[0].unsqueeze(0)) # label_list.append(batch_idx + (self.param.local_rank * 100)) # # embedding_sample_list.append(current_audio_feats) # label_list.append(batch_idx + (self.param.local_rank * 100)) # # # add negative features to list. # negative_num = min(current_vision_feats[current_masks == 0, ...].shape[0], sample_limits) # if negative_num < 5: continue # skip if contains too limited samples. # rand_idx = torch.randperm(current_vision_feats[current_masks == 0, ...].shape[0])[:negative_num] # negative_sample_list.append(current_vision_feats[current_masks == 0, ...][rand_idx]) # negative_label_list.append(torch.tensor([-1] * negative_num, device=current_vision_feats.device)) # # embedding_sample_list = torch.cat(embedding_sample_list) # label_list = torch.tensor(label_list, device=masks.device) # negative_sample_list = torch.cat(negative_sample_list, dim=0) # negative_label_list = torch.cat(negative_label_list) # # loss = self.info_nce(embedding_sample_list, label_list.unsqueeze(-1), # torch.cat([embedding_sample_list, negative_sample_list], dim=0), # torch.cat([label_list, negative_label_list]).unsqueeze(-1)) # # # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] # # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] # # # # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) # # torch.distributed.all_gather(output_list_labels, label_list) # # # # output_list_embeddings = torch.cat(output_list_embeddings) # # output_list_labels = torch.cat(output_list_labels, dim=1) # # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) # return loss # attention mean. # def forward(self, embeddings): # embedding_sample_list = [] # label_list = [] # for layer_embeddings in embeddings: # embedding_sample_list.append(torch.nn.functional.normalize(layer_embeddings, p=2, dim=1)) # # currently we only utilise single frame. # label_list.append(torch.tensor(list(range(0, 1 + 1)) * self.param.batch_size) + (self.param.local_rank * 100)) # embedding_sample_list = torch.cat(embedding_sample_list).cuda(self.param.local_rank) # label_list = torch.cat(label_list).cuda(self.param.local_rank).unsqueeze(0) # # """ # all gather implementation. # """ # """ # output_list_embeddings = [torch.zeros_like(embedding_sample_list) for _ in range(torch.distributed.get_world_size())] # output_list_labels = [torch.zeros_like(label_list) for _ in range(torch.distributed.get_world_size())] # # torch.distributed.all_gather(output_list_embeddings, embedding_sample_list) # torch.distributed.all_gather(output_list_labels, label_list) # # output_list_embeddings = torch.cat(output_list_embeddings) # output_list_labels = torch.cat(output_list_labels, dim=1) # loss = self.info_nce(output_list_embeddings, output_list_labels, output_list_embeddings, output_list_labels) # """ # loss = self.info_nce(embedding_sample_list, label_list, embedding_sample_list, label_list) # # frame_token_semantic_attn = torch.nn.functional.normalize(frame_token_semantic_attn.squeeze(), p=2, dim=1) # # audio_token_attn = torch.nn.functional.normalize(audio_token_attn, p=2, dim=1) # # city_gt = torch.nn.functional.interpolate(city_gt.unsqueeze(1).float(), size=city_proj.shape[2:], # # mode='nearest').squeeze().long() # # # # ood_gt = torch.nn.functional.interpolate(ood_gt.unsqueeze(1).float(), size=ood_proj.shape[2:], # # mode='nearest').squeeze().long() # # # # # normalise the embed results # # city_proj = torch.nn.functional.normalize(city_proj, p=2, dim=1) # # ood_proj = torch.nn.functional.normalize(ood_proj, p=2, dim=1) # # # randomly extract embed samples within a batch # # anchor_embeds, anchor_labels, contrs_embeds, contrs_labels = self.extraction_samples(city_proj, city_gt, # # ood_proj, ood_gt) # # # # # calculate the CoroCL # # loss = self.info_nce(anchors_=anchor_embeds, a_labels_=anchor_labels.unsqueeze(1), contras_=contrs_embeds, # # c_labels_=contrs_labels.unsqueeze(1)) if anchor_embeds.nelement() > 0 else \ # # torch.tensor([.0], device=city_proj.device) # # return loss @staticmethod def manipulate_cover_mask(a_label, current_mask): # shifting current visual index value # background:=1, foreground:=2. a_label = a_label + 1 visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1)) # kicked out the positive value in same visual class. current_mask[:visual_mask.shape[1], :visual_mask.shape[0]][visual_mask == 1.] = 0 current_mask[:visual_mask.shape[1], :visual_mask.shape[0]][visual_mask == 4.] = 0 return current_mask # The implementation of cross-image contrastive learning is based on: # https://github.com/tfzhou/ContrastiveSeg/blob/287e5d3069ce6d7a1517ddf98e004c00f23f8f99/lib/loss/loss_contrast.py def info_nce(self, anchors_, a_labels_, contras_, c_labels_): c_labels_ = torch.cat([a_labels_, c_labels_]) contras_ = torch.cat([anchors_, contras_]) # calculates the binary mask: same category => 1, different categories => 0 mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float() # calculates the dot product anchor_dot_contrast = torch.div(torch.matmul(anchors_, torch.transpose(contras_, 0, 1)), self.temperature) # for numerical stability logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) logits = anchor_dot_contrast - logits_max.detach() # calculates the negative mask neg_mask = 1 - mask # avoid the self duplicate issue mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask) mask = mask.fill_diagonal_(0.) # sum the negative odot results neg_logits = torch.exp(logits) * neg_mask neg_logits = neg_logits.sum(1, keepdim=True) exp_logits = torch.exp(logits) # log_prob -> log(exp(x))-log(exp(x) + exp(y)) # log_prob -> log{exp(x)/[exp(x)+exp(y)]} log_prob = logits - torch.log(exp_logits + neg_logits) # log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # calculate the info-nce based on the positive samples (under same categories) mask_pos_pairs = mask.sum(1) mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) # mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs.sum(1) mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any()) return - mean_log_prob_pos.mean() # def extraction_samples(self, city_embd, city_label, ood_embd, ood_label): # # reformat the matrix # city_embd = city_embd.flatten(start_dim=2).permute(0, 2, 1) # city_label = city_label.flatten(start_dim=1) # ood_embd = ood_embd.flatten(start_dim=2).permute(0, 2, 1) # ood_label = ood_label.flatten(start_dim=1) # # # define different types of embeds # city_positive = city_embd[city_label == self.ood_idx] # city_negative = city_embd[(city_label != self.ood_idx) & (city_label != self.ignore_idx)] # ood_positive = ood_embd[ood_label == self.ood_idx] # ood_negative = ood_embd[(ood_label != self.ood_idx) & (ood_label != self.ignore_idx)] # # # define the number of choice # sample_num = int(min(self.max_views, city_positive.shape[0], ood_positive.shape[0], # city_negative.shape[0], ood_negative.shape[0])) # # # randomly extract the anchor set with {city_ood, city_inlier} # city_positive_anchor = city_positive[torch.randperm(city_positive.shape[0])][:sample_num] # city_negative_anchor = city_negative[torch.randperm(city_negative.shape[0])][:sample_num] # # anchor_embed = torch.cat([city_positive_anchor, city_negative_anchor], dim=0) # # anchor_label = torch.cat([torch.empty(city_positive_anchor.shape[0], # device=city_positive_anchor.device).fill_(1.), # torch.empty(city_negative_anchor.shape[0], # device=city_negative_anchor.device).fill_(0.)]) # # # randomly extract the contras set with {city_ood, city_inlier, coco_ood, coco_inlier} # city_positive_contras = city_positive_anchor.clone() # city_negative_contras = city_negative_anchor.clone() # ood_positive_contras = ood_positive[torch.randperm(ood_positive.shape[0])][:sample_num] # ood_negative_contras = ood_negative[torch.randperm(ood_negative.shape[0])][:sample_num] # # contrs_embed = torch.cat([city_positive_contras, city_negative_contras, # ood_positive_contras, ood_negative_contras], dim=0) # # contrs_label = torch.cat([torch.empty(city_positive_contras.shape[0], # device=city_positive_contras.device).fill_(1.), # torch.empty(city_negative_contras.shape[0], # device=city_negative_contras.device).fill_(0.), # torch.empty(ood_positive_contras.shape[0], # device=ood_positive_contras.device).fill_(1.), # torch.empty(ood_negative_contras.shape[0], # device=ood_negative_contras.device).fill_(0.)]) # # return anchor_embed, anchor_label, contrs_embed, contrs_label