# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from sapiens.registry import MODELS from torch import Tensor @MODELS.register_module() class NormalHead(nn.Module): def __init__( self, in_channels: int = 768, channels: int = 16, upsample_channels: List[int] = [768, 384, 192, 96], conv_out_channels: Optional[Sequence[int]] = None, conv_kernel_sizes: Optional[Sequence[int]] = None, loss_decode=dict(type="L1Loss", loss_weight=1.0), **kwargs, ): super().__init__(**kwargs) self.in_channels = in_channels self.channels = channels self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes) # final conv layer to predict normal in_channels = ( conv_out_channels[-1] if conv_out_channels else upsample_channels[-1] ) self.conv_normal = nn.Conv2d(in_channels, 3, kernel_size=1) if isinstance(loss_decode, dict): self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: self.loss_decode.append(MODELS.build(loss)) else: raise TypeError( f"loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}" ) self._init_weights() def _build_network( self, upsample_channels: List[int], conv_out_channels: Optional[Sequence[int]], conv_kernel_sizes: Optional[Sequence[int]], ) -> None: in_channels = self.in_channels self.input_conv = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.InstanceNorm2d(in_channels), # Normalize first nn.SiLU(inplace=True), ) # Progressive upsampling blocks up_blocks = [] cur_ch = in_channels for out_ch in upsample_channels: up_blocks.append( nn.Sequential( nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1), nn.PixelShuffle(2), # ↑ spatial ×2 nn.InstanceNorm2d(out_ch), nn.SiLU(inplace=True), ) ) cur_ch = out_ch self.upsample_blocks = nn.Sequential(*up_blocks) # optional extra conv layers conv_layers = [] if conv_out_channels and conv_kernel_sizes: for out_ch, k in zip(conv_out_channels, conv_kernel_sizes): conv_layers.extend( [ nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2), nn.InstanceNorm2d(out_ch), nn.SiLU(inplace=True), ] ) cur_ch = out_ch self.conv_layers = nn.Sequential(*conv_layers) def _init_weights(self) -> None: """Initialize network weights.""" for m in self.modules(): if isinstance(m, nn.Conv2d): weight_dtype = m.weight.dtype weight = nn.init.kaiming_normal_( m.weight.float(), mode="fan_out", nonlinearity="relu" ) m.weight.data = weight.to(weight_dtype) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): weight_dtype = m.weight.dtype weight = nn.init.kaiming_normal_( m.weight.float(), mode="fan_in", nonlinearity="linear" ) m.weight.data = weight.to(weight_dtype) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor: x_normal = self.input_conv(x) x_normal = self.upsample_blocks(x_normal) x_normal = self.conv_layers(x_normal) normal = self.conv_normal(x_normal) return normal def loss( self, outputs: Tuple[Tensor], data_samples: dict, ) -> dict: pred_normal = outputs gt_normal = data_samples["gt_normal"] ## B x 3 x H x W gt_mask = data_samples["mask"] ## B x 1 x H x W if pred_normal.shape[2:] != gt_normal.shape[2:]: pred_normal = F.interpolate( input=pred_normal, size=gt_normal.shape[2:], mode="bilinear", align_corners=False, antialias=False, ) ##--------------------------------- loss = dict() if not isinstance(self.loss_decode, nn.ModuleList): losses_decode = [self.loss_decode] else: losses_decode = self.loss_decode for loss_decode in losses_decode: this_loss = loss_decode( pred_normal, gt_normal, valid_mask=gt_mask, ) if loss_decode.loss_name not in loss: loss[loss_decode.loss_name] = this_loss else: loss[loss_decode.loss_name] += this_loss return loss, pred_normal