"""Flow-Warp-Mask U-Net: predicts flow, occlusion mask, and generated frame.""" import torch import torch.nn as nn import torch.nn.functional as F class ResConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch) self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x): residual = self.proj(x) x = F.silu(self.gn1(self.conv1(x))) x = F.silu(self.gn2(self.conv2(x))) return x + residual class FlowWarpMaskUNet(nn.Module): def __init__(self, in_channels=12, channels=[48, 96, 192]): super().__init__() # Encoder self.encoders = nn.ModuleList() self.pools = nn.ModuleList() prev_ch = in_channels for ch in channels: self.encoders.append(ResConvBlock(prev_ch, ch)) self.pools.append(nn.MaxPool2d(2)) prev_ch = ch # Bottleneck self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2) # Decoder self.upconvs = nn.ModuleList() self.decoders = nn.ModuleList() dec_channels = list(reversed(channels)) prev_ch = channels[-1] * 2 for ch in dec_channels: self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2)) self.decoders.append(ResConvBlock(ch * 2, ch)) prev_ch = ch # Flow head (2 channels: dx, dy) self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1) # Mask head (1 channel: occlusion mask, sigmoid applied) self.mask_head = nn.Conv2d(dec_channels[-1], 1, 1) # Generation head (3 channels: full frame for occluded areas) self.gen_head = nn.Conv2d(dec_channels[-1], 3, 1) # Initialize flow and mask heads near-zero for stable start nn.init.zeros_(self.flow_head.weight) nn.init.zeros_(self.flow_head.bias) nn.init.zeros_(self.mask_head.weight) nn.init.zeros_(self.mask_head.bias) def forward(self, x): skips = [] for enc, pool in zip(self.encoders, self.pools): x = enc(x) skips.append(x) x = pool(x) x = self.bottleneck(x) for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)): x = upconv(x) x = torch.cat([x, skip], dim=1) x = dec(x) flow = self.flow_head(x) mask = torch.sigmoid(self.mask_head(x)) gen_frame = self.gen_head(x) return flow, mask, gen_frame