"""Flow-Warp U-Net: predicts optical flow + residual, warps last frame.""" import torch import torch.nn as nn import torch.nn.functional as F class ResConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch) self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x): residual = self.proj(x) x = F.silu(self.gn1(self.conv1(x))) x = F.silu(self.gn2(self.conv2(x))) return x + residual class FlowWarpUNet(nn.Module): def __init__(self, in_channels=12, channels=[48, 96, 192, 384]): super().__init__() # Encoder self.encoders = nn.ModuleList() self.pools = nn.ModuleList() prev_ch = in_channels for ch in channels: self.encoders.append(ResConvBlock(prev_ch, ch)) self.pools.append(nn.MaxPool2d(2)) prev_ch = ch # Bottleneck self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2) # Decoder self.upconvs = nn.ModuleList() self.decoders = nn.ModuleList() dec_channels = list(reversed(channels)) prev_ch = channels[-1] * 2 for ch in dec_channels: self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2)) self.decoders.append(ResConvBlock(ch * 2, ch)) prev_ch = ch # Flow head (2 channels: dx, dy) self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1) # Residual head (3 channels: RGB residual) self.residual_head = nn.Conv2d(dec_channels[-1], 3, 1) # Initialize flow head near-zero for stable start nn.init.zeros_(self.flow_head.weight) nn.init.zeros_(self.flow_head.bias) # Initialize residual head near-zero too nn.init.zeros_(self.residual_head.weight) nn.init.zeros_(self.residual_head.bias) def forward(self, x): """ Args: x: (B, 12, 64, 64) - 4 frames stacked Returns: flow: (B, 2, 64, 64) - optical flow (dx, dy) in pixels residual: (B, 3, 64, 64) - residual correction """ skips = [] for enc, pool in zip(self.encoders, self.pools): x = enc(x) skips.append(x) x = pool(x) x = self.bottleneck(x) for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)): x = upconv(x) x = torch.cat([x, skip], dim=1) x = dec(x) flow = self.flow_head(x) # (B, 2, 64, 64) residual = self.residual_head(x) # (B, 3, 64, 64) return flow, residual def differentiable_warp(img, flow): """ Warp image by flow using bilinear sampling. Args: img: (B, C, H, W) - image to warp flow: (B, 2, H, W) - flow field (dx, dy) in pixel coordinates Returns: warped: (B, C, H, W) """ B, C, H, W = img.shape # Create base grid grid_y, grid_x = torch.meshgrid( torch.arange(H, device=img.device, dtype=img.dtype), torch.arange(W, device=img.device, dtype=img.dtype), indexing='ij' ) grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) # (B, H, W) grid_y = grid_y.unsqueeze(0).expand(B, -1, -1) # Add flow new_x = grid_x + flow[:, 0] # (B, H, W) new_y = grid_y + flow[:, 1] # Normalize to [-1, 1] for grid_sample new_x = 2.0 * new_x / (W - 1) - 1.0 new_y = 2.0 * new_y / (H - 1) - 1.0 grid = torch.stack([new_x, new_y], dim=-1) # (B, H, W, 2) warped = F.grid_sample(img, grid, mode='bilinear', padding_mode='border', align_corners=True) return warped def flow_smoothness_loss(flow): """Penalize spatial gradients of flow field.""" dx = flow[:, :, :, 1:] - flow[:, :, :, :-1] dy = flow[:, :, 1:, :] - flow[:, :, :-1, :] return (dx.abs().mean() + dy.abs().mean()) / 2 class GlobalSSIMLoss(nn.Module): def __init__(self): super().__init__() self.C1 = (0.01) ** 2 self.C2 = (0.03) ** 2 def forward(self, pred, target): B, C, H, W = pred.shape pred_flat = pred.view(B, C, -1) target_flat = target.view(B, C, -1) mu_pred = pred_flat.mean(dim=2) mu_target = target_flat.mean(dim=2) sigma_pred_sq = pred_flat.var(dim=2) sigma_target_sq = target_flat.var(dim=2) sigma_cross = ((pred_flat - mu_pred.unsqueeze(2)) * (target_flat - mu_target.unsqueeze(2))).mean(dim=2) numerator = (2 * mu_pred * mu_target + self.C1) * (2 * sigma_cross + self.C2) denominator = (mu_pred ** 2 + mu_target ** 2 + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2) ssim = numerator / denominator return 1 - ssim.mean()