| import torch |
| import numpy as np |
|
|
| from torch import Tensor |
| from torch import distributed as dist |
| from torch.nn import functional as F |
|
|
|
|
| class SimpleContrastiveLoss: |
| def __call__( |
| self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = "mean" |
| ): |
| if target is None: |
| target_per_qry = y.size(0) // x.size(0) |
| target = torch.arange( |
| 0, |
| x.size(0) * target_per_qry, |
| target_per_qry, |
| device=x.device, |
| dtype=torch.long, |
| ) |
| logits = torch.matmul(x, y.transpose(0, 1)) |
| return F.cross_entropy(logits, target, reduction=reduction) |
|
|
|
|
| class GLENLoss: |
| def __call__( |
| self, |
| x: Tensor, |
| y: Tensor, |
| target: Tensor = None, |
| reduction: str = "mean", |
| **kwargs |
| ): |
| if target is None: |
| target_per_qry = y.size(0) // x.size(0) |
| target = torch.arange( |
| 0, |
| x.size(0) * target_per_qry, |
| target_per_qry, |
| device=x.device, |
| dtype=torch.long, |
| ) |
| model = kwargs["model"] |
| tokenizer = model.tokenizer |
|
|
| q_reps = x |
| p_reps_dt = y |
|
|
| |
| lm_logits_p = p_reps_dt @ model.lm_p.shared.weight.T |
| if model.model_args.mask_special_tokens_for_decoding: |
| special_token_ids = tokenizer.all_special_ids |
| special_token_ids = [ |
| x |
| for x in special_token_ids |
| if x |
| not in [ |
| tokenizer.bos_token_id, |
| tokenizer.eos_token_id, |
| tokenizer.pad_token_id, |
| ] |
| ] |
| lm_logits_p[:, :, special_token_ids] = -float("inf") |
| if model.model_args.do_docid_temperature_annealing: |
| first_temperature = model.model_args.docid_temperature |
| cur_epoch = ( |
| model.trainer.state.epoch |
| if hasattr(model, "trainer") |
| else model.cur_epoch |
| ) |
| temperature = max( |
| model.model_args.docid_temperature_min, |
| first_temperature * np.exp(-cur_epoch), |
| ) |
| else: |
| temperature = model.model_args.docid_temperature |
|
|
| lm_logits_p = lm_logits_p / temperature |
| lm_attention = torch.softmax(lm_logits_p, dim=-1) |
| p_reps = lm_attention @ model.lm_p.shared.weight |
| p_attention = lm_attention |
|
|
| |
| scores = model.compute_similarity(q_reps, p_reps) |
| scores = scores.view(q_reps.size(0), -1) |
|
|
| scores = scores / model.softmax_temperature |
|
|
| target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) |
| target = target * (p_reps.size(0) // q_reps.size(0)) |
|
|
| loss = model.compute_loss(scores, target) * model.model_args.infonce_loss |
|
|
| |
| lm_logits_q = None |
| if ( |
| "q_to_docid_loss" in model.model_args.__dict__ |
| and model.model_args.q_to_docid_loss > 0 |
| ): |
| lm_logits_q = q_reps @ model.lm_q.shared.weight.T |
| if model.model_args.mask_special_tokens_for_decoding: |
| special_token_ids = tokenizer.all_special_ids |
| special_token_ids = [ |
| x |
| for x in special_token_ids |
| if x |
| not in [ |
| tokenizer.bos_token_id, |
| tokenizer.eos_token_id, |
| tokenizer.pad_token_id, |
| ] |
| ] |
| lm_logits_q[:, :, special_token_ids] = -1e9 |
|
|
| lm_logits_q = lm_logits_q / model.softmax_temperature |
|
|
| |
| pos_doc = torch.arange(q_reps.size(0), dtype=torch.long) |
| pos_doc = pos_doc * (p_reps.size(0) // q_reps.size(0)) |
| pos_p_attention = p_attention[ |
| pos_doc, :, : |
| ] |
|
|
| lm_targets = pos_p_attention.argmax(dim=-1) |
| lm_loss = model.cross_entropy( |
| lm_logits_q.view(-1, lm_logits_q.size(-1)), lm_targets.view(-1) |
| ) |
|
|
| loss = loss + lm_loss * model.model_args.q_to_docid_loss |
|
|
| |
| if ( |
| "cosine_point_loss" in model.model_args.__dict__ |
| and model.model_args.cosine_point_loss > 0 |
| ): |
| if lm_logits_q is None: |
| lm_logits_q = q_reps @ model.lm_q.shared.weight.T |
| if model.model_args.mask_special_tokens_for_decoding: |
| special_token_ids = tokenizer.all_special_ids |
| special_token_ids = [ |
| x |
| for x in special_token_ids |
| if x |
| not in [ |
| tokenizer.bos_token_id, |
| tokenizer.eos_token_id, |
| tokenizer.pad_token_id, |
| ] |
| ] |
| lm_logits_q[:, :, special_token_ids] = -1e9 |
|
|
| lm_logits_q = lm_logits_q / model.softmax_temperature |
|
|
| pos_doc_lm_logits = ( |
| p_reps_dt[pos_doc] @ model.lm_p.shared.weight.T |
| ) |
| query_lm_logits = lm_logits_q |
|
|
| pos_doc_id_weight, pos_doc_id = pos_doc_lm_logits.max(dim=-1) |
| query_id_weight = query_lm_logits.gather( |
| dim=-1, index=pos_doc_id.unsqueeze(-1) |
| ).squeeze( |
| -1 |
| ) |
|
|
| cosine_sim = F.cosine_similarity( |
| pos_doc_id_weight, query_id_weight, dim=-1 |
| ) |
| cosine_loss = torch.mean(1 - cosine_sim) |
|
|
| loss = loss + cosine_loss * model.model_args.cosine_point_loss |
|
|
| return loss |
|
|
|
|
| class DistributedContrastiveLoss(SimpleContrastiveLoss): |
| def __init__(self, n_target: int = 0, scale_loss: bool = True): |
| assert ( |
| dist.is_initialized() |
| ), "Distributed training has not been properly initialized." |
| super().__init__() |
| self.word_size = dist.get_world_size() |
| self.rank = dist.get_rank() |
| self.scale_loss = scale_loss |
|
|
| def __call__(self, x: Tensor, y: Tensor, **kwargs): |
| dist_x = self.gather_tensor(x) |
| dist_y = self.gather_tensor(y) |
| loss = super().__call__(dist_x, dist_y, **kwargs) |
| if self.scale_loss: |
| loss = loss * self.word_size |
| return loss |
|
|
| def gather_tensor(self, t): |
| gathered = [torch.empty_like(t) for _ in range(self.word_size)] |
| dist.all_gather(gathered, t) |
| gathered[self.rank] = t |
| return torch.cat(gathered, dim=0) |
|
|