| import torch |
| import math |
| import torch.nn.functional as F |
|
|
|
|
| def log_sum_exp(x, axis=None): |
| """ |
| Log sum exp function |
| Args: |
| x: Input. |
| axis: Axis over which to perform sum. |
| Returns: |
| torch.Tensor: log sum exp |
| """ |
| x_max = torch.max(x, axis)[0] |
| y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max |
| return y |
|
|
|
|
| def get_positive_expectation(p_samples, measure='JSD', average=True): |
| """ |
| Computes the positive part of a divergence / difference. |
| Args: |
| p_samples: Positive samples. |
| measure: Measure to compute for. |
| average: Average the result over samples. |
| Returns: |
| torch.Tensor |
| """ |
| log_2 = math.log(2.) |
| if measure == 'GAN': |
| Ep = - F.softplus(-p_samples) |
| elif measure == 'JSD': |
| Ep = log_2 - F.softplus(-p_samples) |
| elif measure == 'X2': |
| Ep = p_samples ** 2 |
| elif measure == 'KL': |
| Ep = p_samples + 1. |
| elif measure == 'RKL': |
| Ep = -torch.exp(-p_samples) |
| elif measure == 'DV': |
| Ep = p_samples |
| elif measure == 'H2': |
| Ep = torch.ones_like(p_samples) - torch.exp(-p_samples) |
| elif measure == 'W1': |
| Ep = p_samples |
| else: |
| raise ValueError('Unknown measurement {}'.format(measure)) |
| if average: |
| return Ep.mean() |
| else: |
| return Ep |
|
|
|
|
| def get_negative_expectation(q_samples, measure='JSD', average=True): |
| """ |
| Computes the negative part of a divergence / difference. |
| Args: |
| q_samples: Negative samples. |
| measure: Measure to compute for. |
| average: Average the result over samples. |
| Returns: |
| torch.Tensor |
| """ |
| log_2 = math.log(2.) |
| if measure == 'GAN': |
| Eq = F.softplus(-q_samples) + q_samples |
| elif measure == 'JSD': |
| Eq = F.softplus(-q_samples) + q_samples - log_2 |
| elif measure == 'X2': |
| Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) |
| elif measure == 'KL': |
| Eq = torch.exp(q_samples) |
| elif measure == 'RKL': |
| Eq = q_samples - 1. |
| elif measure == 'DV': |
| Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) |
| elif measure == 'H2': |
| Eq = torch.exp(q_samples) - 1. |
| elif measure == 'W1': |
| Eq = q_samples |
| else: |
| raise ValueError('Unknown measurement {}'.format(measure)) |
| if average: |
| return Eq.mean() |
| else: |
| return Eq |
|
|
|
|
| def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'): |
| """ |
| QV-CL module |
| Computing the Contrastive Loss between the video and query. |
| :param video: video rep (bsz, Lv, dim) |
| :param query: query rep (bsz, dim) |
| :param match_labels: match labels (bsz, Lv) |
| :param mask: mask (bsz, Lv) |
| :param measure: estimator of the mutual information |
| :return: L_{qv} |
| """ |
| |
| pos_mask = match_labels.type(torch.float32) |
| neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask |
|
|
| |
| query = query.unsqueeze(2) |
| res = torch.matmul(video, query).squeeze(2) |
|
|
| |
| E_pos = get_positive_expectation(res * pos_mask, measure, average=False) |
| E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
|
|
| |
| E_neg = get_negative_expectation(res * neg_mask, measure, average=False) |
| E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
|
|
| E = E_neg - E_pos |
| |
| return E |
|
|
|
|
| def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'): |
| """ |
| VV-CL module |
| Computing the Contrastive loss between the start/end clips and the video |
| :param video: video rep (bsz, Lv, dim) |
| :param st_ed_indices: (bsz, 2) |
| :param match_labels: match labels (bsz, Lv) |
| :param mask: mask (bsz, Lv) |
| :param measure: estimator of the mutual information |
| :return: L_{vv} |
| """ |
| |
| pos_mask = match_labels.type(torch.float32) |
| neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask |
|
|
| |
| st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] |
| batch_indices = torch.arange(0, video.shape[0]).long() |
| video_s = video[batch_indices, st_indices, :] |
| video_e = video[batch_indices, ed_indices, :] |
|
|
| |
| video_s = video_s.unsqueeze(2) |
| res_s = torch.matmul(video, video_s).squeeze(2) |
| video_e = video_e.unsqueeze(2) |
| res_e = torch.matmul(video, video_e).squeeze(2) |
|
|
| |
| E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False) |
| E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
| |
| E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False) |
| E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
| E_pos = E_s_pos + E_e_pos |
|
|
| |
| E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False) |
| E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
|
|
| |
| E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False) |
| E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
| E_neg = E_s_neg + E_e_neg |
|
|
| E = E_neg - E_pos |
| return torch.mean(E) |
|
|