sapiens2-normal / sapiens /dense /src /models /heads /normal_head.py
Rawal Khirodkar
Initial sapiens2-normal Space (HF download at startup, all 4 sizes)
ba23d94
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from sapiens.registry import MODELS
from torch import Tensor
@MODELS.register_module()
class NormalHead(nn.Module):
def __init__(
self,
in_channels: int = 768,
channels: int = 16,
upsample_channels: List[int] = [768, 384, 192, 96],
conv_out_channels: Optional[Sequence[int]] = None,
conv_kernel_sizes: Optional[Sequence[int]] = None,
loss_decode=dict(type="L1Loss", loss_weight=1.0),
**kwargs,
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.channels = channels
self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes)
# final conv layer to predict normal
in_channels = (
conv_out_channels[-1] if conv_out_channels else upsample_channels[-1]
)
self.conv_normal = nn.Conv2d(in_channels, 3, kernel_size=1)
if isinstance(loss_decode, dict):
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(
f"loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}"
)
self._init_weights()
def _build_network(
self,
upsample_channels: List[int],
conv_out_channels: Optional[Sequence[int]],
conv_kernel_sizes: Optional[Sequence[int]],
) -> None:
in_channels = self.in_channels
self.input_conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(in_channels), # Normalize first
nn.SiLU(inplace=True),
)
# Progressive upsampling blocks
up_blocks = []
cur_ch = in_channels
for out_ch in upsample_channels:
up_blocks.append(
nn.Sequential(
nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2), # ↑ spatial ×2
nn.InstanceNorm2d(out_ch),
nn.SiLU(inplace=True),
)
)
cur_ch = out_ch
self.upsample_blocks = nn.Sequential(*up_blocks)
# optional extra conv layers
conv_layers = []
if conv_out_channels and conv_kernel_sizes:
for out_ch, k in zip(conv_out_channels, conv_kernel_sizes):
conv_layers.extend(
[
nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2),
nn.InstanceNorm2d(out_ch),
nn.SiLU(inplace=True),
]
)
cur_ch = out_ch
self.conv_layers = nn.Sequential(*conv_layers)
def _init_weights(self) -> None:
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
weight_dtype = m.weight.dtype
weight = nn.init.kaiming_normal_(
m.weight.float(), mode="fan_out", nonlinearity="relu"
)
m.weight.data = weight.to(weight_dtype)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
weight_dtype = m.weight.dtype
weight = nn.init.kaiming_normal_(
m.weight.float(), mode="fan_in", nonlinearity="linear"
)
m.weight.data = weight.to(weight_dtype)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor:
x_normal = self.input_conv(x)
x_normal = self.upsample_blocks(x_normal)
x_normal = self.conv_layers(x_normal)
normal = self.conv_normal(x_normal)
return normal
def loss(
self,
outputs: Tuple[Tensor],
data_samples: dict,
) -> dict:
pred_normal = outputs
gt_normal = data_samples["gt_normal"] ## B x 3 x H x W
gt_mask = data_samples["mask"] ## B x 1 x H x W
if pred_normal.shape[2:] != gt_normal.shape[2:]:
pred_normal = F.interpolate(
input=pred_normal,
size=gt_normal.shape[2:],
mode="bilinear",
align_corners=False,
antialias=False,
)
##---------------------------------
loss = dict()
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
this_loss = loss_decode(
pred_normal,
gt_normal,
valid_mask=gt_mask,
)
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = this_loss
else:
loss[loss_decode.loss_name] += this_loss
return loss, pred_normal