| import os |
|
|
| import numpy as np |
| import streamlit as st |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from torchvision import transforms as tr |
|
|
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, channels): |
| super(ResidualBlock, self).__init__() |
| self.block = nn.Sequential( |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(channels, channels, kernel_size=3, padding=0), |
| nn.InstanceNorm2d(channels), |
| nn.ReLU(inplace=True), |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(channels, channels, kernel_size=3, padding=0), |
| nn.InstanceNorm2d(channels), |
| ) |
|
|
| def forward(self, x): |
| return x + self.block(x) |
|
|
|
|
| class Generator(nn.Module): |
| def __init__(self, in_channels=3, out_channels=3, num_features=64, num_residual_blocks=9): |
| super(Generator, self).__init__() |
| model = [ |
| nn.ReflectionPad2d(3), |
| nn.Conv2d(in_channels, num_features, kernel_size=7, padding=0), |
| nn.InstanceNorm2d(num_features), |
| nn.ReLU(inplace=True), |
| ] |
| in_f = num_features |
| out_f = in_f * 2 |
| for _ in range(2): |
| model += [ |
| nn.Conv2d(in_f, out_f, kernel_size=3, stride=2, padding=1), |
| nn.InstanceNorm2d(out_f), |
| nn.ReLU(inplace=True), |
| ] |
| in_f = out_f |
| out_f = in_f * 2 |
| for _ in range(num_residual_blocks): |
| model += [ResidualBlock(in_f)] |
| out_f = in_f // 2 |
| for _ in range(2): |
| model += [ |
| nn.ConvTranspose2d(in_f, out_f, kernel_size=3, stride=2, padding=1, output_padding=1), |
| nn.InstanceNorm2d(out_f), |
| nn.ReLU(inplace=True), |
| ] |
| in_f = out_f |
| out_f = in_f // 2 |
| model += [ |
| nn.ReflectionPad2d(3), |
| nn.Conv2d(num_features, out_channels, kernel_size=7, padding=0), |
| nn.Tanh(), |
| ] |
| self.model = nn.Sequential(*model) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__(self, in_channels=3, num_features=64, num_layers=3): |
| super(Discriminator, self).__init__() |
| model = [ |
| nn.Conv2d(in_channels, num_features, kernel_size=4, stride=2, padding=1), |
| nn.LeakyReLU(0.2, inplace=True), |
| ] |
| in_f = num_features |
| out_f = in_f * 2 |
| for _ in range(1, num_layers): |
| model += [ |
| nn.Conv2d(in_f, out_f, kernel_size=4, stride=2, padding=1), |
| nn.InstanceNorm2d(out_f), |
| nn.LeakyReLU(0.2, inplace=True), |
| ] |
| in_f = out_f |
| out_f = min(in_f * 2, 512) |
| model += [ |
| nn.Conv2d(in_f, out_f, kernel_size=4, stride=1, padding=1), |
| nn.InstanceNorm2d(out_f), |
| nn.LeakyReLU(0.2, inplace=True), |
| ] |
| model += [nn.Conv2d(out_f, 1, kernel_size=4, stride=1, padding=1)] |
| self.model = nn.Sequential(*model) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class CycleGAN(nn.Module): |
| def __init__(self, in_channels=3, num_features_g=64, num_residual_blocks=9, |
| num_features_d=64, num_layers_d=3): |
| super(CycleGAN, self).__init__() |
| self.generators = nn.ModuleDict({ |
| "a_to_b": Generator(in_channels, in_channels, num_features_g, num_residual_blocks), |
| "b_to_a": Generator(in_channels, in_channels, num_features_g, num_residual_blocks), |
| }) |
| self.discriminators = nn.ModuleDict({ |
| "a": Discriminator(in_channels, num_features_d, num_layers_d), |
| "b": Discriminator(in_channels, num_features_d, num_layers_d), |
| }) |
|
|
|
|
| @st.cache_resource |
| def load_model(): |
| model_dir = os.path.dirname(os.path.abspath(__file__)) |
| model_path = os.path.join(model_dir, "cyclegan_model_v2.pt") |
|
|
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
| checkpoint = torch.load(model_path, map_location="cpu") |
| model = CycleGAN(**checkpoint["model_params"]) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| return model, checkpoint |
|
|
|
|
| def preprocess(image, mean, std, size=256): |
| transform = tr.Compose([ |
| tr.Resize(size), |
| tr.CenterCrop(size), |
| tr.ToTensor(), |
| tr.Normalize(mean=mean, std=std), |
| ]) |
| return transform(image).unsqueeze(0) |
|
|
|
|
| def postprocess(tensor, mean, std): |
| img = tensor.squeeze(0).detach().cpu() |
| mean_t = torch.tensor(mean).view(3, 1, 1) |
| std_t = torch.tensor(std).view(3, 1, 1) |
| img = img * std_t + mean_t |
| img = img.permute(1, 2, 0).numpy() |
| img = np.clip(img * 255, 0, 255).astype(np.uint8) |
| return Image.fromarray(img) |
|
|
|
|
| st.set_page_config(page_title="CycleGAN: Summer <-> Winter", layout="wide") |
| st.title("CycleGAN: Summer ↔ Winter (Yosemite)") |
| st.markdown("Upload an image and transform it between summer and winter styles!") |
|
|
| if "result_image" not in st.session_state: |
| st.session_state.result_image = None |
| if "result_label" not in st.session_state: |
| st.session_state.result_label = None |
|
|
| direction = st.selectbox( |
| "Choose transformation direction:", |
| ["Summer → Winter (A→B)", "Winter → Summer (B→A)"] |
| ) |
|
|
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
| generate_clicked = st.button("Generate") |
|
|
| if uploaded_file is not None: |
| image = Image.open(uploaded_file).convert("RGB") |
|
|
| col1, col2 = st.columns(2) |
|
|
| with col1: |
| st.subheader("Original") |
| st.image(image, width=512) |
|
|
| if generate_clicked: |
| with st.spinner("Loading model..."): |
| model, checkpoint = load_model() |
|
|
| mean_a = checkpoint["channel_mean_a"] |
| std_a = checkpoint["channel_std_a"] |
| mean_b = checkpoint["channel_mean_b"] |
| std_b = checkpoint["channel_std_b"] |
|
|
| with torch.no_grad(): |
| if "A→B" in direction: |
| input_tensor = preprocess(image, mean_a, std_a) |
| output_tensor = model.generators["a_to_b"](input_tensor) |
| result = postprocess(output_tensor, mean_b, std_b) |
| label = "Winter (Generated)" |
| else: |
| input_tensor = preprocess(image, mean_b, std_b) |
| output_tensor = model.generators["b_to_a"](input_tensor) |
| result = postprocess(output_tensor, mean_a, std_a) |
| label = "Summer (Generated)" |
|
|
| st.session_state.result_image = result |
| st.session_state.result_label = label |
|
|
| if st.session_state.result_image is not None: |
| with col2: |
| st.subheader(st.session_state.result_label) |
| st.image(st.session_state.result_image, width=512) |
| st.markdown("---") |
| st.markdown("Built with CycleGAN (Zhu et al., 2017). Dataset: summer2winter_yosemite.") |
|
|