| import pytorch_lightning as pl |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from typing import Dict, List, Optional, OrderedDict, Tuple |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__( |
| self, |
| hidden_size: Optional[int] = 64, |
| channels: Optional[int] = 3, |
| kernel_size: Optional[int] = 4, |
| stride: Optional[int] = 2, |
| padding: Optional[int] = 1, |
| negative_slope: Optional[float] = 0.2, |
| bias: Optional[bool] = False, |
| ): |
| """ |
| Initializes the discriminator. |
| |
| Parameters |
| ---------- |
| hidden_size : int, optional |
| The input size. (the default is 64) |
| channels : int, optional |
| The number of channels. (default: 3) |
| kernel_size : int, optional |
| The kernal size. (default: 4) |
| stride : int, optional |
| The stride. (default: 2) |
| padding : int, optional |
| The padding. (default: 1) |
| negative_slope : float, optional |
| The negative slope. (default: 0.2) |
| bias : bool, optional |
| Whether to use bias. (default: False) |
| """ |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.channels = channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.negative_slope = negative_slope |
| self.bias = bias |
|
|
| self.model = nn.Sequential( |
| nn.utils.spectral_norm( |
| nn.Conv2d( |
| self.channels, self.hidden_size, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| ), |
| nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
| nn.utils.spectral_norm( |
| nn.Conv2d( |
| hidden_size, hidden_size * 2, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| ), |
| nn.BatchNorm2d(hidden_size * 2), |
| nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
| nn.utils.spectral_norm( |
| nn.Conv2d( |
| hidden_size * 2, hidden_size * 4, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| ), |
| nn.BatchNorm2d(hidden_size * 4), |
| nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
| nn.utils.spectral_norm( |
| nn.Conv2d( |
| hidden_size * 4, hidden_size * 8, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| ), |
| nn.BatchNorm2d(hidden_size * 8), |
| nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
| nn.utils.spectral_norm( |
| nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), |
| ), |
| nn.Flatten(), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| def forward(self, input_img: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward propagation. |
| |
| Parameters |
| ---------- |
| input_img : torch.Tensor |
| The input image. |
| |
| Returns |
| ------- |
| torch.Tensor |
| The output. |
| """ |
| logits = self.model(input_img) |
| return logits |
|
|
|
|
| class Generator(nn.Module): |
| def __init__( |
| self, |
| hidden_size: Optional[int] = 64, |
| latent_size: Optional[int] = 128, |
| channels: Optional[int] = 3, |
| kernel_size: Optional[int] = 4, |
| stride: Optional[int] = 2, |
| padding: Optional[int] = 1, |
| bias: Optional[bool] = False, |
| ): |
| """ |
| Initializes the generator. |
| |
| Parameters |
| ---------- |
| hidden_size : int, optional |
| The hidden size. (default: 64) |
| latent_size : int, optional |
| The latent size. (default: 128) |
| channels : int, optional |
| The number of channels. (default: 3) |
| kernel_size : int, optional |
| The kernel size. (default: 4) |
| stride : int, optional |
| The stride. (default: 2) |
| padding : int, optional |
| The padding. (default: 1) |
| bias : bool, optional |
| Whether to use bias. (default: False) |
| """ |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.latent_size = latent_size |
| self.channels = channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.bias = bias |
|
|
| self.model = nn.Sequential( |
| nn.ConvTranspose2d( |
| self.latent_size, self.hidden_size * 8, |
| kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias |
| ), |
| nn.BatchNorm2d(self.hidden_size * 8), |
| nn.ReLU(inplace=True), |
|
|
| nn.ConvTranspose2d( |
| self.hidden_size * 8, self.hidden_size * 4, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| nn.BatchNorm2d(self.hidden_size * 4), |
| nn.ReLU(inplace=True), |
|
|
| nn.ConvTranspose2d( |
| self.hidden_size * 4, self.hidden_size * 2, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| nn.BatchNorm2d(self.hidden_size * 2), |
| nn.ReLU(inplace=True), |
|
|
| nn.ConvTranspose2d( |
| self.hidden_size * 2, self.hidden_size, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| nn.BatchNorm2d(self.hidden_size), |
| nn.ReLU(inplace=True), |
|
|
| nn.ConvTranspose2d( |
| self.hidden_size, self.channels, |
| kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
| ), |
| nn.Tanh() |
| ) |
|
|
| |
| def forward(self, input_noise: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward propagation. |
| |
| Parameters |
| ---------- |
| input_noise : torch.Tensor |
| The input image. |
| |
| Returns |
| ------- |
| torch.Tensor |
| The output. |
| """ |
| fake_img = self.model(input_noise) |
| return fake_img |
|
|
|
|
| class DocuGAN(pl.LightningModule): |
| def __init__( |
| self, |
| hidden_size: Optional[int] = 64, |
| latent_size: Optional[int] = 128, |
| num_channel: Optional[int] = 3, |
| learning_rate: Optional[float] = 0.0002, |
| batch_size: Optional[int] = 128, |
| bias1: Optional[float] = 0.5, |
| bias2: Optional[float] = 0.999, |
| ): |
| """ |
| Initializes the LightningGan. |
| |
| Parameters |
| ---------- |
| hidden_size : int, optional |
| The hidden size. (default: 64) |
| latent_size : int, optional |
| The latent size. (default: 128) |
| num_channel : int, optional |
| The number of channels. (default: 3) |
| learning_rate : float, optional |
| The learning rate. (default: 0.0002) |
| batch_size : int, optional |
| The batch size. (default: 128) |
| bias1 : float, optional |
| The bias1. (default: 0.5) |
| bias2 : float, optional |
| The bias2. (default: 0.999) |
| """ |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.latent_size = latent_size |
| self.num_channel = num_channel |
| self.learning_rate = learning_rate |
| self.batch_size = batch_size |
| self.bias1 = bias1 |
| self.bias2 = bias2 |
| self.criterion = nn.BCELoss() |
| self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1) |
| self.save_hyperparameters() |
|
|
| self.generator = Generator( |
| latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size |
| ) |
| self.generator.apply(self.weights_init) |
| |
| self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size) |
| self.discriminator.apply(self.weights_init) |
|
|
| |
|
|
|
|
| def weights_init(self, m: nn.Module) -> None: |
| """ |
| Initializes the weights. |
| |
| Parameters |
| ---------- |
| m : nn.Module |
| The module. |
| """ |
| classname = m.__class__.__name__ |
| if classname.find("Conv") != -1: |
| nn.init.normal_(m.weight.data, 0.0, 0.02) |
| elif classname.find("BatchNorm") != -1: |
| nn.init.normal_(m.weight.data, 1.0, 0.02) |
| nn.init.constant_(m.bias.data, 0) |
|
|
| |
| def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]: |
| """ |
| Configures the optimizers. |
| |
| Returns |
| ------- |
| Tuple[List[torch.optim.Optimizer], List] |
| The optimizers and the LR schedulers. |
| """ |
| opt_generator = torch.optim.Adam( |
| self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2) |
| ) |
| opt_discriminator = torch.optim.Adam( |
| self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2) |
| ) |
| return [opt_generator, opt_discriminator], [] |
|
|
|
|
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward propagation. |
| |
| Parameters |
| ---------- |
| z : torch.Tensorh |
| The latent vector. |
| |
| Returns |
| ------- |
| torch.Tensor |
| The output. |
| """ |
| return self.generator(z) |
|
|
|
|
| def training_step( |
| self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int |
| ) -> Dict: |
| """ |
| Training step. |
| |
| Parameters |
| ---------- |
| batch : Tuple[torch.Tensor, torch.Tensor] |
| The batch. |
| batch_idx : int |
| The batch index. |
| optimizer_idx : int |
| The optimizer index. |
| |
| Returns |
| ------- |
| Dict |
| The training loss. |
| """ |
| real_images = batch["tr_image"] |
|
|
| if optimizer_idx == 0: |
| fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1) |
| fake_random_noise = fake_random_noise.type_as(real_images) |
| fake_images = self(fake_random_noise) |
|
|
| |
| preds = self.discriminator(fake_images) |
| loss = self.criterion(preds, torch.ones_like(preds)) |
| self.log("g_loss", loss, on_step=False, on_epoch=True) |
|
|
| tqdm_dict = {"g_loss": loss} |
| output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) |
| return output |
|
|
| elif optimizer_idx == 1: |
| real_preds = self.discriminator(real_images) |
| real_loss = self.criterion(real_preds, torch.ones_like(real_preds)) |
|
|
| |
| real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1) |
| real_random_noise = real_random_noise.type_as(real_images) |
| fake_images = self(real_random_noise) |
|
|
| |
| fake_preds = self.discriminator(fake_images) |
| fake_loss = self.criterion(fake_preds, torch.zeros_like(fake_preds)) |
|
|
| |
| loss = real_loss + fake_loss |
| self.log("d_loss", loss, on_step=False, on_epoch=True) |
|
|
| tqdm_dict = {"d_loss": loss} |
| output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) |
| return output |
|
|