| import torch |
| from torch import nn |
|
|
| class Loss_VAE(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mse = nn.MSELoss(reduction='sum') |
|
|
| def forward(self, recon_x, x, mu, log_var): |
| mse = self.mse(recon_x, x) |
| kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) |
| loss = mse + kld |
| return loss |
| |
|
|
| def DiceScore( |
| y_pred: torch.Tensor, |
| y: torch.Tensor, |
| include_background: bool = True, |
| ) -> torch.Tensor: |
| """Computes Dice score metric from full size Tensor and collects average. |
| Args: |
| y_pred: input data to compute, typical segmentation model output. |
| It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values |
| should be binarized. |
| y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. |
| The values should be binarized. |
| include_background: whether to skip Dice computation on the first channel of |
| the predicted output. Defaults to True. |
| Returns: |
| Dice scores per batch and per class, (shape [batch_size, num_classes]). |
| Raises: |
| ValueError: when `y_pred` and `y` have different shapes. |
| """ |
|
|
| y = y.float() |
| y_pred = y_pred.float() |
|
|
| if y.shape != y_pred.shape: |
| raise ValueError("y_pred and y should have same shapes.") |
|
|
| |
| n_len = len(y_pred.shape) |
| reduce_axis = list(range(2, n_len)) |
| intersection = torch.sum(y * y_pred, dim=reduce_axis) |
|
|
| y_o = torch.sum(y, reduce_axis) |
| y_pred_o = torch.sum(y_pred, dim=reduce_axis) |
| denominator = y_o + y_pred_o |
|
|
| return torch.where( |
| denominator > 0, |
| (2.0 * intersection) / denominator, |
| torch.tensor(float("1"), device=y_o.device), |
| ) |
|
|