AuralSAM2 / avs.code /v1s.code /loss /training /contrastive_learning.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
from abc import ABC
import torch
import torch.nn as nn
class ContrastLoss(nn.Module, ABC):
def __init__(self, hyp_param):
super().__init__()
self.param = hyp_param
_defaults = {
"temperature": 0.10,
"ignore_idx": 255,
"ood_idx": 254,
"max_views": 512,
"proj_dim": 512,
"sample_limits": 128,
"total_limits": 15240,
}
_raw = getattr(hyp_param, "contrastive_learning", None) or {}
_cfg = {**_defaults, **_raw}
self.temperature = _cfg["temperature"]
self.ignore_idx = _cfg["ignore_idx"]
self.ood_idx = _cfg["ood_idx"]
self.max_views = _cfg["max_views"]
self.proj_dim = _cfg["proj_dim"]
self.sample_limits = _cfg["sample_limits"]
self.total_limits = _cfg["total_limits"]
def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx):
embedding_sample_list = []
label_list = []
embedding_sample_list_a = []
label_list_a = []
class_index_list = torch.unique(masks)
if len(class_index_list) > 1:
for class_index in class_index_list[1:]:
embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3)
else:
embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3)
sample_limits = self.sample_limits
embeddings = embeddings.permute(1, 0)
for class_index in class_index_list:
hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()]
easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()]
hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0]
selective_num_hard = min(sample_limits, hard_indices_num)
selective_num_easy = min(sample_limits, easy_indices_num)
if (selective_num_hard + selective_num_easy) < sample_limits * 2:
if selective_num_hard > selective_num_easy:
selective_num_hard += sample_limits * 2 - selective_num_easy
else:
selective_num_easy += sample_limits * 2 - selective_num_hard
hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard]
embedding_sample_list.append(hard_indices[hard_chosen_indices])
label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3)
easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy]
embedding_sample_list.append(easy_indices[easy_chosen_indices])
label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3)
return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a
def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions):
masks = masks.flatten(start_dim=1)
predictions = predictions.flatten(start_dim=1)
visual_embeddings = visual_embeddings.flatten(start_dim=-2)
visual_embedding_sample_list = []
visual_label_list = []
audio_embedding_sample_list = []
audio_label_list = []
for frame_idx in range(masks.shape[0]):
current_vision_feats = visual_embeddings[frame_idx]
current_masks = masks[frame_idx]
current_predictions = predictions[frame_idx]
current_audio_feats = audio_embeddings[frame_idx]
for layer_idx in range(3):
(
selected_vision_embeddings,
selected_vision_labels,
selected_audio_embeddings,
selected_audio_labels,
) = self.select_class_wise_samples(
current_vision_feats[layer_idx],
current_audio_feats[layer_idx],
current_predictions,
current_masks,
0,
)
visual_embedding_sample_list += selected_vision_embeddings
visual_label_list += selected_vision_labels
audio_embedding_sample_list += selected_audio_embeddings
audio_label_list += selected_audio_labels
if len(visual_embedding_sample_list) == 0:
return 0.0
visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze()
if visual_embedding_sample_list.dim() == 1:
visual_embedding_sample_list = visual_embedding_sample_list.unsqueeze(0)
visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1)
audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze()
if audio_embedding_sample_list.dim() == 1:
audio_embedding_sample_list = audio_embedding_sample_list.unsqueeze(0)
audio_label_list = torch.cat(audio_label_list).unsqueeze(1)
total_limits = self.total_limits
if visual_embedding_sample_list.shape[0] > total_limits:
rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits]
visual_embedding_sample_list = visual_embedding_sample_list[:rand_index]
visual_label_list = visual_label_list[:rand_index]
loss = self.info_nce(
visual_embedding_sample_list,
visual_label_list,
audio_embedding_sample_list,
audio_label_list,
)
return loss
def forward(self, embeddings, output_dicts, masks):
predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts])
predictions = torch.nn.functional.interpolate(
predictions,
size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
mode="bilinear",
align_corners=False,
).squeeze(1)
masks = torch.nn.functional.interpolate(
masks.unsqueeze(1),
size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
mode="nearest",
).squeeze(1)
visual_embeddings, audio_embeddings = embeddings
visual_embeddings = torch.cat(
[
torch.cat(
[
visual_embeddings[0][i].unsqueeze(0),
visual_embeddings[1][i].unsqueeze(0),
visual_embeddings[2][i].unsqueeze(0),
]
).unsqueeze(0)
for i in range(masks.shape[0])
]
)
audio_embeddings = torch.cat(
[
torch.cat(
[
audio_embeddings[0][i].unsqueeze(0),
audio_embeddings[1][i].unsqueeze(0),
audio_embeddings[2][i].unsqueeze(0),
]
).unsqueeze(0)
for i in range(masks.shape[0])
]
)
return self.forward_audio_visual(
visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions
)
@staticmethod
def manipulate_cover_mask(a_label, current_mask):
a_label = a_label + 1
visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1))
current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0
current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0
return current_mask
def info_nce(self, anchors_, a_labels_, contras_, c_labels_):
c_labels_ = torch.cat([a_labels_, c_labels_])
contras_ = torch.cat([anchors_, contras_])
mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float()
anchor_dot_contrast = torch.div(
torch.matmul(anchors_, torch.transpose(contras_, 0, 1)),
self.temperature,
)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
neg_mask = 1 - mask
mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask)
mask = mask.fill_diagonal_(0.0)
neg_logits = torch.exp(logits) * neg_mask
neg_logits = neg_logits.sum(1, keepdim=True)
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits + neg_logits)
mask_pos_pairs = mask.sum(1)
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any())
return -mean_log_prob_pos.mean()