| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class GANLoss(nn.Module): |
| def __init__(self, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor, opt=None): |
| super(GANLoss, self).__init__() |
| self.real_label = target_real_label |
| self.fake_label = target_fake_label |
| self.real_label_tensor = None |
| self.fake_label_tensor = None |
| self.zero_tensor = None |
| self.Tensor = tensor |
| self.opt = opt |
|
|
| def get_target_tensor(self, input, target_is_real): |
| if target_is_real: |
| return torch.ones_like(input).detach() |
| else: |
| return torch.zeros_like(input).detach() |
|
|
| def get_zero_tensor(self, input): |
| return torch.zeros_like(input).detach() |
|
|
| def loss(self, inputs, target_is_real, for_discriminator=True): |
| target_tensor = self.get_target_tensor(inputs, target_is_real) |
| loss = F.binary_cross_entropy_with_logits(inputs, target_tensor) |
| return loss |
|
|
| def __call__(self, inputs, target_is_real, for_discriminator=True): |
| |
| |
| if isinstance(inputs, list): |
| loss = 0 |
| for pred_i in inputs: |
| if isinstance(pred_i, list): |
| pred_i = pred_i[-1] |
| loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) |
| bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) |
| new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) |
| loss += new_loss |
| return loss / len(inputs) |
| else: |
| return self.loss(inputs, target_is_real, for_discriminator) |
|
|