| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| def carafe_forward( |
| features: torch.Tensor, |
| masks: torch.Tensor, |
| kernel_size: int, |
| group_size: int, |
| scale_factor: int |
| ) -> torch.Tensor: |
| """ |
| Pure-PyTorch implementation of the CARAFE upsampling operator. |
| |
| Args: |
| features (Tensor): Input feature map of shape (N, C, H, W). |
| masks (Tensor): Reassembly kernel weights of shape |
| (N, kernel_size*kernel_size*group_size, H_out, W_out), |
| where H_out = H*scale_factor and W_out = W*scale_factor. |
| kernel_size (int): The spatial size of the reassembly kernel. |
| group_size (int): The group size to divide channels. Must divide C. |
| scale_factor (int): The upsampling factor. |
| |
| Returns: |
| Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor). |
| """ |
| N, C, H, W = features.size() |
| out_H, out_W = H * scale_factor, W * scale_factor |
| num_channels = C // group_size |
|
|
| |
| features = features.view(N, group_size, num_channels, H, W) |
| |
| features_reshaped = features.view(N * group_size, num_channels, H, W) |
| |
| patches = F.unfold(features_reshaped, kernel_size=kernel_size, |
| padding=(kernel_size - 1) // 2) |
| |
| |
| patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W) |
| |
| patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W) |
|
|
| |
| |
| |
| device = features.device |
| |
| h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor') |
| w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor') |
| |
| h_idx = h_idx.unsqueeze(1).expand(out_H, out_W) |
| w_idx = w_idx.unsqueeze(0).expand(out_H, out_W) |
| base_idx = (h_idx * W + w_idx).view(-1) |
|
|
| |
| |
| base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1) |
| |
| gathered_patches = torch.gather(patches, -1, base_idx) |
| |
| gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W) |
|
|
| |
| |
| |
| masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W) |
| |
| |
| masks = masks.unsqueeze(2) |
| |
| masks = masks.expand(-1, -1, num_channels, -1, -1, -1) |
|
|
| |
| |
| out = (gathered_patches * masks).sum(dim=3) |
| |
| out = out.view(N, C, out_H, out_W) |
| return out |
|
|
|
|
| class CARAFE(nn.Module): |
| """ |
| CARAFE: Content-Aware ReAssembly of Features |
| |
| This PyTorch module implements the CARAFE upsampling operator in pure Python. |
| Given an input feature map and its corresponding reassembly masks, the module |
| reassembles features from local patches to produce a higher-resolution output. |
| |
| Args: |
| kernel_size (int): Reassembly kernel size. |
| group_size (int): Group size for channel grouping (must divide number of channels). |
| scale_factor (int): Upsample ratio. |
| """ |
| def __init__(self, kernel_size: int, group_size: int, scale_factor: int): |
| super(CARAFE, self).__init__() |
| self.kernel_size = kernel_size |
| self.group_size = group_size |
| self.scale_factor = scale_factor |
|
|
| def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
| return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor) |
|
|
|
|
| class CARAFEPack(nn.Module): |
| """ |
| A unified package of the CARAFE upsampler that contains: |
| 1) A channel compressor. |
| 2) A content encoder that predicts reassembly masks. |
| 3) The CARAFE operator. |
| |
| This is modeled after the official CARAFE package. |
| |
| Args: |
| channels (int): Number of input feature channels. |
| scale_factor (int): Upsample ratio. |
| up_kernel (int): Kernel size for the CARAFE operator. |
| up_group (int): Group size for the CARAFE operator. |
| encoder_kernel (int): Kernel size of the content encoder. |
| encoder_dilation (int): Dilation rate for the content encoder. |
| compressed_channels (int): Output channels for the channel compressor. |
| """ |
| def __init__( |
| self, |
| channels: int, |
| scale_factor: int, |
| up_kernel: int = 5, |
| up_group: int = 1, |
| encoder_kernel: int = 3, |
| encoder_dilation: int = 1, |
| compressed_channels: int = 64 |
| ): |
| super(CARAFEPack, self).__init__() |
| self.channels = channels |
| self.scale_factor = scale_factor |
| self.up_kernel = up_kernel |
| self.up_group = up_group |
| self.encoder_kernel = encoder_kernel |
| self.encoder_dilation = encoder_dilation |
| self.compressed_channels = compressed_channels |
|
|
| |
| self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1) |
| |
| self.content_encoder = nn.Conv2d( |
| compressed_channels, |
| up_kernel * up_kernel * up_group * scale_factor * scale_factor, |
| kernel_size=encoder_kernel, |
| padding=int((encoder_kernel - 1) * encoder_dilation / 2), |
| dilation=encoder_dilation |
| ) |
| |
| nn.init.xavier_uniform_(self.channel_compressor.weight) |
| nn.init.xavier_uniform_(self.content_encoder.weight) |
| if self.channel_compressor.bias is not None: |
| nn.init.constant_(self.channel_compressor.bias, 0) |
| if self.content_encoder.bias is not None: |
| nn.init.constant_(self.content_encoder.bias, 0) |
|
|
| def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Normalize and reshape the mask. |
| Applies pixel shuffle to upsample the predicted kernel weights and then |
| applies softmax normalization across the kernel dimension. |
| |
| Args: |
| mask (Tensor): Predicted mask of shape (N, out_channels, H, W). |
| |
| Returns: |
| Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale). |
| """ |
| |
| mask = F.pixel_shuffle(mask, self.scale_factor) |
| N, mask_c, H, W = mask.size() |
| |
| mask_channel = mask_c // (self.up_kernel ** 2) |
| mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W) |
| mask = F.softmax(mask, dim=2) |
| mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous() |
| return mask |
|
|
| def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| compressed_x = self.channel_compressor(x) |
| mask = self.content_encoder(compressed_x) |
| mask = self.kernel_normalizer(mask) |
| out = self.feature_reassemble(x, mask) |
| return out |
|
|
|
|
| |
| if __name__ == '__main__': |
| |
| x = torch.randn(2, 64, 32, 32).cuda() |
| |
| |
| upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda() |
| |
| out = upsampler(x) |
| print("Input shape: ", x.shape) |
| print("Output shape:", out.shape) |
|
|