File size: 4,717 Bytes
84b67b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Compact U-Net architecture for the released WildFIRE-FM checkpoints."""

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def make_norm(norm_type: str, num_channels: int, norm_groups: int) -> nn.Module:
    if norm_type == "batch":
        return nn.BatchNorm2d(num_channels)
    if norm_type == "group":
        groups = max(1, min(int(norm_groups), num_channels))
        while num_channels % groups != 0 and groups > 1:
            groups -= 1
        return nn.GroupNorm(groups, num_channels)
    if norm_type == "instance":
        return nn.InstanceNorm2d(num_channels, affine=True)
    if norm_type in {"none", "identity"}:
        return nn.Identity()
    raise ValueError(f"Unsupported norm_type: {norm_type}")


class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, norm_type: str, norm_groups: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            make_norm(norm_type, out_ch, norm_groups),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            make_norm(norm_type, out_ch, norm_groups),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class UNetSmallFlex(nn.Module):
    def __init__(
        self,
        in_ch: int,
        base: int = 32,
        dropout: float = 0.1,
        norm_type: str = "group",
        norm_groups: int = 8,
        prior_prob: float | None = None,
        use_aux_spatial_head: bool = False,
        aux_prior_prob: float | None = None,
    ):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base, norm_type, norm_groups)
        self.enc2 = ConvBlock(base, base * 2, norm_type, norm_groups)
        self.enc3 = ConvBlock(base * 2, base * 4, norm_type, norm_groups)
        self.enc4 = ConvBlock(base * 4, base * 8, norm_type, norm_groups)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base * 8, base * 16, norm_type, norm_groups)
        self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, stride=2)
        self.dec4 = ConvBlock(base * 16, base * 8, norm_type, norm_groups)
        self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, stride=2)
        self.dec3 = ConvBlock(base * 8, base * 4, norm_type, norm_groups)
        self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
        self.dec2 = ConvBlock(base * 4, base * 2, norm_type, norm_groups)
        self.up1 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
        self.dec1 = ConvBlock(base * 2, base, norm_type, norm_groups)
        self.drop = nn.Dropout2d(p=dropout)
        self.head = nn.Conv2d(base, 1, kernel_size=1)
        self.use_aux_spatial_head = bool(use_aux_spatial_head)
        self.aux_head = nn.Conv2d(base, 1, kernel_size=1) if self.use_aux_spatial_head else None
        if prior_prob is not None:
            prior_prob = float(min(max(prior_prob, 1e-6), 1.0 - 1e-6))
            nn.init.constant_(self.head.bias, math.log(prior_prob / (1.0 - prior_prob)))
        if self.aux_head is not None and aux_prior_prob is not None:
            aux_prior_prob = float(min(max(aux_prior_prob, 1e-6), 1.0 - 1e-6))
            nn.init.constant_(self.aux_head.bias, math.log(aux_prior_prob / (1.0 - aux_prior_prob)))

    @staticmethod
    def _match_hw(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
        diff_y = ref.size(2) - x.size(2)
        diff_x = ref.size(3) - x.size(3)
        if diff_y > 0 or diff_x > 0:
            x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2])
        if diff_y < 0:
            y0 = (-diff_y) // 2
            x = x[:, :, y0 : y0 + ref.size(2), :]
        if diff_x < 0:
            x0 = (-diff_x) // 2
            x = x[:, :, :, x0 : x0 + ref.size(3)]
        return x

    def forward(self, x: torch.Tensor, return_aux: bool = False):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
        d4 = self.dec4(torch.cat([self._match_hw(self.up4(b), e4), e4], dim=1))
        d3 = self.dec3(torch.cat([self._match_hw(self.up3(d4), e3), e3], dim=1))
        d2 = self.dec2(torch.cat([self._match_hw(self.up2(d3), e2), e2], dim=1))
        d1 = self.dec1(torch.cat([self._match_hw(self.up1(d2), e1), e1], dim=1))
        features = self.drop(d1)
        logits = self.head(features)
        if return_aux and self.aux_head is not None:
            return logits, self.aux_head(features)
        return logits