| """Flow-Warp U-Net: predicts optical flow + residual, warps last frame.""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ResConvBlock(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch) |
| self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
| def forward(self, x): |
| residual = self.proj(x) |
| x = F.silu(self.gn1(self.conv1(x))) |
| x = F.silu(self.gn2(self.conv2(x))) |
| return x + residual |
|
|
|
|
| class FlowWarpUNet(nn.Module): |
| def __init__(self, in_channels=12, channels=[48, 96, 192, 384]): |
| super().__init__() |
| |
| self.encoders = nn.ModuleList() |
| self.pools = nn.ModuleList() |
| prev_ch = in_channels |
| for ch in channels: |
| self.encoders.append(ResConvBlock(prev_ch, ch)) |
| self.pools.append(nn.MaxPool2d(2)) |
| prev_ch = ch |
|
|
| |
| self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2) |
|
|
| |
| self.upconvs = nn.ModuleList() |
| self.decoders = nn.ModuleList() |
| dec_channels = list(reversed(channels)) |
| prev_ch = channels[-1] * 2 |
| for ch in dec_channels: |
| self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2)) |
| self.decoders.append(ResConvBlock(ch * 2, ch)) |
| prev_ch = ch |
|
|
| |
| self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1) |
| |
| self.residual_head = nn.Conv2d(dec_channels[-1], 3, 1) |
|
|
| |
| nn.init.zeros_(self.flow_head.weight) |
| nn.init.zeros_(self.flow_head.bias) |
| |
| nn.init.zeros_(self.residual_head.weight) |
| nn.init.zeros_(self.residual_head.bias) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: (B, 12, 64, 64) - 4 frames stacked |
| Returns: |
| flow: (B, 2, 64, 64) - optical flow (dx, dy) in pixels |
| residual: (B, 3, 64, 64) - residual correction |
| """ |
| skips = [] |
| for enc, pool in zip(self.encoders, self.pools): |
| x = enc(x) |
| skips.append(x) |
| x = pool(x) |
|
|
| x = self.bottleneck(x) |
|
|
| for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)): |
| x = upconv(x) |
| x = torch.cat([x, skip], dim=1) |
| x = dec(x) |
|
|
| flow = self.flow_head(x) |
| residual = self.residual_head(x) |
|
|
| return flow, residual |
|
|
|
|
| def differentiable_warp(img, flow): |
| """ |
| Warp image by flow using bilinear sampling. |
| |
| Args: |
| img: (B, C, H, W) - image to warp |
| flow: (B, 2, H, W) - flow field (dx, dy) in pixel coordinates |
| Returns: |
| warped: (B, C, H, W) |
| """ |
| B, C, H, W = img.shape |
|
|
| |
| grid_y, grid_x = torch.meshgrid( |
| torch.arange(H, device=img.device, dtype=img.dtype), |
| torch.arange(W, device=img.device, dtype=img.dtype), |
| indexing='ij' |
| ) |
| grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) |
| grid_y = grid_y.unsqueeze(0).expand(B, -1, -1) |
|
|
| |
| new_x = grid_x + flow[:, 0] |
| new_y = grid_y + flow[:, 1] |
|
|
| |
| new_x = 2.0 * new_x / (W - 1) - 1.0 |
| new_y = 2.0 * new_y / (H - 1) - 1.0 |
|
|
| grid = torch.stack([new_x, new_y], dim=-1) |
|
|
| warped = F.grid_sample(img, grid, mode='bilinear', padding_mode='border', align_corners=True) |
| return warped |
|
|
|
|
| def flow_smoothness_loss(flow): |
| """Penalize spatial gradients of flow field.""" |
| dx = flow[:, :, :, 1:] - flow[:, :, :, :-1] |
| dy = flow[:, :, 1:, :] - flow[:, :, :-1, :] |
| return (dx.abs().mean() + dy.abs().mean()) / 2 |
|
|
|
|
| class GlobalSSIMLoss(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.C1 = (0.01) ** 2 |
| self.C2 = (0.03) ** 2 |
|
|
| def forward(self, pred, target): |
| B, C, H, W = pred.shape |
| pred_flat = pred.view(B, C, -1) |
| target_flat = target.view(B, C, -1) |
|
|
| mu_pred = pred_flat.mean(dim=2) |
| mu_target = target_flat.mean(dim=2) |
| sigma_pred_sq = pred_flat.var(dim=2) |
| sigma_target_sq = target_flat.var(dim=2) |
| sigma_cross = ((pred_flat - mu_pred.unsqueeze(2)) * |
| (target_flat - mu_target.unsqueeze(2))).mean(dim=2) |
|
|
| numerator = (2 * mu_pred * mu_target + self.C1) * (2 * sigma_cross + self.C2) |
| denominator = (mu_pred ** 2 + mu_target ** 2 + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2) |
| ssim = numerator / denominator |
| return 1 - ssim.mean() |
|
|