riciii7 commited on
Commit
d63970f
·
1 Parent(s): 4c508c2

Upload 7 files

Browse files
Files changed (7) hide show
  1. .dockerignore +7 -0
  2. .gitignore +5 -0
  3. Dockerfile +20 -0
  4. main.py +15 -0
  5. model.py +368 -0
  6. requirements.txt +30 -0
  7. utils.py +26 -0
.dockerignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ env/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.DS_Store
7
+ .vscode/
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ .vscode/
3
+ *.pyc
4
+ env/
5
+ *.pt
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ RUN apt-get update && \
8
+ apt-get install -y --no-install-recommends \
9
+ libglib2.0-0 \
10
+ libsm6 \
11
+ libxext6 \
12
+ libxrender-dev \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ WORKDIR /app
16
+
17
+ COPY --chown=user . .
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
+ from utils import load_model, generate_image
4
+
5
+ app = FastAPI()
6
+ model = load_model()
7
+
8
+ @app.get("/generate")
9
+ def generate():
10
+ image_stream = generate_image(model, steps=5, alpha=1.0)
11
+ return StreamingResponse(image_stream, media_type="image/png")
12
+
13
+ @app.get("/ping")
14
+ def ping():
15
+ return {"status": "pong"}
model.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, optim
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from typing import Any, Callable, Optional
5
+ import math
6
+
7
+ class WSLinear(nn.Module):
8
+ '''
9
+ Weighted scale linear for equalized learning rate.
10
+
11
+ Args:
12
+ in_features (int): The number of input features.
13
+ out_features (int): The number of output features.
14
+ '''
15
+
16
+ def __init__(self, in_features: int, out_features: int) -> None:
17
+ super(WSLinear, self).__init__()
18
+ self.in_features = in_features
19
+ self.out_features = out_features
20
+
21
+ self.linear = nn.Linear(self.in_features, self.out_features)
22
+ self.scale = (2 / self.in_features) ** 0.5
23
+ self.bias = self.linear.bias
24
+ self.linear.bias = None
25
+
26
+ self._init_weights()
27
+
28
+ def _init_weights(self) -> None:
29
+ nn.init.normal_(self.linear.weight)
30
+ nn.init.zeros_(self.bias)
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ return self.linear(x * self.scale) + self.bias
34
+
35
+ class WSConv2d(nn.Module):
36
+ """
37
+ Weight-scaled Conv2d layer for equalized learning rate.
38
+
39
+ Args:
40
+ in_channels (int): Number of input channels.
41
+ out_channels (int): Number of output channels.
42
+ kernel_size (int, optional): Size of the convolving kernel. Default: 3.
43
+ stride (int, optional): Stride of the convolution. Default: 1.
44
+ padding (int, optional): Padding added to all sides of the input. Default: 1.
45
+ gain (float, optional): Gain factor for weight initialization. Default: 2.
46
+ """
47
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
48
+ super().__init__()
49
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
50
+ self.scale = (gain / (in_channels * kernel_size ** 2)) ** 0.5
51
+ self.bias = self.conv.bias
52
+ self.conv.bias = None # Remove bias to apply it after scaling
53
+
54
+ # Initialize weights
55
+ nn.init.normal_(self.conv.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, x):
59
+ return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
60
+
61
+ class Mapping(nn.Module):
62
+ '''
63
+ Mapping network.
64
+
65
+ Args:
66
+ features (int): Number of features in the input and output.
67
+ num_layers (int): Number of layers in the feed forward network.
68
+ num_styles (int): Number of styles to generate.
69
+ '''
70
+
71
+ def __init__(
72
+ self,
73
+ features: int,
74
+ num_styles: int,
75
+ num_layers: int = 8,
76
+ ) -> None:
77
+ super(Mapping, self).__init__()
78
+ self.features = features
79
+ self.num_layers = num_layers
80
+ self.num_styles = num_styles
81
+
82
+ layers = []
83
+ for _ in range(self.num_layers):
84
+ layers.append(WSLinear(self.features, self.features))
85
+ layers.append(nn.LeakyReLU(0.2))
86
+
87
+ self.fc = nn.Sequential(*layers)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ '''
91
+ Args:
92
+ x (torch.Tensor): Input tensor of shape (b, l).
93
+
94
+ Returns:
95
+ torch.Tensor: Output tensor with the same shape as input.
96
+ '''
97
+
98
+ x = self.fc(x) # (b, l)
99
+ return x
100
+
101
+ class AdaIN(nn.Module):
102
+ '''
103
+ Adaptive Instance Normalization (AdaIN)
104
+ AdaIN(x_i, y) = y_s,i * (x_i - mean(x_i)) / std(x_i) + y_b,i
105
+
106
+ Args:
107
+ eps (float, optional): Small value to avoid division by zero. Default value is 0.00001.
108
+ '''
109
+
110
+ def __init__(self, eps: float= 1e-5) -> None:
111
+ super(AdaIN, self).__init__()
112
+ self.eps = eps
113
+
114
+ def forward(
115
+ self,
116
+ x: torch.Tensor,
117
+ scale: torch.Tensor,
118
+ shift: torch.Tensor
119
+ ) -> torch.Tensor:
120
+ '''
121
+ Args:
122
+ x (torch.Tensor): Input tensor of shape (b, c, h, w).
123
+ scale (torch.Tensor): Scale tensor of shape (b, c).
124
+ shift (torch.Tensor): Shift tensor of shape (b, c).
125
+
126
+ Returns:
127
+ torch.Tensor: Output tensor of shape (b, c, h, w).
128
+ '''
129
+
130
+ b, c, *_ = x.shape
131
+
132
+ mean = x.mean(dim=(2, 3), keepdim=True) # (b, c, 1, 1)
133
+ std = x.std(dim=(2, 3), keepdim=True) # (b, c, 1, 1)
134
+ x_norm = (x - mean) / (std ** 2 + self.eps) ** .5
135
+
136
+ scale = scale.view(b, c, 1, 1) # (b, c, 1, 1)
137
+ shift = scale.view(b, c, 1, 1) # (b, c, 1, 1)
138
+ outputs = scale * x_norm + shift # (b, c, h, w)
139
+
140
+ return outputs
141
+
142
+ class SynthesisLayer(nn.Module):
143
+ '''
144
+ Synthesis network layer which consist of:
145
+ - Conv2d.
146
+ - AdaIN.
147
+ - Affine transformation.
148
+ - Noise injection.
149
+
150
+ Args:
151
+ in_channels (int): The number of input channels.
152
+ out_channels (int): The number of output channels.
153
+ latent_features (int): The number of latent features.
154
+ use_conv (bool, optional): Whether to use convolution or not. Default value is True.
155
+ '''
156
+
157
+ def __init__(
158
+ self,
159
+ in_channels: int,
160
+ out_channels: int,
161
+ latent_features: int,
162
+ use_conv: bool = True
163
+ ) -> None:
164
+ super(SynthesisLayer, self).__init__()
165
+ self.in_channels = in_channels
166
+ self.out_channels = out_channels
167
+ self.latent_features = latent_features
168
+ self.use_conv = use_conv
169
+
170
+ self.conv = nn.Sequential(
171
+ WSConv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
172
+ nn.LeakyReLU(0.2)
173
+ ) if self.use_conv else nn.Identity()
174
+ self.norm = AdaIN()
175
+ self.scale_transform = WSLinear(self.latent_features, self.out_channels)
176
+ self.shift_transform = WSLinear(self.latent_features, self.out_channels)
177
+ self.noise_factor = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1))
178
+
179
+ self._init_weights()
180
+
181
+ def _init_weights(self) -> None:
182
+ for m in self.modules():
183
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
184
+ nn.init.normal_(m.weight)
185
+ if m.bias is not None:
186
+ nn.init.zeros_(m.bias)
187
+ nn.init.ones_(self.scale_transform.bias)
188
+
189
+ def forward(
190
+ self,
191
+ x: torch.Tensor,
192
+ w: torch.Tensor,
193
+ noise: Optional[torch.Tensor] = None
194
+ ) -> torch.Tensor:
195
+ '''
196
+ Args:
197
+ x (torch.Tensor): Input tensor of shape (b, c, h, w).
198
+ w (torch.Tensor): Latent space vector of shape (b, l).
199
+ noise (torch.Tensor, optional): Noise tensor of shape (b, 1, h, w). Default value is None.
200
+
201
+ Returns:
202
+ torch.Tensor: Output tensor of shape (b, c, h, w).
203
+ '''
204
+
205
+ b, _, h, w_ = x.shape
206
+ x = self.conv(x) # (b, o_c, h, w)
207
+ if noise is None:
208
+ noise = torch.randn(b, 1, h, w_, device=x.device) # (b, 1, h, w)
209
+ x += self.noise_factor * noise # (b, o_c, h, w)
210
+ y_s = self.scale_transform(w) # (b, o_c)
211
+ y_b = self.shift_transform(w) # (b, o_c)
212
+ x = self.norm(x, y_s, y_b) # (b, i_c, h, w)
213
+
214
+ return x
215
+
216
+
217
+ class SynthesisBlock(nn.Module):
218
+ '''
219
+ Synthesis network block which consist of:
220
+ - Optional upsampling.
221
+ - 2 Synthesis Layers.
222
+
223
+ Args:
224
+ in_channels (int): The number of input channels.
225
+ out_channels (int): The number of output channels.
226
+ latent_features (int): The number of latent features.
227
+ use_conv (bool, optional): Whether to use convolution or not. Default value is True.
228
+ upsample (bool, optional): Whether to use upsampling or not. Default value is True.
229
+ '''
230
+
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ out_channels: int,
235
+ latent_features: int,
236
+ *,
237
+ use_conv: bool = True,
238
+ upsample: bool = True
239
+ ) -> None:
240
+ super(SynthesisBlock, self).__init__()
241
+ self.in_channels = in_channels
242
+ self.out_channels = out_channels
243
+ self.latent_features = latent_features
244
+ self.use_conv = use_conv
245
+ self.upsample = upsample
246
+
247
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') if self.upsample else nn.Identity()
248
+ self.layers = nn.ModuleList([
249
+ SynthesisLayer(self.in_channels, self.in_channels, self.latent_features, use_conv=self.use_conv),
250
+ SynthesisLayer(self.in_channels, self.out_channels, self.latent_features)
251
+ ])
252
+
253
+ def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
254
+ '''
255
+ Args:
256
+ x (torch.Tensor): Input tensor of shape (b, c, h, w).
257
+ w (torch.Tensor): Latent vector of shape (b, l).
258
+
259
+ Returns:
260
+ torch.Tensor: Output tensor of shape (b, c, h, w) if not upsample else (b, c, 2h, 2w).
261
+ '''
262
+
263
+ x = self.upsample(x) # (b, c, h, w) if not upsample else (b, c, 2h, 2w)
264
+
265
+ for layer in self.layers:
266
+ x = layer(x, w) # (b, c, h, w) if not upsample else (b, c, 2h, 2w)
267
+
268
+ return x
269
+
270
+ class Synthesis(nn.Module):
271
+ '''
272
+ Synthesis network which consist of:
273
+ - Constant tensor.
274
+ - Synthesis blocks.
275
+ - ToRGB convolutions.
276
+
277
+ Args:
278
+ resolution (int): The resolution of the image.
279
+ const_channels (int): The number of channels in the constant tensor. Default value is 512.
280
+ '''
281
+
282
+ def __init__(self, resolution: int, const_channels: int = 512) -> None:
283
+ super(Synthesis, self).__init__()
284
+ self.const_channels = const_channels
285
+ self.resolution = resolution
286
+
287
+ self.resolution_levels = int(math.log2(resolution) - 1)
288
+
289
+ self.constant = nn.Parameter(torch.ones(1, self.const_channels, 4, 4)) # (c, 4, 4)
290
+
291
+ in_channels = self.const_channels
292
+ blocks = [ SynthesisBlock(in_channels, in_channels, self.const_channels, use_conv=False, upsample=False) ]
293
+ to_rgb = [ WSConv2d(in_channels, 3, kernel_size=1, padding=0) ]
294
+
295
+ for _ in range(self.resolution_levels - 1):
296
+ blocks.append(SynthesisBlock(in_channels, in_channels // 2, self.const_channels))
297
+ to_rgb.append(WSConv2d(in_channels // 2, 3, kernel_size=1, padding=0))
298
+ in_channels //= 2
299
+
300
+ self.blocks = nn.ModuleList(blocks)
301
+ self.to_rgb = nn.ModuleList(to_rgb)
302
+
303
+ def forward(self, w: torch.Tensor, alpha: float, steps: int) -> torch.Tensor:
304
+ '''
305
+ Args:
306
+ w (torch.Tensor): Latent space vector of shape (b, l).
307
+ alpha (float): Fade in alpha value.
308
+ steps (int): The number of steps starting from 0.
309
+
310
+ Returns:
311
+ torch.Tensor: Output tensor of shape (b, 3, h, w).
312
+ '''
313
+
314
+ b = w.size(0)
315
+ x = self.constant.expand(b, -1, -1, -1).clone() # (b, c, h, w)
316
+
317
+ if steps == 0:
318
+ x = self.blocks[0](x, w) # (b, c, h, w)
319
+ x = self.to_rgb[0](x) # (b, c, h, w)
320
+ return x
321
+
322
+ for i in range(steps):
323
+ x = self.blocks[i](x, w) # (b, c, h/2, w/2)
324
+
325
+ old_rgb = self.to_rgb[steps - 1](x) # (b, 3, h/2, w/2)
326
+
327
+ x = self.blocks[steps](x, w) # (b, 3, h, w)
328
+ new_rgb = self.to_rgb[steps](x) # (b, 3, h, w)
329
+ old_rgb = F.interpolate(old_rgb, scale_factor=2, mode='bilinear', align_corners=False) # (b, 3, h, w)
330
+
331
+ x = (1 - alpha) * old_rgb + alpha * new_rgb # (b, 3, h, w)
332
+
333
+ return x
334
+
335
+ class StyleGAN(nn.Module):
336
+ '''
337
+ StyleGAN implementation.
338
+
339
+ Args:
340
+ num_features (int): The number of features in the latent space vector.
341
+ resolution (int): The resolution of the image.
342
+ num_blocks (int, optional): The number of blocks in the synthesis network. Default value is 10.
343
+ '''
344
+
345
+ def __init__(self, num_features: int, resolution: int, num_blocks: int = 10):
346
+ super(StyleGAN, self).__init__()
347
+ self.num_features = num_features
348
+ self.resolution = resolution
349
+ self.num_blocks = num_blocks
350
+
351
+ self.mapping = Mapping(self.num_features, self.num_blocks)
352
+ self.synthesis = Synthesis(self.resolution, self.num_features)
353
+
354
+ def forward(self, x: torch.Tensor, alpha: float, steps: int) -> torch.Tensor:
355
+ '''
356
+ Args:
357
+ x (torch.Tensor): Random input tensor of shape (b, l).
358
+ alpha (float): Fade in alpha value.
359
+ steps (int): The number of steps starting from 0.
360
+
361
+ Returns:
362
+ torch.Tensor: Output tensor of shape (b, c, h, w).
363
+ '''
364
+
365
+ w = self.mapping(x) # (b, l)
366
+ outputs = self.synthesis(w, alpha, steps) # (b, c, h, w)
367
+
368
+ return outputs
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML/DL
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ triton==3.2.0
5
+
6
+ # FastAPI & Server
7
+ fastapi==0.115.12
8
+ uvicorn==0.34.0
9
+
10
+ # Scientific stack
11
+ numpy==2.2.4
12
+ pillow==11.1.0
13
+ sympy==1.13.1
14
+ networkx==3.4.2
15
+ fsspec==2025.3.2
16
+
17
+ # Typing & Pydantic
18
+ pydantic==2.11.2
19
+ pydantic_core==2.33.1
20
+ typing_extensions==4.13.1
21
+ typing-inspection==0.4.0
22
+
23
+ # Async tools (used by FastAPI)
24
+ anyio==4.9.0
25
+ sniffio==1.3.1
26
+ h11==0.14.0
27
+ click==8.1.8
28
+ Jinja2==3.1.6
29
+ MarkupSafe==3.0.2
30
+ idna==3.10
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import StyleGAN
2
+ import torch
3
+ from io import BytesIO
4
+ from torchvision.utils import save_image
5
+
6
+ LATENT_FEATURES = 512
7
+ RESOLUTION = 128
8
+
9
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+ def load_model(path='model_128.pt'):
11
+ model = StyleGAN(LATENT_FEATURES, RESOLUTION).to(DEVICE)
12
+ last_checkpoint = torch.load(path, map_location=DEVICE)
13
+ model.load_state_dict(last_checkpoint['generator'], strict=False)
14
+ model.eval()
15
+ return model
16
+
17
+ def generate_image(generator, steps=5, alpha=1.0):
18
+ with torch.no_grad():
19
+ image = generator(torch.randn(1, LATENT_FEATURES, device=DEVICE), alpha=1.0, steps=steps)
20
+ image = image.tanh()
21
+ image = (image + 1) / 2
22
+
23
+ buffer = BytesIO()
24
+ save_image(image, buffer, format='PNG')
25
+ buffer.seek(0)
26
+ return buffer