Spaces:
Sleeping
Sleeping
File size: 2,900 Bytes
b3a58bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvEncoder(nn.Module):
def __init__(self, z_dim: int = 4, img_channels: int = 3):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(img_channels, 32, 4, stride=2, padding=1), # 64β32
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32β16
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16β8
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8β4
nn.ReLU(inplace=True),
)
self.fc_mu = nn.Linear(256 * 4 * 4, z_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, z_dim)
def forward(self, x: torch.Tensor):
h = self.net(x).flatten(1)
return self.fc_mu(h), self.fc_logvar(h)
class ConvDecoder(nn.Module):
def __init__(self, z_dim: int = 4, img_channels: int = 3):
super().__init__()
self.fc = nn.Linear(z_dim, 256 * 4 * 4)
self.net = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 4β8
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 8β16
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 16β32
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, img_channels, 4, stride=2, padding=1), # 32β64
nn.Sigmoid(),
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
h = self.fc(z).view(-1, 256, 4, 4)
return self.net(h)
class VAE(nn.Module):
def __init__(self, z_dim: int = 4, img_channels: int = 3):
super().__init__()
self.encoder = ConvEncoder(z_dim, img_channels)
self.decoder = ConvDecoder(z_dim, img_channels)
self.z_dim = z_dim
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
if self.training:
std = (0.5 * logvar).exp()
return mu + std * torch.randn_like(std)
return mu
def encode(self, x: torch.Tensor):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def forward(self, x: torch.Tensor):
z, mu, logvar = self.encode(x)
recon = self.decode(z)
return recon, mu, logvar
def elbo_loss(
recon: torch.Tensor,
x: torch.Tensor,
mu: torch.Tensor,
logvar: torch.Tensor,
beta: float = 1.0,
):
recon_loss = F.mse_loss(recon, x, reduction="sum") / x.size(0)
kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
return recon_loss + beta * kl, recon_loss, kl
|