| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
|
|
| class RDMCrossEntropyLoss(nn.CrossEntropyLoss): |
| def __init__(self, ignore_index): |
| self.ignore_index = ignore_index |
|
|
| def forward(self, |
| scores: torch.Tensor, |
| target: torch.Tensor, |
| label_mask, |
| weights, |
| ) -> torch.Tensor: |
| """ |
| Computes the RDM-derived loss (weighted cross-entropy). |
| """ |
|
|
| sample_size = target.ne(self.ignore_index).float().sum() |
|
|
| lprobs = F.log_softmax(scores, dim=-1) |
|
|
| loss = lprobs * weights |
| fullseq_loss = loss.sum() / sample_size |
|
|
| |
| |
| label_mask = label_mask.float() |
| sample_size = label_mask.sum() |
| loss = (loss * label_mask).sum() / sample_size |
|
|
| ppl = torch.exp(loss) |
| |
| logging_output = { |
| 'ppl': ppl.data, |
| 'fullseq_loss': fullseq_loss.data, |
| 'weight_diff_loss': loss.data |
| } |
|
|
| return logging_output |