| import torch |
| from torchmetrics import Metric |
|
|
| class MyAccuracy(Metric): |
| """ |
| Accuracy metric costomized for handling sequences with padding. |
| |
| Methods: |
| update(self, logits, labels, num_labels): Update the accuracy based on |
| model predictions and ground truth labels. |
| |
| compute(self): Compute the accuracy. |
| |
| Attributes: |
| total (torch.Tensor): Total number of non-padding elements. |
| correct (torch.Tensor): Number of correctly predicted non-padding elements. |
| """ |
| def __init__(self): |
| super().__init__() |
| self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum') |
| self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum') |
|
|
| def update(self, logits: torch.Tensor, labels: torch.Tensor, num_labels: int) -> None: |
| """ |
| Args: |
| logits (torch.Tensor): Model predictions. |
| labels (torch.Tensor): Ground truth labels. |
| num_labels (int): Number of unique labels. |
| """ |
| flattened_targets = labels.view(-1) |
| active_logits = logits.view(-1, num_labels) |
| flattened_predictions = torch.argmax(active_logits, axis=1) |
|
|
| |
| active_accuracy = labels.view(-1) != -100 |
| ac_labels = torch.masked_select(flattened_targets, active_accuracy) |
| predictions = torch.masked_select(flattened_predictions, active_accuracy) |
|
|
| self.correct += torch.sum(ac_labels == predictions) |
| self.total += torch.numel(ac_labels) |
|
|
| def compute(self) -> torch.Tensor: |
| """ |
| Calculate the accuracy. |
| """ |
| return self.correct.float() / self.total.float() |