lsr-lang / models /vae.py
singhanshuman's picture
Upload models/vae.py with huggingface_hub
b3a58bf verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvEncoder(nn.Module):
def __init__(self, z_dim: int = 4, img_channels: int = 3):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(img_channels, 32, 4, stride=2, padding=1), # 64β†’32
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32β†’16
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16β†’8
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8β†’4
nn.ReLU(inplace=True),
)
self.fc_mu = nn.Linear(256 * 4 * 4, z_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, z_dim)
def forward(self, x: torch.Tensor):
h = self.net(x).flatten(1)
return self.fc_mu(h), self.fc_logvar(h)
class ConvDecoder(nn.Module):
def __init__(self, z_dim: int = 4, img_channels: int = 3):
super().__init__()
self.fc = nn.Linear(z_dim, 256 * 4 * 4)
self.net = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 4β†’8
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 8β†’16
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 16β†’32
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, img_channels, 4, stride=2, padding=1), # 32β†’64
nn.Sigmoid(),
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
h = self.fc(z).view(-1, 256, 4, 4)
return self.net(h)
class VAE(nn.Module):
def __init__(self, z_dim: int = 4, img_channels: int = 3):
super().__init__()
self.encoder = ConvEncoder(z_dim, img_channels)
self.decoder = ConvDecoder(z_dim, img_channels)
self.z_dim = z_dim
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
if self.training:
std = (0.5 * logvar).exp()
return mu + std * torch.randn_like(std)
return mu
def encode(self, x: torch.Tensor):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def forward(self, x: torch.Tensor):
z, mu, logvar = self.encode(x)
recon = self.decode(z)
return recon, mu, logvar
def elbo_loss(
recon: torch.Tensor,
x: torch.Tensor,
mu: torch.Tensor,
logvar: torch.Tensor,
beta: float = 1.0,
):
recon_loss = F.mse_loss(recon, x, reduction="sum") / x.size(0)
kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
return recon_loss + beta * kl, recon_loss, kl