| |
| |
| |
| |
| |
| |
| |
| import torch |
| import torch.nn.functional as F |
|
|
| |
| def loss_dcgan_dis(dis_fake, dis_real): |
| L1 = torch.mean(F.softplus(-dis_real)) |
| L2 = torch.mean(F.softplus(dis_fake)) |
| return L1, L2 |
|
|
|
|
| def loss_dcgan_gen(dis_fake): |
| loss = torch.mean(F.softplus(-dis_fake)) |
| return loss |
|
|
|
|
| |
| def loss_hinge_dis(dis_fake, dis_real): |
| loss_real = torch.mean(F.relu(1.0 - dis_real)) |
| loss_fake = torch.mean(F.relu(1.0 + dis_fake)) |
| return loss_real, loss_fake |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| def loss_hinge_gen(dis_fake): |
| loss = -torch.mean(dis_fake) |
| return loss |
|
|
|
|
| |
| generator_loss = loss_hinge_gen |
| discriminator_loss = loss_hinge_dis |
|
|