from math import log2 import torch from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin import torch.nn.functional as F # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py class Discriminator(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, img_resolution=256): super().__init__() depth = int(log2(img_resolution) -2) d = max(depth - 3, 3) layers = [ nn.utils.spectral_norm( nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) ), nn.LeakyReLU(0.2), ] for i in range(depth - 1): c_in = hidden_channels // (2 ** max((d - i), 0)) c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) layers.append(nn.InstanceNorm2d(c_out)) layers.append(nn.LeakyReLU(0.2)) self.encoder = nn.Sequential(*layers) self.shuffle = nn.Conv2d( (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 ) # self.fc = nn.Linear(hidden_channels * int((img_resolution/(2**depth))**2), # hidden_channels,) # self.out = nn.Linear(hidden_channels,1) # self.logits = nn.Sigmoid() def forward(self, x, cond=None): x = self.encoder(x) if cond is not None: cond = cond.view( cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) x = torch.cat([x, cond], dim=1) x = self.shuffle(x) x = x.flatten(1).mean(-1) # x = self.fc(x.flatten(1)) # x = F.leaky_relu(x, 0.2) # x = self.out(x) return x