Spaces:
Running
Running
| from torch import nn | |
| class MLMLoss(nn.Module): | |
| """ | |
| Masked Language Modeling loss. | |
| """ | |
| def __init__(self, mse_based=False): | |
| super(MLMLoss, self).__init__() | |
| self.mse_based = mse_based | |
| if self.mse_based: | |
| self.loss_fn = nn.MSELoss(reduction='none') | |
| else: | |
| self.loss_fn = nn.CrossEntropyLoss(reduction='none') | |
| def forward(self, predictions, targets, mask): | |
| if self.mse_based: | |
| predictions = predictions.squeeze(-1) | |
| else: | |
| predictions = predictions.permute(0, 2, 1) # (batch_size, vocab_size, seq_len) | |
| targets = targets.long() | |
| masked_loss = self.loss_fn(predictions, targets) | |
| masked_loss = masked_loss * mask.float() | |
| return masked_loss.sum() / mask.sum() |