Spaces:
Running
Running
File size: 5,559 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from sapiens.registry import MODELS
from torch import Tensor
@MODELS.register_module()
class NormalHead(nn.Module):
def __init__(
self,
in_channels: int = 768,
channels: int = 16,
upsample_channels: List[int] = [768, 384, 192, 96],
conv_out_channels: Optional[Sequence[int]] = None,
conv_kernel_sizes: Optional[Sequence[int]] = None,
loss_decode=dict(type="L1Loss", loss_weight=1.0),
**kwargs,
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.channels = channels
self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes)
# final conv layer to predict normal
in_channels = (
conv_out_channels[-1] if conv_out_channels else upsample_channels[-1]
)
self.conv_normal = nn.Conv2d(in_channels, 3, kernel_size=1)
if isinstance(loss_decode, dict):
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(
f"loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}"
)
self._init_weights()
def _build_network(
self,
upsample_channels: List[int],
conv_out_channels: Optional[Sequence[int]],
conv_kernel_sizes: Optional[Sequence[int]],
) -> None:
in_channels = self.in_channels
self.input_conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(in_channels), # Normalize first
nn.SiLU(inplace=True),
)
# Progressive upsampling blocks
up_blocks = []
cur_ch = in_channels
for out_ch in upsample_channels:
up_blocks.append(
nn.Sequential(
nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1),
nn.PixelShuffle(2), # ↑ spatial ×2
nn.InstanceNorm2d(out_ch),
nn.SiLU(inplace=True),
)
)
cur_ch = out_ch
self.upsample_blocks = nn.Sequential(*up_blocks)
# optional extra conv layers
conv_layers = []
if conv_out_channels and conv_kernel_sizes:
for out_ch, k in zip(conv_out_channels, conv_kernel_sizes):
conv_layers.extend(
[
nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2),
nn.InstanceNorm2d(out_ch),
nn.SiLU(inplace=True),
]
)
cur_ch = out_ch
self.conv_layers = nn.Sequential(*conv_layers)
def _init_weights(self) -> None:
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
weight_dtype = m.weight.dtype
weight = nn.init.kaiming_normal_(
m.weight.float(), mode="fan_out", nonlinearity="relu"
)
m.weight.data = weight.to(weight_dtype)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
weight_dtype = m.weight.dtype
weight = nn.init.kaiming_normal_(
m.weight.float(), mode="fan_in", nonlinearity="linear"
)
m.weight.data = weight.to(weight_dtype)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor:
x_normal = self.input_conv(x)
x_normal = self.upsample_blocks(x_normal)
x_normal = self.conv_layers(x_normal)
normal = self.conv_normal(x_normal)
return normal
def loss(
self,
outputs: Tuple[Tensor],
data_samples: dict,
) -> dict:
pred_normal = outputs
gt_normal = data_samples["gt_normal"] ## B x 3 x H x W
gt_mask = data_samples["mask"] ## B x 1 x H x W
if pred_normal.shape[2:] != gt_normal.shape[2:]:
pred_normal = F.interpolate(
input=pred_normal,
size=gt_normal.shape[2:],
mode="bilinear",
align_corners=False,
antialias=False,
)
##---------------------------------
loss = dict()
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
this_loss = loss_decode(
pred_normal,
gt_normal,
valid_mask=gt_mask,
)
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = this_loss
else:
loss[loss_decode.loss_name] += this_loss
return loss, pred_normal
|