| import torch.nn as nn |
| import torch |
|
|
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super(ResidualBlock, self).__init__() |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(out_channels) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(out_channels) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None |
|
|
| def forward(self, x): |
| residual = x |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
| out = self.bn2(out) |
| if self.downsample: |
| residual = self.downsample(x) |
| out += residual |
| out = self.relu(out) |
| return out |
|
|