| import torch |
| import torch.nn as nn |
|
|
| from models.init_weight import init_net |
| from models.model_blocks import AdaInResBlock |
| from models.model_blocks import ResBlock |
| from models.semantic_face_fusion_model import SemanticFaceFusionModule |
| from models.shape_aware_identity_model import ShapeAwareIdentityExtractor |
|
|
|
|
| class Encoder(nn.Module): |
| """ |
| Hififace encoder part |
| """ |
|
|
| def __init__(self): |
| super(Encoder, self).__init__() |
| self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) |
|
|
| self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512] |
| self.down_sample = [True, True, True, True, True, False, False] |
|
|
| self.block_list = nn.ModuleList() |
|
|
| for i in range(7): |
| self.block_list.append( |
| ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i]) |
| ) |
|
|
| def forward(self, x): |
| x = self.conv_first(x) |
| z_enc = None |
|
|
| for i in range(7): |
| x = self.block_list[i](x) |
| if i == 1: |
| z_enc = x |
| return z_enc, x |
|
|
|
|
| class Decoder(nn.Module): |
| """ |
| Hififace decoder part |
| """ |
|
|
| def __init__(self): |
| super(Decoder, self).__init__() |
| self.block_list = nn.ModuleList() |
| self.channel_list = [512, 512, 512, 512, 512, 256] |
| self.up_sample = [False, False, True, True, True] |
|
|
| for i in range(5): |
| self.block_list.append( |
| AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i]) |
| ) |
|
|
| def forward(self, x, id_vector): |
| """ |
| Parameters: |
| ----------- |
| x: encoder encoded feature map |
| id_vector: 3d shape aware identity vector |
| |
| Returns: |
| -------- |
| z_dec |
| """ |
| for i in range(5): |
| x = self.block_list[i](x, id_vector) |
| return x |
|
|
|
|
| class Generator(nn.Module): |
| """ |
| Hififace Generator |
| """ |
|
|
| def __init__(self, identity_extractor_config): |
| super(Generator, self).__init__() |
| self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config) |
| self.id_extractor.requires_grad_(False) |
| self.encoder = init_net(Encoder()) |
| self.decoder = init_net(Decoder()) |
| self.sff_module = init_net(SemanticFaceFusionModule()) |
|
|
| @torch.no_grad() |
| def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): |
| shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate) |
| z_enc, x = self.encoder(i_target) |
| z_dec = self.decoder(x, shape_aware_id_vector) |
|
|
| i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) |
|
|
| return i_r, i_low, m_r, m_low |
|
|
| def forward(self, i_source, i_target, need_id_grad=False): |
| """ |
| Parameters: |
| ----------- |
| i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image |
| i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image |
| need_id_grad: bool, whether to calculate id extractor module's gradient |
| |
| Returns: |
| -------- |
| i_r: torch.Tensor |
| i_low: torch.Tensor |
| m_r: torch.Tensor |
| m_low: torch.Tensor |
| """ |
| if need_id_grad: |
| shape_aware_id_vector = self.id_extractor(i_source, i_target) |
| else: |
| with torch.no_grad(): |
| shape_aware_id_vector = self.id_extractor(i_source, i_target) |
| z_enc, x = self.encoder(i_target) |
| z_dec = self.decoder(x, shape_aware_id_vector) |
|
|
| i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) |
|
|
| return i_r, i_low, m_r, m_low |
|
|