import os import numpy as np import streamlit as st import torch import torch.nn as nn from PIL import Image from torchvision import transforms as tr class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3, padding=0), nn.InstanceNorm2d(channels), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3, padding=0), nn.InstanceNorm2d(channels), ) def forward(self, x): return x + self.block(x) class Generator(nn.Module): def __init__(self, in_channels=3, out_channels=3, num_features=64, num_residual_blocks=9): super(Generator, self).__init__() model = [ nn.ReflectionPad2d(3), nn.Conv2d(in_channels, num_features, kernel_size=7, padding=0), nn.InstanceNorm2d(num_features), nn.ReLU(inplace=True), ] in_f = num_features out_f = in_f * 2 for _ in range(2): model += [ nn.Conv2d(in_f, out_f, kernel_size=3, stride=2, padding=1), nn.InstanceNorm2d(out_f), nn.ReLU(inplace=True), ] in_f = out_f out_f = in_f * 2 for _ in range(num_residual_blocks): model += [ResidualBlock(in_f)] out_f = in_f // 2 for _ in range(2): model += [ nn.ConvTranspose2d(in_f, out_f, kernel_size=3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_f), nn.ReLU(inplace=True), ] in_f = out_f out_f = in_f // 2 model += [ nn.ReflectionPad2d(3), nn.Conv2d(num_features, out_channels, kernel_size=7, padding=0), nn.Tanh(), ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class Discriminator(nn.Module): def __init__(self, in_channels=3, num_features=64, num_layers=3): super(Discriminator, self).__init__() model = [ nn.Conv2d(in_channels, num_features, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), ] in_f = num_features out_f = in_f * 2 for _ in range(1, num_layers): model += [ nn.Conv2d(in_f, out_f, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(out_f), nn.LeakyReLU(0.2, inplace=True), ] in_f = out_f out_f = min(in_f * 2, 512) model += [ nn.Conv2d(in_f, out_f, kernel_size=4, stride=1, padding=1), nn.InstanceNorm2d(out_f), nn.LeakyReLU(0.2, inplace=True), ] model += [nn.Conv2d(out_f, 1, kernel_size=4, stride=1, padding=1)] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class CycleGAN(nn.Module): def __init__(self, in_channels=3, num_features_g=64, num_residual_blocks=9, num_features_d=64, num_layers_d=3): super(CycleGAN, self).__init__() self.generators = nn.ModuleDict({ "a_to_b": Generator(in_channels, in_channels, num_features_g, num_residual_blocks), "b_to_a": Generator(in_channels, in_channels, num_features_g, num_residual_blocks), }) self.discriminators = nn.ModuleDict({ "a": Discriminator(in_channels, num_features_d, num_layers_d), "b": Discriminator(in_channels, num_features_d, num_layers_d), }) @st.cache_resource def load_model(): model_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(model_dir, "cyclegan_model_v2.pt") if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") checkpoint = torch.load(model_path, map_location="cpu") model = CycleGAN(**checkpoint["model_params"]) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, checkpoint def preprocess(image, mean, std, size=256): transform = tr.Compose([ tr.Resize(size), tr.CenterCrop(size), tr.ToTensor(), tr.Normalize(mean=mean, std=std), ]) return transform(image).unsqueeze(0) def postprocess(tensor, mean, std): img = tensor.squeeze(0).detach().cpu() mean_t = torch.tensor(mean).view(3, 1, 1) std_t = torch.tensor(std).view(3, 1, 1) img = img * std_t + mean_t img = img.permute(1, 2, 0).numpy() img = np.clip(img * 255, 0, 255).astype(np.uint8) return Image.fromarray(img) st.set_page_config(page_title="CycleGAN: Summer <-> Winter", layout="wide") st.title("CycleGAN: Summer ↔ Winter (Yosemite)") st.markdown("Upload an image and transform it between summer and winter styles!") if "result_image" not in st.session_state: st.session_state.result_image = None if "result_label" not in st.session_state: st.session_state.result_label = None direction = st.selectbox( "Choose transformation direction:", ["Summer → Winter (A→B)", "Winter → Summer (B→A)"] ) uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) generate_clicked = st.button("Generate") if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") col1, col2 = st.columns(2) with col1: st.subheader("Original") st.image(image, width=512) if generate_clicked: with st.spinner("Loading model..."): model, checkpoint = load_model() mean_a = checkpoint["channel_mean_a"] std_a = checkpoint["channel_std_a"] mean_b = checkpoint["channel_mean_b"] std_b = checkpoint["channel_std_b"] with torch.no_grad(): if "A→B" in direction: input_tensor = preprocess(image, mean_a, std_a) output_tensor = model.generators["a_to_b"](input_tensor) result = postprocess(output_tensor, mean_b, std_b) label = "Winter (Generated)" else: input_tensor = preprocess(image, mean_b, std_b) output_tensor = model.generators["b_to_a"](input_tensor) result = postprocess(output_tensor, mean_a, std_a) label = "Summer (Generated)" st.session_state.result_image = result st.session_state.result_label = label if st.session_state.result_image is not None: with col2: st.subheader(st.session_state.result_label) st.image(st.session_state.result_image, width=512) st.markdown("---") st.markdown("Built with CycleGAN (Zhu et al., 2017). Dataset: summer2winter_yosemite.")