| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, ResNetBackbone |
| from .configuration_conditional_unet import ConditionalUNetConfig |
|
|
| class UpSampleBlock(nn.Module): |
| def __init__(self, in_channels, skip_channels, out_channels, condition_size): |
| super(UpSampleBlock, self).__init__() |
| self.up = nn.Upsample(scale_factor=2, mode='nearest') |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_channels + skip_channels + condition_size, out_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self, x, skip, condition, upsample=True): |
| if upsample: |
| x = self.up(x) |
| b, _, h, w = x.size() |
| |
| condition = condition.view(b, -1, 1, 1).expand(-1, -1, h, w) |
| x = torch.cat([x, skip, condition], dim=1) |
| x = self.conv(x) |
| return x |
|
|
| class ConditionalUNet(PreTrainedModel): |
| config_class = ConditionalUNetConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.config = config |
|
|
| self.encoder_rep = config.encoder_rep |
| self.encoder = ResNetBackbone.from_pretrained( |
| self.encoder_rep, |
| return_dict=False, |
| output_hidden_states=True |
| ) |
| self.encoder.eval() |
| self.encoder.requires_grad_(False) |
|
|
| self.num_labels = self.encoder.config.num_labels |
| self.num_channels = self.encoder.config.num_channels |
|
|
| self.config.num_labels = self.num_labels |
| self.config.num_channels = self.num_channels |
|
|
| hidden_sizes = self.encoder.config.hidden_sizes |
| embedding_size = self.encoder.config.embedding_size |
|
|
| self.up_blocks = nn.ModuleList() |
| num_stages = len(hidden_sizes) |
|
|
| in_channels = hidden_sizes[-1] |
| for i in range(num_stages - 1, -1, -1): |
| skip_channels = hidden_sizes[i - 1] if i > 0 else embedding_size |
| out_channels = skip_channels |
| self.up_blocks.append( |
| UpSampleBlock( |
| in_channels=in_channels, |
| skip_channels=skip_channels, |
| out_channels=out_channels, |
| condition_size=self.num_labels |
| ) |
| ) |
| in_channels = out_channels |
|
|
| self.final_conv = nn.Sequential( |
| nn.Conv2d(in_channels + self.num_labels, in_channels, kernel_size=3, padding=1), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(in_channels, self.num_channels, kernel_size=1) |
| ) |
|
|
| def forward(self, x, condition): |
| outputs = self.encoder(x)[-1] |
| x_stages = outputs[::-1] |
| x = x_stages[0] |
|
|
| for i, up_block in enumerate(self.up_blocks): |
| skip = x_stages[i + 1] if i + 1 < len(x_stages) else None |
| upsample = i < len(self.up_blocks) - 1 |
| if skip is not None: |
| x = up_block(x, skip, condition, upsample=upsample) |
| else: |
| x = up_block(x, torch.zeros_like(x), condition, upsample=upsample) |
|
|
| x_upsampled = nn.functional.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False) |
| b, _, h, w = x_upsampled.size() |
| condition_expanded = condition.view(b, -1, 1, 1).expand(-1, -1, h, w) |
| final_input = torch.cat([x_upsampled, condition_expanded], dim=1) |
| output = self.final_conv(final_input) |
|
|
| return output |
|
|