GLEN-model / src /tevatron /loss.py
QuanTH02's picture
Phase 1
3d5551b
import torch
import numpy as np
from torch import Tensor
from torch import distributed as dist
from torch.nn import functional as F
class SimpleContrastiveLoss:
def __call__(
self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = "mean"
):
if target is None:
target_per_qry = y.size(0) // x.size(0)
target = torch.arange(
0,
x.size(0) * target_per_qry,
target_per_qry,
device=x.device,
dtype=torch.long,
)
logits = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(logits, target, reduction=reduction)
class GLENLoss:
def __call__(
self,
x: Tensor,
y: Tensor,
target: Tensor = None,
reduction: str = "mean",
**kwargs
):
if target is None:
target_per_qry = y.size(0) // x.size(0)
target = torch.arange(
0,
x.size(0) * target_per_qry,
target_per_qry,
device=x.device,
dtype=torch.long,
)
model = kwargs["model"]
tokenizer = model.tokenizer
q_reps = x
p_reps_dt = y
# get representations
lm_logits_p = p_reps_dt @ model.lm_p.shared.weight.T
if model.model_args.mask_special_tokens_for_decoding:
special_token_ids = tokenizer.all_special_ids
special_token_ids = [
x
for x in special_token_ids
if x
not in [
tokenizer.bos_token_id,
tokenizer.eos_token_id,
tokenizer.pad_token_id,
]
]
lm_logits_p[:, :, special_token_ids] = -float("inf")
if model.model_args.do_docid_temperature_annealing:
first_temperature = model.model_args.docid_temperature
cur_epoch = (
model.trainer.state.epoch
if hasattr(model, "trainer")
else model.cur_epoch
)
temperature = max(
model.model_args.docid_temperature_min,
first_temperature * np.exp(-cur_epoch),
)
else:
temperature = model.model_args.docid_temperature
lm_logits_p = lm_logits_p / temperature
lm_attention = torch.softmax(lm_logits_p, dim=-1)
p_reps = lm_attention @ model.lm_p.shared.weight
p_attention = lm_attention
# pairwise loss
scores = model.compute_similarity(q_reps, p_reps) # (B, B * train_n_passages)
scores = scores.view(q_reps.size(0), -1) # (B, B * train_n_passages)
scores = scores / model.softmax_temperature
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (p_reps.size(0) // q_reps.size(0))
loss = model.compute_loss(scores, target) * model.model_args.infonce_loss
# pointwise loss
lm_logits_q = None
if (
"q_to_docid_loss" in model.model_args.__dict__
and model.model_args.q_to_docid_loss > 0
):
lm_logits_q = q_reps @ model.lm_q.shared.weight.T # (B, L, V)
if model.model_args.mask_special_tokens_for_decoding:
special_token_ids = tokenizer.all_special_ids
special_token_ids = [
x
for x in special_token_ids
if x
not in [
tokenizer.bos_token_id,
tokenizer.eos_token_id,
tokenizer.pad_token_id,
]
]
lm_logits_q[:, :, special_token_ids] = -1e9
lm_logits_q = lm_logits_q / model.softmax_temperature # (B, L, V)
# For each query, select the corresponding positive document
pos_doc = torch.arange(q_reps.size(0), dtype=torch.long) # (B)
pos_doc = pos_doc * (p_reps.size(0) // q_reps.size(0)) # (B)
pos_p_attention = p_attention[
pos_doc, :, :
] # (B * train_n_passages, L, V) -> (B, L, V)
lm_targets = pos_p_attention.argmax(dim=-1) # (B, L)
lm_loss = model.cross_entropy(
lm_logits_q.view(-1, lm_logits_q.size(-1)), lm_targets.view(-1)
)
loss = loss + lm_loss * model.model_args.q_to_docid_loss
# pointwise loss
if (
"cosine_point_loss" in model.model_args.__dict__
and model.model_args.cosine_point_loss > 0
):
if lm_logits_q is None:
lm_logits_q = q_reps @ model.lm_q.shared.weight.T # (B, L, V)
if model.model_args.mask_special_tokens_for_decoding:
special_token_ids = tokenizer.all_special_ids
special_token_ids = [
x
for x in special_token_ids
if x
not in [
tokenizer.bos_token_id,
tokenizer.eos_token_id,
tokenizer.pad_token_id,
]
]
lm_logits_q[:, :, special_token_ids] = -1e9
lm_logits_q = lm_logits_q / model.softmax_temperature # (B, L, V)
pos_doc_lm_logits = (
p_reps_dt[pos_doc] @ model.lm_p.shared.weight.T
) # (B, L, V)
query_lm_logits = lm_logits_q # (B, L, V)
pos_doc_id_weight, pos_doc_id = pos_doc_lm_logits.max(dim=-1) # (B, L)
query_id_weight = query_lm_logits.gather(
dim=-1, index=pos_doc_id.unsqueeze(-1)
).squeeze(
-1
) # (B, L)
cosine_sim = F.cosine_similarity(
pos_doc_id_weight, query_id_weight, dim=-1
) # (B)
cosine_loss = torch.mean(1 - cosine_sim) # ()
loss = loss + cosine_loss * model.model_args.cosine_point_loss
return loss
class DistributedContrastiveLoss(SimpleContrastiveLoss):
def __init__(self, n_target: int = 0, scale_loss: bool = True):
assert (
dist.is_initialized()
), "Distributed training has not been properly initialized."
super().__init__()
self.word_size = dist.get_world_size()
self.rank = dist.get_rank()
self.scale_loss = scale_loss
def __call__(self, x: Tensor, y: Tensor, **kwargs):
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
loss = super().__call__(dist_x, dist_y, **kwargs)
if self.scale_loss:
loss = loss * self.word_size
return loss
def gather_tensor(self, t):
gathered = [torch.empty_like(t) for _ in range(self.word_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t
return torch.cat(gathered, dim=0)