Cycle-GAN / src /app.py
Sckwoky's picture
Updated model loading to use new checkpoint file name for CycleGAN model
21a8940
import os
import numpy as np
import streamlit as st
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms as tr
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3, padding=0),
nn.InstanceNorm2d(channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3, padding=0),
nn.InstanceNorm2d(channels),
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_features=64, num_residual_blocks=9):
super(Generator, self).__init__()
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, num_features, kernel_size=7, padding=0),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True),
]
in_f = num_features
out_f = in_f * 2
for _ in range(2):
model += [
nn.Conv2d(in_f, out_f, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm2d(out_f),
nn.ReLU(inplace=True),
]
in_f = out_f
out_f = in_f * 2
for _ in range(num_residual_blocks):
model += [ResidualBlock(in_f)]
out_f = in_f // 2
for _ in range(2):
model += [
nn.ConvTranspose2d(in_f, out_f, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_f),
nn.ReLU(inplace=True),
]
in_f = out_f
out_f = in_f // 2
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(num_features, out_channels, kernel_size=7, padding=0),
nn.Tanh(),
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self, in_channels=3, num_features=64, num_layers=3):
super(Discriminator, self).__init__()
model = [
nn.Conv2d(in_channels, num_features, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
]
in_f = num_features
out_f = in_f * 2
for _ in range(1, num_layers):
model += [
nn.Conv2d(in_f, out_f, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(out_f),
nn.LeakyReLU(0.2, inplace=True),
]
in_f = out_f
out_f = min(in_f * 2, 512)
model += [
nn.Conv2d(in_f, out_f, kernel_size=4, stride=1, padding=1),
nn.InstanceNorm2d(out_f),
nn.LeakyReLU(0.2, inplace=True),
]
model += [nn.Conv2d(out_f, 1, kernel_size=4, stride=1, padding=1)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class CycleGAN(nn.Module):
def __init__(self, in_channels=3, num_features_g=64, num_residual_blocks=9,
num_features_d=64, num_layers_d=3):
super(CycleGAN, self).__init__()
self.generators = nn.ModuleDict({
"a_to_b": Generator(in_channels, in_channels, num_features_g, num_residual_blocks),
"b_to_a": Generator(in_channels, in_channels, num_features_g, num_residual_blocks),
})
self.discriminators = nn.ModuleDict({
"a": Discriminator(in_channels, num_features_d, num_layers_d),
"b": Discriminator(in_channels, num_features_d, num_layers_d),
})
@st.cache_resource
def load_model():
model_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(model_dir, "cyclegan_model_v2.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
checkpoint = torch.load(model_path, map_location="cpu")
model = CycleGAN(**checkpoint["model_params"])
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, checkpoint
def preprocess(image, mean, std, size=256):
transform = tr.Compose([
tr.Resize(size),
tr.CenterCrop(size),
tr.ToTensor(),
tr.Normalize(mean=mean, std=std),
])
return transform(image).unsqueeze(0)
def postprocess(tensor, mean, std):
img = tensor.squeeze(0).detach().cpu()
mean_t = torch.tensor(mean).view(3, 1, 1)
std_t = torch.tensor(std).view(3, 1, 1)
img = img * std_t + mean_t
img = img.permute(1, 2, 0).numpy()
img = np.clip(img * 255, 0, 255).astype(np.uint8)
return Image.fromarray(img)
st.set_page_config(page_title="CycleGAN: Summer <-> Winter", layout="wide")
st.title("CycleGAN: Summer ↔ Winter (Yosemite)")
st.markdown("Upload an image and transform it between summer and winter styles!")
if "result_image" not in st.session_state:
st.session_state.result_image = None
if "result_label" not in st.session_state:
st.session_state.result_label = None
direction = st.selectbox(
"Choose transformation direction:",
["Summer → Winter (A→B)", "Winter → Summer (B→A)"]
)
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
generate_clicked = st.button("Generate")
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.subheader("Original")
st.image(image, width=512)
if generate_clicked:
with st.spinner("Loading model..."):
model, checkpoint = load_model()
mean_a = checkpoint["channel_mean_a"]
std_a = checkpoint["channel_std_a"]
mean_b = checkpoint["channel_mean_b"]
std_b = checkpoint["channel_std_b"]
with torch.no_grad():
if "A→B" in direction:
input_tensor = preprocess(image, mean_a, std_a)
output_tensor = model.generators["a_to_b"](input_tensor)
result = postprocess(output_tensor, mean_b, std_b)
label = "Winter (Generated)"
else:
input_tensor = preprocess(image, mean_b, std_b)
output_tensor = model.generators["b_to_a"](input_tensor)
result = postprocess(output_tensor, mean_a, std_a)
label = "Summer (Generated)"
st.session_state.result_image = result
st.session_state.result_label = label
if st.session_state.result_image is not None:
with col2:
st.subheader(st.session_state.result_label)
st.image(st.session_state.result_image, width=512)
st.markdown("---")
st.markdown("Built with CycleGAN (Zhu et al., 2017). Dataset: summer2winter_yosemite.")