world-model / flownet_model.py
ojaffe's picture
Upload folder using huggingface_hub
87bfad6 verified
"""Flow-Warp U-Net: predicts optical flow + residual, warps last frame."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch)
self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x):
residual = self.proj(x)
x = F.silu(self.gn1(self.conv1(x)))
x = F.silu(self.gn2(self.conv2(x)))
return x + residual
class FlowWarpUNet(nn.Module):
def __init__(self, in_channels=12, channels=[48, 96, 192, 384]):
super().__init__()
# Encoder
self.encoders = nn.ModuleList()
self.pools = nn.ModuleList()
prev_ch = in_channels
for ch in channels:
self.encoders.append(ResConvBlock(prev_ch, ch))
self.pools.append(nn.MaxPool2d(2))
prev_ch = ch
# Bottleneck
self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2)
# Decoder
self.upconvs = nn.ModuleList()
self.decoders = nn.ModuleList()
dec_channels = list(reversed(channels))
prev_ch = channels[-1] * 2
for ch in dec_channels:
self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2))
self.decoders.append(ResConvBlock(ch * 2, ch))
prev_ch = ch
# Flow head (2 channels: dx, dy)
self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1)
# Residual head (3 channels: RGB residual)
self.residual_head = nn.Conv2d(dec_channels[-1], 3, 1)
# Initialize flow head near-zero for stable start
nn.init.zeros_(self.flow_head.weight)
nn.init.zeros_(self.flow_head.bias)
# Initialize residual head near-zero too
nn.init.zeros_(self.residual_head.weight)
nn.init.zeros_(self.residual_head.bias)
def forward(self, x):
"""
Args:
x: (B, 12, 64, 64) - 4 frames stacked
Returns:
flow: (B, 2, 64, 64) - optical flow (dx, dy) in pixels
residual: (B, 3, 64, 64) - residual correction
"""
skips = []
for enc, pool in zip(self.encoders, self.pools):
x = enc(x)
skips.append(x)
x = pool(x)
x = self.bottleneck(x)
for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
x = upconv(x)
x = torch.cat([x, skip], dim=1)
x = dec(x)
flow = self.flow_head(x) # (B, 2, 64, 64)
residual = self.residual_head(x) # (B, 3, 64, 64)
return flow, residual
def differentiable_warp(img, flow):
"""
Warp image by flow using bilinear sampling.
Args:
img: (B, C, H, W) - image to warp
flow: (B, 2, H, W) - flow field (dx, dy) in pixel coordinates
Returns:
warped: (B, C, H, W)
"""
B, C, H, W = img.shape
# Create base grid
grid_y, grid_x = torch.meshgrid(
torch.arange(H, device=img.device, dtype=img.dtype),
torch.arange(W, device=img.device, dtype=img.dtype),
indexing='ij'
)
grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
# Add flow
new_x = grid_x + flow[:, 0] # (B, H, W)
new_y = grid_y + flow[:, 1]
# Normalize to [-1, 1] for grid_sample
new_x = 2.0 * new_x / (W - 1) - 1.0
new_y = 2.0 * new_y / (H - 1) - 1.0
grid = torch.stack([new_x, new_y], dim=-1) # (B, H, W, 2)
warped = F.grid_sample(img, grid, mode='bilinear', padding_mode='border', align_corners=True)
return warped
def flow_smoothness_loss(flow):
"""Penalize spatial gradients of flow field."""
dx = flow[:, :, :, 1:] - flow[:, :, :, :-1]
dy = flow[:, :, 1:, :] - flow[:, :, :-1, :]
return (dx.abs().mean() + dy.abs().mean()) / 2
class GlobalSSIMLoss(nn.Module):
def __init__(self):
super().__init__()
self.C1 = (0.01) ** 2
self.C2 = (0.03) ** 2
def forward(self, pred, target):
B, C, H, W = pred.shape
pred_flat = pred.view(B, C, -1)
target_flat = target.view(B, C, -1)
mu_pred = pred_flat.mean(dim=2)
mu_target = target_flat.mean(dim=2)
sigma_pred_sq = pred_flat.var(dim=2)
sigma_target_sq = target_flat.var(dim=2)
sigma_cross = ((pred_flat - mu_pred.unsqueeze(2)) *
(target_flat - mu_target.unsqueeze(2))).mean(dim=2)
numerator = (2 * mu_pred * mu_target + self.C1) * (2 * sigma_cross + self.C2)
denominator = (mu_pred ** 2 + mu_target ** 2 + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2)
ssim = numerator / denominator
return 1 - ssim.mean()