lolzysiu commited on
Commit
90dd8cd
·
verified ·
1 Parent(s): afa49dd

Create models_conv.py

Browse files
Files changed (1) hide show
  1. models_conv.py +62 -0
models_conv.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ConvGenerator(nn.Module):
5
+ def __init__(self, latent_dim=100, channels=1):
6
+ super(ConvGenerator, self).__init__()
7
+ self.latent_dim = latent_dim
8
+
9
+ self.init_size = 7 # Initial size before upsampling
10
+ self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
11
+
12
+ self.conv_blocks = nn.Sequential(
13
+ nn.BatchNorm2d(128),
14
+ nn.Upsample(scale_factor=2),
15
+ nn.Conv2d(128, 128, 3, stride=1, padding=1),
16
+ nn.BatchNorm2d(128, 0.8),
17
+ nn.LeakyReLU(0.2, inplace=True),
18
+ nn.Upsample(scale_factor=2),
19
+ nn.Conv2d(128, 64, 3, stride=1, padding=1),
20
+ nn.BatchNorm2d(64, 0.8),
21
+ nn.LeakyReLU(0.2, inplace=True),
22
+ nn.Conv2d(64, channels, 3, stride=1, padding=1),
23
+ nn.Tanh()
24
+ )
25
+
26
+ def forward(self, z):
27
+ out = self.l1(z)
28
+ out = out.view(out.shape[0], 128, self.init_size, self.init_size)
29
+ img = self.conv_blocks(out)
30
+ return img
31
+
32
+ class ConvDiscriminator(nn.Module):
33
+ def __init__(self, channels=1):
34
+ super(ConvDiscriminator, self).__init__()
35
+
36
+ def discriminator_block(in_filters, out_filters, bn=True):
37
+ block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
38
+ nn.LeakyReLU(0.2, inplace=True),
39
+ nn.Dropout2d(0.25)]
40
+ if bn:
41
+ block.append(nn.BatchNorm2d(out_filters, 0.8))
42
+ return block
43
+
44
+ self.model = nn.Sequential(
45
+ *discriminator_block(channels, 16, bn=False),
46
+ *discriminator_block(16, 32),
47
+ *discriminator_block(32, 64),
48
+ *discriminator_block(64, 128),
49
+ )
50
+
51
+ # The height and width of downsampled image
52
+ ds_size = 28 // 2**4
53
+ self.adv_layer = nn.Sequential(
54
+ nn.Linear(128 * ds_size ** 2, 1),
55
+ nn.Sigmoid()
56
+ )
57
+
58
+ def forward(self, img):
59
+ out = self.model(img)
60
+ out = out.view(out.shape[0], -1)
61
+ validity = self.adv_layer(out)
62
+ return validity