| """Compact U-Net architecture for the released WildFIRE-FM checkpoints.""" |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def make_norm(norm_type: str, num_channels: int, norm_groups: int) -> nn.Module: |
| if norm_type == "batch": |
| return nn.BatchNorm2d(num_channels) |
| if norm_type == "group": |
| groups = max(1, min(int(norm_groups), num_channels)) |
| while num_channels % groups != 0 and groups > 1: |
| groups -= 1 |
| return nn.GroupNorm(groups, num_channels) |
| if norm_type == "instance": |
| return nn.InstanceNorm2d(num_channels, affine=True) |
| if norm_type in {"none", "identity"}: |
| return nn.Identity() |
| raise ValueError(f"Unsupported norm_type: {norm_type}") |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, in_ch: int, out_ch: int, norm_type: str, norm_groups: int): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), |
| make_norm(norm_type, out_ch, norm_groups), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), |
| make_norm(norm_type, out_ch, norm_groups), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class UNetSmallFlex(nn.Module): |
| def __init__( |
| self, |
| in_ch: int, |
| base: int = 32, |
| dropout: float = 0.1, |
| norm_type: str = "group", |
| norm_groups: int = 8, |
| prior_prob: float | None = None, |
| use_aux_spatial_head: bool = False, |
| aux_prior_prob: float | None = None, |
| ): |
| super().__init__() |
| self.enc1 = ConvBlock(in_ch, base, norm_type, norm_groups) |
| self.enc2 = ConvBlock(base, base * 2, norm_type, norm_groups) |
| self.enc3 = ConvBlock(base * 2, base * 4, norm_type, norm_groups) |
| self.enc4 = ConvBlock(base * 4, base * 8, norm_type, norm_groups) |
| self.pool = nn.MaxPool2d(2) |
| self.bottleneck = ConvBlock(base * 8, base * 16, norm_type, norm_groups) |
| self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, stride=2) |
| self.dec4 = ConvBlock(base * 16, base * 8, norm_type, norm_groups) |
| self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, stride=2) |
| self.dec3 = ConvBlock(base * 8, base * 4, norm_type, norm_groups) |
| self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2) |
| self.dec2 = ConvBlock(base * 4, base * 2, norm_type, norm_groups) |
| self.up1 = nn.ConvTranspose2d(base * 2, base, 2, stride=2) |
| self.dec1 = ConvBlock(base * 2, base, norm_type, norm_groups) |
| self.drop = nn.Dropout2d(p=dropout) |
| self.head = nn.Conv2d(base, 1, kernel_size=1) |
| self.use_aux_spatial_head = bool(use_aux_spatial_head) |
| self.aux_head = nn.Conv2d(base, 1, kernel_size=1) if self.use_aux_spatial_head else None |
| if prior_prob is not None: |
| prior_prob = float(min(max(prior_prob, 1e-6), 1.0 - 1e-6)) |
| nn.init.constant_(self.head.bias, math.log(prior_prob / (1.0 - prior_prob))) |
| if self.aux_head is not None and aux_prior_prob is not None: |
| aux_prior_prob = float(min(max(aux_prior_prob, 1e-6), 1.0 - 1e-6)) |
| nn.init.constant_(self.aux_head.bias, math.log(aux_prior_prob / (1.0 - aux_prior_prob))) |
|
|
| @staticmethod |
| def _match_hw(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
| diff_y = ref.size(2) - x.size(2) |
| diff_x = ref.size(3) - x.size(3) |
| if diff_y > 0 or diff_x > 0: |
| x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2]) |
| if diff_y < 0: |
| y0 = (-diff_y) // 2 |
| x = x[:, :, y0 : y0 + ref.size(2), :] |
| if diff_x < 0: |
| x0 = (-diff_x) // 2 |
| x = x[:, :, :, x0 : x0 + ref.size(3)] |
| return x |
|
|
| def forward(self, x: torch.Tensor, return_aux: bool = False): |
| e1 = self.enc1(x) |
| e2 = self.enc2(self.pool(e1)) |
| e3 = self.enc3(self.pool(e2)) |
| e4 = self.enc4(self.pool(e3)) |
| b = self.bottleneck(self.pool(e4)) |
| d4 = self.dec4(torch.cat([self._match_hw(self.up4(b), e4), e4], dim=1)) |
| d3 = self.dec3(torch.cat([self._match_hw(self.up3(d4), e3), e3], dim=1)) |
| d2 = self.dec2(torch.cat([self._match_hw(self.up2(d3), e2), e2], dim=1)) |
| d1 = self.dec1(torch.cat([self._match_hw(self.up1(d2), e1), e1], dim=1)) |
| features = self.drop(d1) |
| logits = self.head(features) |
| if return_aux and self.aux_head is not None: |
| return logits, self.aux_head(features) |
| return logits |
|
|