Sckwoky commited on
Commit
ae71a8e
·
1 Parent(s): bc2f8f0

Added initial project structure with Streamlit app and model definitions

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. Dockerfile +1 -1
  3. requirements.txt +6 -1
  4. src/app.py +187 -0
  5. src/cyclegan_model.pt +3 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ .idea/
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ torch
5
+ torchvision
6
+ streamlit
7
+ numpy
8
+ Pillow
src/app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torchvision import transforms as tr
8
+ import io
9
+
10
+
11
+ # ===== Model definitions =====
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, channels):
14
+ super(ResidualBlock, self).__init__()
15
+ self.block = nn.Sequential(
16
+ nn.ReflectionPad2d(1),
17
+ nn.Conv2d(channels, channels, kernel_size=3, padding=0),
18
+ nn.InstanceNorm2d(channels),
19
+ nn.ReLU(inplace=True),
20
+ nn.ReflectionPad2d(1),
21
+ nn.Conv2d(channels, channels, kernel_size=3, padding=0),
22
+ nn.InstanceNorm2d(channels),
23
+ )
24
+
25
+ def forward(self, x):
26
+ return x + self.block(x)
27
+
28
+
29
+ class Generator(nn.Module):
30
+ def __init__(self, in_channels=3, out_channels=3, num_features=64, num_residual_blocks=9):
31
+ super(Generator, self).__init__()
32
+ model = [
33
+ nn.ReflectionPad2d(3),
34
+ nn.Conv2d(in_channels, num_features, kernel_size=7, padding=0),
35
+ nn.InstanceNorm2d(num_features),
36
+ nn.ReLU(inplace=True),
37
+ ]
38
+ in_f = num_features
39
+ out_f = in_f * 2
40
+ for _ in range(2):
41
+ model += [
42
+ nn.Conv2d(in_f, out_f, kernel_size=3, stride=2, padding=1),
43
+ nn.InstanceNorm2d(out_f),
44
+ nn.ReLU(inplace=True),
45
+ ]
46
+ in_f = out_f
47
+ out_f = in_f * 2
48
+ for _ in range(num_residual_blocks):
49
+ model += [ResidualBlock(in_f)]
50
+ out_f = in_f // 2
51
+ for _ in range(2):
52
+ model += [
53
+ nn.ConvTranspose2d(in_f, out_f, kernel_size=3, stride=2, padding=1, output_padding=1),
54
+ nn.InstanceNorm2d(out_f),
55
+ nn.ReLU(inplace=True),
56
+ ]
57
+ in_f = out_f
58
+ out_f = in_f // 2
59
+ model += [
60
+ nn.ReflectionPad2d(3),
61
+ nn.Conv2d(num_features, out_channels, kernel_size=7, padding=0),
62
+ nn.Tanh(),
63
+ ]
64
+ self.model = nn.Sequential(*model)
65
+
66
+ def forward(self, x):
67
+ return self.model(x)
68
+
69
+
70
+ class Discriminator(nn.Module):
71
+ def __init__(self, in_channels=3, num_features=64, num_layers=3):
72
+ super(Discriminator, self).__init__()
73
+ model = [
74
+ nn.Conv2d(in_channels, num_features, kernel_size=4, stride=2, padding=1),
75
+ nn.LeakyReLU(0.2, inplace=True),
76
+ ]
77
+ in_f = num_features
78
+ out_f = in_f * 2
79
+ for i in range(1, num_layers):
80
+ model += [
81
+ nn.Conv2d(in_f, out_f, kernel_size=4, stride=2, padding=1),
82
+ nn.InstanceNorm2d(out_f),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ ]
85
+ in_f = out_f
86
+ out_f = min(in_f * 2, 512)
87
+ model += [
88
+ nn.Conv2d(in_f, out_f, kernel_size=4, stride=1, padding=1),
89
+ nn.InstanceNorm2d(out_f),
90
+ nn.LeakyReLU(0.2, inplace=True),
91
+ ]
92
+ model += [nn.Conv2d(out_f, 1, kernel_size=4, stride=1, padding=1)]
93
+ self.model = nn.Sequential(*model)
94
+
95
+ def forward(self, x):
96
+ return self.model(x)
97
+
98
+
99
+ class CycleGAN(nn.Module):
100
+ def __init__(self, in_channels=3, num_features_g=64, num_residual_blocks=9,
101
+ num_features_d=64, num_layers_d=3):
102
+ super(CycleGAN, self).__init__()
103
+ self.generators = nn.ModuleDict({
104
+ "a_to_b": Generator(in_channels, in_channels, num_features_g, num_residual_blocks),
105
+ "b_to_a": Generator(in_channels, in_channels, num_features_g, num_residual_blocks),
106
+ })
107
+ self.discriminators = nn.ModuleDict({
108
+ "a": Discriminator(in_channels, num_features_d, num_layers_d),
109
+ "b": Discriminator(in_channels, num_features_d, num_layers_d),
110
+ })
111
+
112
+
113
+ # ===== Load model =====
114
+ @st.cache_resource
115
+ def load_model():
116
+ checkpoint = torch.load("cyclegan_model.pt", map_location="cpu")
117
+ model = CycleGAN(**checkpoint["model_params"])
118
+ model.load_state_dict(checkpoint["model_state_dict"])
119
+ model.eval()
120
+ return model, checkpoint
121
+
122
+
123
+ def preprocess(image, mean, std, size=256):
124
+ transform = tr.Compose([
125
+ tr.Resize(size),
126
+ tr.CenterCrop(size),
127
+ tr.ToTensor(),
128
+ tr.Normalize(mean=mean, std=std),
129
+ ])
130
+ return transform(image).unsqueeze(0)
131
+
132
+
133
+ def postprocess(tensor, mean, std):
134
+ img = tensor.squeeze(0).detach().cpu()
135
+ mean_t = torch.tensor(mean).view(3, 1, 1)
136
+ std_t = torch.tensor(std).view(3, 1, 1)
137
+ img = img * std_t + mean_t
138
+ img = img.permute(1, 2, 0).numpy()
139
+ img = np.clip(img * 255, 0, 255).astype(np.uint8)
140
+ return Image.fromarray(img)
141
+
142
+
143
+ # ===== Streamlit UI =====
144
+ st.set_page_config(page_title="CycleGAN: Summer <-> Winter", layout="wide")
145
+ st.title("🏔️ CycleGAN: Summer ↔ Winter (Yosemite)")
146
+ st.markdown("Upload an image and transform it between summer and winter styles!")
147
+
148
+ model, checkpoint = load_model()
149
+ mean_a = checkpoint["channel_mean_a"]
150
+ std_a = checkpoint["channel_std_a"]
151
+ mean_b = checkpoint["channel_mean_b"]
152
+ std_b = checkpoint["channel_std_b"]
153
+
154
+ direction = st.selectbox(
155
+ "Choose transformation direction:",
156
+ ["Summer → Winter (A→B)", "Winter → Summer (B→A)"]
157
+ )
158
+
159
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
160
+
161
+ if uploaded_file is not None:
162
+ image = Image.open(uploaded_file).convert("RGB")
163
+
164
+ col1, col2 = st.columns(2)
165
+
166
+ with col1:
167
+ st.subheader("Original")
168
+ st.image(image, use_container_width=True)
169
+
170
+ with torch.no_grad():
171
+ if "A→B" in direction:
172
+ input_tensor = preprocess(image, mean_a, std_a)
173
+ output_tensor = model.generators["a_to_b"](input_tensor)
174
+ result = postprocess(output_tensor, mean_b, std_b)
175
+ label = "Winter (Generated)"
176
+ else:
177
+ input_tensor = preprocess(image, mean_b, std_b)
178
+ output_tensor = model.generators["b_to_a"](input_tensor)
179
+ result = postprocess(output_tensor, mean_a, std_a)
180
+ label = "Summer (Generated)"
181
+
182
+ with col2:
183
+ st.subheader(label)
184
+ st.image(result, use_container_width=True)
185
+
186
+ st.markdown("---")
187
+ st.markdown("Built with CycleGAN (Zhu et al., 2017). Dataset: summer2winter_yosemite.")
src/cyclegan_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a47525227bfddb83260d8e9733e0da1016157bb87cb9cea51eb2e5d741af1de
3
+ size 84861181