File size: 10,250 Bytes
ba23d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from sapiens.registry import MODELS
from torch import Tensor


@MODELS.register_module()
class PointmapHead(nn.Module):
    def __init__(
        self,
        in_channels: int = 768,
        channels: int = 16,
        upsample_channels: List[int] = [768, 384, 192, 96],
        conv_out_channels: Optional[Sequence[int]] = None,
        conv_kernel_sizes: Optional[Sequence[int]] = None,
        scale_conv_out_channels: Optional[Sequence[int]] = (1536, 512, 128),
        scale_conv_kernel_sizes: Optional[Sequence[int]] = (1, 1, 1),
        scale_final_layer: Optional[Sequence[int]] = (48 * 128, 512, 64, 1),
        loss_decode=dict(type="L1Loss", loss_weight=1.0),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.in_channels = in_channels
        self.channels = channels

        self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes)
        if scale_conv_out_channels is not None:
            self.scale_conv_layers = self._make_regression_conv_layers(
                in_channels=self.in_channels,
                layer_out_channels=scale_conv_out_channels,
                layer_kernel_sizes=scale_conv_kernel_sizes,
            )
            self.scale_final_layer = self._make_final_layer(scale_final_layer)

        else:
            self.scale_conv_layers = None
            self.scale_final_layer = None

        # final conv layer to predict pointmap
        in_channels = (
            conv_out_channels[-1] if conv_out_channels else upsample_channels[-1]
        )
        self.conv_pointmap = nn.Conv2d(in_channels, 3, kernel_size=1)

        if isinstance(loss_decode, dict):
            self.loss_decode = MODELS.build(loss_decode)
        elif isinstance(loss_decode, (list, tuple)):
            self.loss_decode = nn.ModuleList()
            for loss in loss_decode:
                self.loss_decode.append(MODELS.build(loss))
        else:
            raise TypeError(
                f"loss_decode must be a dict or sequence of dict,\
                but got {type(loss_decode)}"
            )

        self._init_weights()

    def _build_network(
        self,
        upsample_channels: List[int],
        conv_out_channels: Optional[Sequence[int]],
        conv_kernel_sizes: Optional[Sequence[int]],
    ) -> None:
        in_channels = self.in_channels

        self.input_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(in_channels),  # Normalize first
            nn.SiLU(inplace=True),
        )

        # Progressive upsampling blocks
        up_blocks = []
        cur_ch = in_channels
        for out_ch in upsample_channels:
            up_blocks.append(
                nn.Sequential(
                    nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1),
                    nn.PixelShuffle(2),  # ↑ spatial ×2
                    nn.InstanceNorm2d(out_ch),
                    nn.SiLU(inplace=True),
                )
            )

            cur_ch = out_ch
        self.upsample_blocks = nn.Sequential(*up_blocks)

        # optional extra conv layers
        conv_layers = []
        if conv_out_channels and conv_kernel_sizes:
            for out_ch, k in zip(conv_out_channels, conv_kernel_sizes):
                conv_layers.extend(
                    [
                        nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2),
                        nn.InstanceNorm2d(out_ch),
                        nn.SiLU(inplace=True),
                    ]
                )
                cur_ch = out_ch

        self.conv_layers = nn.Sequential(*conv_layers)

    def _make_final_layer(self, final_layer: Sequence[int]) -> nn.Module:
        """Create final layer by given parameters."""
        layers = [nn.Flatten()]
        in_features = final_layer[0]

        for i in range(1, len(final_layer)):
            layers.append(nn.Linear(in_features, final_layer[i]))
            if i < len(final_layer) - 1:  # No activation after the last layer
                layers.append(nn.SiLU())
            in_features = final_layer[i]

        return nn.Sequential(*layers)

    def _make_regression_conv_layers(
        self,
        in_channels: int,
        layer_out_channels: Sequence[int],
        layer_kernel_sizes: Sequence[int],
    ) -> nn.Module:
        """Create convolutional layers by given parameters."""

        layers = []
        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
            stride = 2  # Set stride to 2 to reduce resolution by half
            padding = (kernel_size - 1) // 2
            layers.append(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                )
            )
            layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.SiLU(inplace=True))

            in_channels = out_channels

        return nn.Sequential(*layers)

    def _init_weights(self) -> None:
        """Initialize network weights."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                weight_dtype = m.weight.dtype
                weight = nn.init.kaiming_normal_(
                    m.weight.float(), mode="fan_out", nonlinearity="relu"
                )
                m.weight.data = weight.to(weight_dtype)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                weight_dtype = m.weight.dtype
                weight = nn.init.kaiming_normal_(
                    m.weight.float(), mode="fan_in", nonlinearity="linear"
                )
                m.weight.data = weight.to(weight_dtype)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor:
        x_pointmap = self.input_conv(x)
        x_pointmap = self.upsample_blocks(x_pointmap)
        x_pointmap = self.conv_layers(x_pointmap)
        pointmap = self.conv_pointmap(x_pointmap)

        if self.scale_conv_layers is not None:
            x_scale = self.scale_conv_layers(x)
            scale = self.scale_final_layer(
                x_scale
            )  ## B x 1. scale = f_c / f_actual. in pixel spac of fx
        else:
            scale = None

        return pointmap, scale

    def loss(
        self,
        outputs: Tuple[Tensor],
        data_samples: dict,
    ) -> dict:
        pred_pointmap, pred_scale = outputs
        gt_pointmap = data_samples["gt_pointmap"]  ## B x 3 x H x W
        gt_mean_depth = data_samples["gt_mean_depth"]  ## B x 1 x 1 x 1

        # gt_K = data_samples["meta"]["K"]  ## B x 3 x 3
        gt_original_K = data_samples["meta"]["original_K"]  ## B x 3 x 3
        gt_scale = data_samples["meta"]["scale"].view(-1, 1)  ## B x 1
        gt_mask = data_samples["mask"]  ## B x 1 x H x W

        if pred_pointmap.shape[2:] != gt_pointmap.shape[2:]:
            print(
                "Warning: this is not recommended in pointmap, you may get artifacts!"
            )
            print(
                f"pred_pointmap size: {pred_pointmap.shape}, gt_pointmap size: {gt_pointmap.shape}"
            )
            pred_pointmap = F.interpolate(
                input=pred_pointmap,
                size=gt_pointmap.shape[2:],
                mode="bilinear",
                align_corners=False,
                antialias=False,
            )

        ##---------------------------------
        loss = dict()
        if not isinstance(self.loss_decode, nn.ModuleList):
            losses_decode = [self.loss_decode]
        else:
            losses_decode = self.loss_decode

        ## B x 1 x H x W
        pred_depth = pred_pointmap[:, 2].unsqueeze(dim=1)  ## B x 1 x H x W
        gt_depth = gt_pointmap[:, 2].unsqueeze(dim=1)  ## B x 1 x H x W

        for loss_decode in losses_decode:
            ## pointmap consistency loss
            if loss_decode.loss_name == "loss_K_consistency":
                this_loss = loss_decode(
                    pred_pointmap,
                    gt_pointmap,
                    valid_mask=gt_mask,
                    intrinsics=gt_original_K,  ## Caution: using original K for consistency loss. since X/Z and Y/Z ratio is the same
                )
            elif loss_decode.loss_name == "loss_silog":
                this_loss = loss_decode(
                    pred_depth,
                    gt_depth,
                    valid_mask=gt_mask,
                )
            elif loss_decode.loss_name == "loss_normal":
                this_loss = loss_decode(
                    pred_pointmap,
                    gt_pointmap,
                    valid_mask=gt_mask,
                    scale=gt_scale,
                )
            elif loss_decode.loss_name == "loss_scale_l1":
                this_loss = loss_decode(pred_scale, gt_scale)
            elif loss_decode.loss_name in [
                "loss_l1",
                "loss_shift_invariant",
                "loss_multiscale_l1_2",
                "loss_multiscale_l1_4",
            ]:
                this_loss = loss_decode(
                    pred_pointmap / gt_mean_depth,
                    gt_pointmap / gt_mean_depth,
                    valid_mask=gt_mask,
                )
                this_loss = torch.clamp(this_loss, max=4.0)

            else:
                raise NotImplementedError(
                    f"loss {loss_decode.loss_name} is not implemented"
                )

            if loss_decode.loss_name not in loss:
                loss[loss_decode.loss_name] = this_loss
            else:
                loss[loss_decode.loss_name] += this_loss

        return loss, (pred_pointmap, pred_scale)