LanXiaoPang613 commited on
Commit
9294feb
·
unverified ·
1 Parent(s): f9e6118

Add files via upload

Browse files

fix cnn network error

Files changed (1) hide show
  1. models/CNN.py +63 -1
models/CNN.py CHANGED
@@ -28,11 +28,73 @@ class VNet(nn.Module):
28
  x = self.output_layer(x)
29
  return torch.sigmoid(x)
30
 
 
 
31
 
32
  class CNN(nn.Module):
33
- def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25):
34
  self.dropout_rate = dropout_rate
 
35
  super(CNN, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  #block1
38
  self.conv1 = nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1)
 
28
  x = self.output_layer(x)
29
  return torch.sigmoid(x)
30
 
31
+ def call_bn(bn, x):
32
+ return bn(x)
33
 
34
  class CNN(nn.Module):
35
+ def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25, top_bn=False):
36
  self.dropout_rate = dropout_rate
37
+ self.top_bn = top_bn
38
  super(CNN, self).__init__()
39
+ self.c1=nn.Conv2d(input_channel,128,kernel_size=3,stride=1, padding=1)
40
+ self.c2=nn.Conv2d(128,128,kernel_size=3,stride=1, padding=1)
41
+ self.c3=nn.Conv2d(128,128,kernel_size=3,stride=1, padding=1)
42
+ self.c4=nn.Conv2d(128,256,kernel_size=3,stride=1, padding=1)
43
+ self.c5=nn.Conv2d(256,256,kernel_size=3,stride=1, padding=1)
44
+ self.c6=nn.Conv2d(256,256,kernel_size=3,stride=1, padding=1)
45
+ self.c7=nn.Conv2d(256,512,kernel_size=3,stride=1, padding=0)
46
+ self.c8=nn.Conv2d(512,256,kernel_size=3,stride=1, padding=0)
47
+ self.c9=nn.Conv2d(256,128,kernel_size=3,stride=1, padding=0)
48
+ self.l_c1=nn.Linear(128,n_outputs)
49
+ self.bn1=nn.BatchNorm2d(128)
50
+ self.bn2=nn.BatchNorm2d(128)
51
+ self.bn3=nn.BatchNorm2d(128)
52
+ self.bn4=nn.BatchNorm2d(256)
53
+ self.bn5=nn.BatchNorm2d(256)
54
+ self.bn6=nn.BatchNorm2d(256)
55
+ self.bn7=nn.BatchNorm2d(512)
56
+ self.bn8=nn.BatchNorm2d(256)
57
+ self.bn9=nn.BatchNorm2d(128)
58
+
59
+ def forward(self, x,):
60
+ h=x
61
+ h=self.c1(h)
62
+ h=F.leaky_relu(call_bn(self.bn1, h), negative_slope=0.01)
63
+ h=self.c2(h)
64
+ h=F.leaky_relu(call_bn(self.bn2, h), negative_slope=0.01)
65
+ h=self.c3(h)
66
+ h=F.leaky_relu(call_bn(self.bn3, h), negative_slope=0.01)
67
+ h=F.max_pool2d(h, kernel_size=2, stride=2)
68
+ h=F.dropout2d(h, p=self.dropout_rate)
69
+
70
+ h=self.c4(h)
71
+ h=F.leaky_relu(call_bn(self.bn4, h), negative_slope=0.01)
72
+ h=self.c5(h)
73
+ h=F.leaky_relu(call_bn(self.bn5, h), negative_slope=0.01)
74
+ h=self.c6(h)
75
+ h=F.leaky_relu(call_bn(self.bn6, h), negative_slope=0.01)
76
+ h=F.max_pool2d(h, kernel_size=2, stride=2)
77
+ h=F.dropout2d(h, p=self.dropout_rate)
78
+
79
+ h=self.c7(h)
80
+ h=F.leaky_relu(call_bn(self.bn7, h), negative_slope=0.01)
81
+ h=self.c8(h)
82
+ h=F.leaky_relu(call_bn(self.bn8, h), negative_slope=0.01)
83
+ h=self.c9(h)
84
+ h=F.leaky_relu(call_bn(self.bn9, h), negative_slope=0.01)
85
+ h=F.avg_pool2d(h, kernel_size=h.data.shape[2])
86
+
87
+ h = h.view(h.size(0), h.size(1))
88
+ logit=self.l_c1(h)
89
+ if self.top_bn:
90
+ logit=call_bn(self.bn_c1, logit)
91
+ return logit
92
+
93
+
94
+ class CNN_bak(nn.Module):
95
+ def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25):
96
+ self.dropout_rate = dropout_rate
97
+ super(CNN_bak, self).__init__()
98
 
99
  #block1
100
  self.conv1 = nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1)