kaveh's picture
init
ef814bf
raw
history blame contribute delete
813 Bytes
from torch import nn
class MLMLoss(nn.Module):
"""
Masked Language Modeling loss.
"""
def __init__(self, mse_based=False):
super(MLMLoss, self).__init__()
self.mse_based = mse_based
if self.mse_based:
self.loss_fn = nn.MSELoss(reduction='none')
else:
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
def forward(self, predictions, targets, mask):
if self.mse_based:
predictions = predictions.squeeze(-1)
else:
predictions = predictions.permute(0, 2, 1) # (batch_size, vocab_size, seq_len)
targets = targets.long()
masked_loss = self.loss_fn(predictions, targets)
masked_loss = masked_loss * mask.float()
return masked_loss.sum() / mask.sum()