Upload folder using huggingface_hub
Browse files- flowmask_model.py +79 -0
- flownet_model.py +150 -0
- loss_history.json +233 -0
- model.pt +3 -0
- predict.py +53 -259
- train.log +63 -31
flowmask_model.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow-Warp-Mask U-Net: predicts flow, occlusion mask, and generated frame."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ResConvBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, out_ch):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
| 11 |
+
self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch)
|
| 12 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
| 13 |
+
self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch)
|
| 14 |
+
self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
residual = self.proj(x)
|
| 18 |
+
x = F.silu(self.gn1(self.conv1(x)))
|
| 19 |
+
x = F.silu(self.gn2(self.conv2(x)))
|
| 20 |
+
return x + residual
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FlowWarpMaskUNet(nn.Module):
|
| 24 |
+
def __init__(self, in_channels=12, channels=[48, 96, 192]):
|
| 25 |
+
super().__init__()
|
| 26 |
+
# Encoder
|
| 27 |
+
self.encoders = nn.ModuleList()
|
| 28 |
+
self.pools = nn.ModuleList()
|
| 29 |
+
prev_ch = in_channels
|
| 30 |
+
for ch in channels:
|
| 31 |
+
self.encoders.append(ResConvBlock(prev_ch, ch))
|
| 32 |
+
self.pools.append(nn.MaxPool2d(2))
|
| 33 |
+
prev_ch = ch
|
| 34 |
+
|
| 35 |
+
# Bottleneck
|
| 36 |
+
self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2)
|
| 37 |
+
|
| 38 |
+
# Decoder
|
| 39 |
+
self.upconvs = nn.ModuleList()
|
| 40 |
+
self.decoders = nn.ModuleList()
|
| 41 |
+
dec_channels = list(reversed(channels))
|
| 42 |
+
prev_ch = channels[-1] * 2
|
| 43 |
+
for ch in dec_channels:
|
| 44 |
+
self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2))
|
| 45 |
+
self.decoders.append(ResConvBlock(ch * 2, ch))
|
| 46 |
+
prev_ch = ch
|
| 47 |
+
|
| 48 |
+
# Flow head (2 channels: dx, dy)
|
| 49 |
+
self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1)
|
| 50 |
+
# Mask head (1 channel: occlusion mask, sigmoid applied)
|
| 51 |
+
self.mask_head = nn.Conv2d(dec_channels[-1], 1, 1)
|
| 52 |
+
# Generation head (3 channels: full frame for occluded areas)
|
| 53 |
+
self.gen_head = nn.Conv2d(dec_channels[-1], 3, 1)
|
| 54 |
+
|
| 55 |
+
# Initialize flow and mask heads near-zero for stable start
|
| 56 |
+
nn.init.zeros_(self.flow_head.weight)
|
| 57 |
+
nn.init.zeros_(self.flow_head.bias)
|
| 58 |
+
nn.init.zeros_(self.mask_head.weight)
|
| 59 |
+
nn.init.zeros_(self.mask_head.bias)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
skips = []
|
| 63 |
+
for enc, pool in zip(self.encoders, self.pools):
|
| 64 |
+
x = enc(x)
|
| 65 |
+
skips.append(x)
|
| 66 |
+
x = pool(x)
|
| 67 |
+
|
| 68 |
+
x = self.bottleneck(x)
|
| 69 |
+
|
| 70 |
+
for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
|
| 71 |
+
x = upconv(x)
|
| 72 |
+
x = torch.cat([x, skip], dim=1)
|
| 73 |
+
x = dec(x)
|
| 74 |
+
|
| 75 |
+
flow = self.flow_head(x)
|
| 76 |
+
mask = torch.sigmoid(self.mask_head(x))
|
| 77 |
+
gen_frame = self.gen_head(x)
|
| 78 |
+
|
| 79 |
+
return flow, mask, gen_frame
|
flownet_model.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow-Warp U-Net: predicts optical flow + residual, warps last frame."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ResConvBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, out_ch):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
| 11 |
+
self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch)
|
| 12 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
| 13 |
+
self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch)
|
| 14 |
+
self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
residual = self.proj(x)
|
| 18 |
+
x = F.silu(self.gn1(self.conv1(x)))
|
| 19 |
+
x = F.silu(self.gn2(self.conv2(x)))
|
| 20 |
+
return x + residual
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FlowWarpUNet(nn.Module):
|
| 24 |
+
def __init__(self, in_channels=12, channels=[48, 96, 192, 384]):
|
| 25 |
+
super().__init__()
|
| 26 |
+
# Encoder
|
| 27 |
+
self.encoders = nn.ModuleList()
|
| 28 |
+
self.pools = nn.ModuleList()
|
| 29 |
+
prev_ch = in_channels
|
| 30 |
+
for ch in channels:
|
| 31 |
+
self.encoders.append(ResConvBlock(prev_ch, ch))
|
| 32 |
+
self.pools.append(nn.MaxPool2d(2))
|
| 33 |
+
prev_ch = ch
|
| 34 |
+
|
| 35 |
+
# Bottleneck
|
| 36 |
+
self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2)
|
| 37 |
+
|
| 38 |
+
# Decoder
|
| 39 |
+
self.upconvs = nn.ModuleList()
|
| 40 |
+
self.decoders = nn.ModuleList()
|
| 41 |
+
dec_channels = list(reversed(channels))
|
| 42 |
+
prev_ch = channels[-1] * 2
|
| 43 |
+
for ch in dec_channels:
|
| 44 |
+
self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2))
|
| 45 |
+
self.decoders.append(ResConvBlock(ch * 2, ch))
|
| 46 |
+
prev_ch = ch
|
| 47 |
+
|
| 48 |
+
# Flow head (2 channels: dx, dy)
|
| 49 |
+
self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1)
|
| 50 |
+
# Residual head (3 channels: RGB residual)
|
| 51 |
+
self.residual_head = nn.Conv2d(dec_channels[-1], 3, 1)
|
| 52 |
+
|
| 53 |
+
# Initialize flow head near-zero for stable start
|
| 54 |
+
nn.init.zeros_(self.flow_head.weight)
|
| 55 |
+
nn.init.zeros_(self.flow_head.bias)
|
| 56 |
+
# Initialize residual head near-zero too
|
| 57 |
+
nn.init.zeros_(self.residual_head.weight)
|
| 58 |
+
nn.init.zeros_(self.residual_head.bias)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
x: (B, 12, 64, 64) - 4 frames stacked
|
| 64 |
+
Returns:
|
| 65 |
+
flow: (B, 2, 64, 64) - optical flow (dx, dy) in pixels
|
| 66 |
+
residual: (B, 3, 64, 64) - residual correction
|
| 67 |
+
"""
|
| 68 |
+
skips = []
|
| 69 |
+
for enc, pool in zip(self.encoders, self.pools):
|
| 70 |
+
x = enc(x)
|
| 71 |
+
skips.append(x)
|
| 72 |
+
x = pool(x)
|
| 73 |
+
|
| 74 |
+
x = self.bottleneck(x)
|
| 75 |
+
|
| 76 |
+
for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
|
| 77 |
+
x = upconv(x)
|
| 78 |
+
x = torch.cat([x, skip], dim=1)
|
| 79 |
+
x = dec(x)
|
| 80 |
+
|
| 81 |
+
flow = self.flow_head(x) # (B, 2, 64, 64)
|
| 82 |
+
residual = self.residual_head(x) # (B, 3, 64, 64)
|
| 83 |
+
|
| 84 |
+
return flow, residual
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def differentiable_warp(img, flow):
|
| 88 |
+
"""
|
| 89 |
+
Warp image by flow using bilinear sampling.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
img: (B, C, H, W) - image to warp
|
| 93 |
+
flow: (B, 2, H, W) - flow field (dx, dy) in pixel coordinates
|
| 94 |
+
Returns:
|
| 95 |
+
warped: (B, C, H, W)
|
| 96 |
+
"""
|
| 97 |
+
B, C, H, W = img.shape
|
| 98 |
+
|
| 99 |
+
# Create base grid
|
| 100 |
+
grid_y, grid_x = torch.meshgrid(
|
| 101 |
+
torch.arange(H, device=img.device, dtype=img.dtype),
|
| 102 |
+
torch.arange(W, device=img.device, dtype=img.dtype),
|
| 103 |
+
indexing='ij'
|
| 104 |
+
)
|
| 105 |
+
grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
|
| 106 |
+
grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
|
| 107 |
+
|
| 108 |
+
# Add flow
|
| 109 |
+
new_x = grid_x + flow[:, 0] # (B, H, W)
|
| 110 |
+
new_y = grid_y + flow[:, 1]
|
| 111 |
+
|
| 112 |
+
# Normalize to [-1, 1] for grid_sample
|
| 113 |
+
new_x = 2.0 * new_x / (W - 1) - 1.0
|
| 114 |
+
new_y = 2.0 * new_y / (H - 1) - 1.0
|
| 115 |
+
|
| 116 |
+
grid = torch.stack([new_x, new_y], dim=-1) # (B, H, W, 2)
|
| 117 |
+
|
| 118 |
+
warped = F.grid_sample(img, grid, mode='bilinear', padding_mode='border', align_corners=True)
|
| 119 |
+
return warped
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def flow_smoothness_loss(flow):
|
| 123 |
+
"""Penalize spatial gradients of flow field."""
|
| 124 |
+
dx = flow[:, :, :, 1:] - flow[:, :, :, :-1]
|
| 125 |
+
dy = flow[:, :, 1:, :] - flow[:, :, :-1, :]
|
| 126 |
+
return (dx.abs().mean() + dy.abs().mean()) / 2
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class GlobalSSIMLoss(nn.Module):
|
| 130 |
+
def __init__(self):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.C1 = (0.01) ** 2
|
| 133 |
+
self.C2 = (0.03) ** 2
|
| 134 |
+
|
| 135 |
+
def forward(self, pred, target):
|
| 136 |
+
B, C, H, W = pred.shape
|
| 137 |
+
pred_flat = pred.view(B, C, -1)
|
| 138 |
+
target_flat = target.view(B, C, -1)
|
| 139 |
+
|
| 140 |
+
mu_pred = pred_flat.mean(dim=2)
|
| 141 |
+
mu_target = target_flat.mean(dim=2)
|
| 142 |
+
sigma_pred_sq = pred_flat.var(dim=2)
|
| 143 |
+
sigma_target_sq = target_flat.var(dim=2)
|
| 144 |
+
sigma_cross = ((pred_flat - mu_pred.unsqueeze(2)) *
|
| 145 |
+
(target_flat - mu_target.unsqueeze(2))).mean(dim=2)
|
| 146 |
+
|
| 147 |
+
numerator = (2 * mu_pred * mu_target + self.C1) * (2 * sigma_cross + self.C2)
|
| 148 |
+
denominator = (mu_pred ** 2 + mu_target ** 2 + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2)
|
| 149 |
+
ssim = numerator / denominator
|
| 150 |
+
return 1 - ssim.mean()
|
loss_history.json
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"epoch": 1,
|
| 4 |
+
"phase": "P1",
|
| 5 |
+
"loss": 0.093708
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"epoch": 2,
|
| 9 |
+
"phase": "P1",
|
| 10 |
+
"loss": 0.075409
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"epoch": 3,
|
| 14 |
+
"phase": "P1",
|
| 15 |
+
"loss": 0.070398
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"epoch": 4,
|
| 19 |
+
"phase": "P1",
|
| 20 |
+
"loss": 0.066922
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"epoch": 5,
|
| 24 |
+
"phase": "P1",
|
| 25 |
+
"loss": 0.064051
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"epoch": 6,
|
| 29 |
+
"phase": "P1",
|
| 30 |
+
"loss": 0.061594
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"epoch": 7,
|
| 34 |
+
"phase": "P1",
|
| 35 |
+
"loss": 0.058991
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"epoch": 8,
|
| 39 |
+
"phase": "P1",
|
| 40 |
+
"loss": 0.056665
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"epoch": 9,
|
| 44 |
+
"phase": "P1",
|
| 45 |
+
"loss": 0.054221
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 10,
|
| 49 |
+
"phase": "P1",
|
| 50 |
+
"loss": 0.052157
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"epoch": 11,
|
| 54 |
+
"phase": "P1",
|
| 55 |
+
"loss": 0.050054
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"epoch": 12,
|
| 59 |
+
"phase": "P1",
|
| 60 |
+
"loss": 0.048416
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"epoch": 13,
|
| 64 |
+
"phase": "P1",
|
| 65 |
+
"loss": 0.047013
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"epoch": 14,
|
| 69 |
+
"phase": "P1",
|
| 70 |
+
"loss": 0.046003
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"epoch": 15,
|
| 74 |
+
"phase": "P1",
|
| 75 |
+
"loss": 0.0454
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"epoch": 16,
|
| 79 |
+
"phase": "P2",
|
| 80 |
+
"loss": 0.071297
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 17,
|
| 84 |
+
"phase": "P2",
|
| 85 |
+
"loss": 0.069845
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"epoch": 18,
|
| 89 |
+
"phase": "P2",
|
| 90 |
+
"loss": 0.067838
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"epoch": 19,
|
| 94 |
+
"phase": "P2",
|
| 95 |
+
"loss": 0.102993
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"epoch": 20,
|
| 99 |
+
"phase": "P2",
|
| 100 |
+
"loss": 0.098403,
|
| 101 |
+
"val_ssim": 0.8174
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 21,
|
| 105 |
+
"phase": "P2",
|
| 106 |
+
"loss": 0.095552
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"epoch": 22,
|
| 110 |
+
"phase": "P2",
|
| 111 |
+
"loss": 0.142291
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"epoch": 23,
|
| 115 |
+
"phase": "P2",
|
| 116 |
+
"loss": 0.137962
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"epoch": 24,
|
| 120 |
+
"phase": "P2",
|
| 121 |
+
"loss": 0.133837
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"epoch": 25,
|
| 125 |
+
"phase": "P2",
|
| 126 |
+
"loss": 0.129812,
|
| 127 |
+
"val_ssim": 0.854
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"epoch": 26,
|
| 131 |
+
"phase": "P2",
|
| 132 |
+
"loss": 0.126053
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"epoch": 27,
|
| 136 |
+
"phase": "P2",
|
| 137 |
+
"loss": 0.122985
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"epoch": 28,
|
| 141 |
+
"phase": "P2",
|
| 142 |
+
"loss": 0.120476
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"epoch": 29,
|
| 146 |
+
"phase": "P2",
|
| 147 |
+
"loss": 0.117592
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"epoch": 30,
|
| 151 |
+
"phase": "P2",
|
| 152 |
+
"loss": 0.115456,
|
| 153 |
+
"val_ssim": 0.8644
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"epoch": 31,
|
| 157 |
+
"phase": "P2",
|
| 158 |
+
"loss": 0.113231
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"epoch": 32,
|
| 162 |
+
"phase": "P2",
|
| 163 |
+
"loss": 0.111175
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"epoch": 33,
|
| 167 |
+
"phase": "P2",
|
| 168 |
+
"loss": 0.108953
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"epoch": 34,
|
| 172 |
+
"phase": "P2",
|
| 173 |
+
"loss": 0.106131
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"epoch": 35,
|
| 177 |
+
"phase": "P2",
|
| 178 |
+
"loss": 0.103505,
|
| 179 |
+
"val_ssim": 0.8744
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"epoch": 36,
|
| 183 |
+
"phase": "P2",
|
| 184 |
+
"loss": 0.100435
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"epoch": 37,
|
| 188 |
+
"phase": "P2",
|
| 189 |
+
"loss": 0.097286
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"epoch": 38,
|
| 193 |
+
"phase": "P2",
|
| 194 |
+
"loss": 0.094014
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"epoch": 39,
|
| 198 |
+
"phase": "P2",
|
| 199 |
+
"loss": 0.090802
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"epoch": 40,
|
| 203 |
+
"phase": "P2",
|
| 204 |
+
"loss": 0.087507,
|
| 205 |
+
"val_ssim": 0.8852
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"epoch": 41,
|
| 209 |
+
"phase": "P2",
|
| 210 |
+
"loss": 0.084485
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"epoch": 42,
|
| 214 |
+
"phase": "P2",
|
| 215 |
+
"loss": 0.081661
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"epoch": 43,
|
| 219 |
+
"phase": "P2",
|
| 220 |
+
"loss": 0.079401
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"epoch": 44,
|
| 224 |
+
"phase": "P2",
|
| 225 |
+
"loss": 0.077772
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"epoch": 45,
|
| 229 |
+
"phase": "P2",
|
| 230 |
+
"loss": 0.076937,
|
| 231 |
+
"val_ssim": 0.885
|
| 232 |
+
}
|
| 233 |
+
]
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4774304dae39b918b34dd4ededabc4a793ac7efdeb772a746587fad584ccfe83
|
| 3 |
+
size 9089268
|
predict.py
CHANGED
|
@@ -1,282 +1,76 @@
|
|
| 1 |
-
"""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
|
| 7 |
sys.path.insert(0, "/home/coder/code")
|
| 8 |
-
from
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def detect_game(context_frames: np.ndarray) -> str:
|
| 16 |
-
first_8 = context_frames[:CONTEXT_FRAMES]
|
| 17 |
-
mean_val = first_8.mean()
|
| 18 |
-
std_val = first_8.std()
|
| 19 |
-
b_mean = first_8[:, :, :, 2].mean()
|
| 20 |
-
r_mean = first_8[:, :, :, 0].mean()
|
| 21 |
-
if mean_val > 100 and std_val < 80 and b_mean > r_mean * 1.5:
|
| 22 |
-
return "pole_position"
|
| 23 |
-
elif mean_val < 5 and 10 < std_val < 20:
|
| 24 |
-
return "pong"
|
| 25 |
-
else:
|
| 26 |
-
return "sonic"
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def load_int8_state_dict(path, device):
|
| 30 |
-
"""Load int8 quantized state dict and dequantize to float32."""
|
| 31 |
-
quantized = torch.load(path, map_location='cpu', weights_only=False)
|
| 32 |
-
sd = {}
|
| 33 |
-
for k, v in quantized.items():
|
| 34 |
-
if 'int8' in v:
|
| 35 |
-
sd[k] = (v['int8'].float() * v['scale']).to(device)
|
| 36 |
-
else:
|
| 37 |
-
sd[k] = v['float'].to(device)
|
| 38 |
-
return sd
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class EnsembleModels:
|
| 42 |
-
def __init__(self):
|
| 43 |
-
self.models = {}
|
| 44 |
-
self.sonic_ar = None
|
| 45 |
-
self.sonic_direct = None
|
| 46 |
-
self.pong_direct = None
|
| 47 |
-
self.direct_cache = None
|
| 48 |
-
self.cache_step = 0
|
| 49 |
-
|
| 50 |
-
def reset_cache(self):
|
| 51 |
-
self.direct_cache = None
|
| 52 |
-
self.cache_step = 0
|
| 53 |
|
| 54 |
|
| 55 |
def load_model(model_dir: str):
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
upsample_mode="bilinear").to(DEVICE)
|
| 72 |
-
sd = load_int8_state_dict(os.path.join(model_dir, "model_pong_direct.pt"), DEVICE)
|
| 73 |
-
pong_direct.load_state_dict(sd)
|
| 74 |
-
pong_direct.eval()
|
| 75 |
-
ens.pong_direct = pong_direct
|
| 76 |
-
|
| 77 |
-
# Sonic AR (fp16, 3 outputs) - kept in fp16 for AR chain quality
|
| 78 |
-
sonic_ar = UNet(in_channels=24, out_channels=3,
|
| 79 |
-
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 80 |
-
upsample_mode="bilinear").to(DEVICE)
|
| 81 |
-
sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
|
| 82 |
-
map_location=DEVICE, weights_only=True)
|
| 83 |
-
sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 84 |
-
sonic_ar.eval()
|
| 85 |
-
ens.sonic_ar = sonic_ar
|
| 86 |
-
|
| 87 |
-
# Sonic direct (int8 quantized, 24 outputs)
|
| 88 |
-
sonic_direct = UNet(in_channels=24, out_channels=24,
|
| 89 |
-
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 90 |
-
upsample_mode="bilinear").to(DEVICE)
|
| 91 |
-
sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_direct.pt"), DEVICE)
|
| 92 |
-
sonic_direct.load_state_dict(sd)
|
| 93 |
-
sonic_direct.eval()
|
| 94 |
-
ens.sonic_direct = sonic_direct
|
| 95 |
-
|
| 96 |
-
# PP full direct (fp16, 24 outputs)
|
| 97 |
-
pp = UNet(in_channels=24, out_channels=24,
|
| 98 |
-
enc_channels=(32, 64, 128), bottleneck_channels=192,
|
| 99 |
-
upsample_mode="bilinear").to(DEVICE)
|
| 100 |
-
sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
|
| 101 |
-
map_location=DEVICE, weights_only=True)
|
| 102 |
-
pp.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 103 |
-
pp.eval()
|
| 104 |
-
ens.models["pole_position"] = pp
|
| 105 |
-
|
| 106 |
-
return ens
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def _predict_8frames_direct(model, context_tensor, last_tensor, residual_scale=1.0):
|
| 110 |
-
output = model(context_tensor)
|
| 111 |
-
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 112 |
-
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 113 |
-
return torch.clamp(last_expanded + residual_scale * residuals, 0, 1)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
|
| 117 |
-
residual = model(context_tensor)
|
| 118 |
-
return torch.clamp(last_tensor + residual_scale * residual, 0, 1)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
| 122 |
-
game = detect_game(context_frames)
|
| 123 |
-
n = len(context_frames)
|
| 124 |
-
|
| 125 |
-
if n < CONTEXT_FRAMES:
|
| 126 |
-
padding = np.stack([context_frames[0]] * (CONTEXT_FRAMES - n), axis=0)
|
| 127 |
-
frames = np.concatenate([padding, context_frames], axis=0)
|
| 128 |
else:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
frames_t = np.transpose(frames_norm, (0, 3, 1, 2))
|
| 133 |
-
context = frames_t.reshape(1, -1, 64, 64)
|
| 134 |
-
|
| 135 |
-
last_frame = frames_norm[-1]
|
| 136 |
-
last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
|
| 137 |
-
|
| 138 |
-
if game == "pong":
|
| 139 |
-
# Pong: AR+direct ensemble, float32 caching, no TTA
|
| 140 |
-
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 141 |
-
result = ens.direct_cache[ens.cache_step]
|
| 142 |
-
ens.cache_step += 1
|
| 143 |
-
if ens.cache_step >= PRED_FRAMES:
|
| 144 |
-
ens.reset_cache()
|
| 145 |
-
return result
|
| 146 |
-
|
| 147 |
-
ens.reset_cache()
|
| 148 |
-
with torch.no_grad():
|
| 149 |
-
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 150 |
-
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 151 |
-
|
| 152 |
-
direct_pred = _predict_8frames_direct(ens.pong_direct, context_tensor, last_tensor)
|
| 153 |
-
|
| 154 |
-
ar_preds = []
|
| 155 |
-
ctx = context_tensor.clone()
|
| 156 |
-
last_t = last_tensor.clone()
|
| 157 |
-
for step in range(PRED_FRAMES):
|
| 158 |
-
predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=1.02)
|
| 159 |
-
ar_preds.append(predicted)
|
| 160 |
-
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 161 |
-
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
| 162 |
-
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 163 |
-
last_t = predicted
|
| 164 |
-
|
| 165 |
-
ar_pred = torch.stack(ar_preds, dim=1)
|
| 166 |
-
|
| 167 |
-
predicted = torch.zeros_like(direct_pred)
|
| 168 |
-
for step in range(PRED_FRAMES):
|
| 169 |
-
ar_weight = 0.85 - (step / (PRED_FRAMES - 1)) * 0.3
|
| 170 |
-
direct_weight = 1.0 - ar_weight
|
| 171 |
-
predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
|
| 172 |
-
|
| 173 |
-
predicted_np = predicted[0].cpu().numpy()
|
| 174 |
-
ens.direct_cache = []
|
| 175 |
-
for i in range(PRED_FRAMES):
|
| 176 |
-
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 177 |
-
frame = np.round(frame * 255 + 0.2).clip(0, 255).astype(np.uint8)
|
| 178 |
-
ens.direct_cache.append(frame)
|
| 179 |
-
|
| 180 |
-
result = ens.direct_cache[ens.cache_step]
|
| 181 |
-
ens.cache_step += 1
|
| 182 |
-
return result
|
| 183 |
-
|
| 184 |
-
elif game == "sonic":
|
| 185 |
-
# Sonic: AR(fp16)+direct(int8) with step blending and TTA
|
| 186 |
-
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 187 |
-
result = ens.direct_cache[ens.cache_step]
|
| 188 |
-
ens.cache_step += 1
|
| 189 |
-
if ens.cache_step >= PRED_FRAMES:
|
| 190 |
-
ens.reset_cache()
|
| 191 |
-
return result
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 197 |
|
| 198 |
-
direct_orig = _predict_8frames_direct(ens.sonic_direct, context_tensor, last_tensor)
|
| 199 |
-
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 200 |
-
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 201 |
-
direct_flipped = _predict_8frames_direct(ens.sonic_direct, context_flipped, last_flipped)
|
| 202 |
-
direct_flipped = torch.flip(direct_flipped, dims=[4])
|
| 203 |
-
direct_pred = (direct_orig + direct_flipped) / 2.0
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
ar_preds_run = []
|
| 210 |
-
ctx = context_tensor.clone()
|
| 211 |
-
ctx_flip = context_flipped.clone()
|
| 212 |
-
last_t = last_tensor.clone()
|
| 213 |
-
last_f = last_flipped.clone()
|
| 214 |
-
sonic_scales = [1.04, 1.04, 1.04, 1.08, 1.08, 1.08, 1.12, 1.12]
|
| 215 |
-
for step in range(PRED_FRAMES):
|
| 216 |
-
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 217 |
-
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 218 |
-
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=sonic_scales[step])
|
| 219 |
-
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=sonic_scales[step])
|
| 220 |
-
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 221 |
-
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 222 |
-
ar_preds_run.append(ar_frame)
|
| 223 |
-
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 224 |
-
ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
|
| 225 |
-
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 226 |
-
last_t = ar_orig
|
| 227 |
-
ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 228 |
-
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 229 |
-
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 230 |
-
last_f = ar_flip
|
| 231 |
-
all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
|
| 232 |
|
| 233 |
-
|
|
|
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
|
| 241 |
-
predicted_np = predicted[0].cpu().numpy()
|
| 242 |
-
ens.direct_cache = []
|
| 243 |
-
for i in range(PRED_FRAMES):
|
| 244 |
-
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 245 |
-
frame = np.round(frame * 255 + 0.2).clip(0, 255).astype(np.uint8)
|
| 246 |
-
ens.direct_cache.append(frame)
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
|
| 252 |
-
|
| 253 |
-
# PP: direct with TTA and caching
|
| 254 |
-
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 255 |
-
result = ens.direct_cache[ens.cache_step]
|
| 256 |
-
ens.cache_step += 1
|
| 257 |
-
if ens.cache_step >= PRED_FRAMES:
|
| 258 |
-
ens.reset_cache()
|
| 259 |
-
return result
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 271 |
-
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
for i in range(PRED_FRAMES):
|
| 276 |
-
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 277 |
-
frame = np.round(frame * 255 + 0.2).clip(0, 255).astype(np.uint8)
|
| 278 |
-
ens.direct_cache.append(frame)
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 1 |
+
"""Prediction interface for Flow-Warp-Mask U-Net v9 with TTA."""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
|
| 7 |
sys.path.insert(0, "/home/coder/code")
|
| 8 |
+
from flowmask_model import FlowWarpMaskUNet
|
| 9 |
+
from flownet_model import differentiable_warp
|
| 10 |
|
| 11 |
+
CONTEXT_LEN = 4
|
| 12 |
+
CHANNELS = [48, 96, 192]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def load_model(model_dir: str):
|
| 16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
model = FlowWarpMaskUNet(in_channels=12, channels=CHANNELS)
|
| 18 |
+
model_path = os.path.join(model_dir, "model.pt")
|
| 19 |
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
| 20 |
+
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 21 |
+
model.load_state_dict(state_dict)
|
| 22 |
+
model.to(device)
|
| 23 |
+
model.eval()
|
| 24 |
+
return {"model": model, "device": device}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _prepare_input(context_frames):
|
| 28 |
+
"""Prepare 4-frame context tensor from numpy frames."""
|
| 29 |
+
if len(context_frames) >= CONTEXT_LEN:
|
| 30 |
+
frames = context_frames[-CONTEXT_LEN:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
else:
|
| 32 |
+
pad_count = CONTEXT_LEN - len(context_frames)
|
| 33 |
+
padding = np.stack([context_frames[0]] * pad_count, axis=0)
|
| 34 |
+
frames = np.concatenate([padding, context_frames], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
|
| 37 |
+
frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64)
|
| 38 |
+
return frames_t
|
|
|
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
def _run_model(model, frames_t, device):
|
| 42 |
+
"""Run model on prepared frames, return prediction tensor."""
|
| 43 |
+
last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64)
|
| 44 |
+
inp = frames_t.reshape(1, -1, 64, 64) # (1, 12, 64, 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
inp = inp.to(device)
|
| 47 |
+
last_frame = last_frame.to(device)
|
| 48 |
|
| 49 |
+
flow, mask, gen_frame = model(inp)
|
| 50 |
+
warped = differentiable_warp(last_frame, flow)
|
| 51 |
+
pred = mask * warped + (1 - mask) * gen_frame
|
| 52 |
+
pred = torch.clamp(pred, 0, 1)
|
| 53 |
+
return pred
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
|
| 57 |
+
model = model_dict["model"]
|
| 58 |
+
device = model_dict["device"]
|
| 59 |
|
| 60 |
+
frames_t = _prepare_input(context_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
# Original prediction
|
| 64 |
+
pred1 = _run_model(model, frames_t, device)
|
|
|
|
| 65 |
|
| 66 |
+
# TTA: horizontally flipped prediction
|
| 67 |
+
frames_flipped = frames_t.flip(-1) # flip W dimension
|
| 68 |
+
pred2_flipped = _run_model(model, frames_flipped, device)
|
| 69 |
+
pred2 = pred2_flipped.flip(-1) # flip back
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# Average
|
| 72 |
+
pred = (pred1 + pred2) / 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
pred = pred[0].cpu().permute(1, 2, 0).numpy()
|
| 75 |
+
pred = (pred * 255).clip(0, 255).astype(np.uint8)
|
| 76 |
+
return pred
|
train.log
CHANGED
|
@@ -1,31 +1,63 @@
|
|
| 1 |
-
[
|
| 2 |
-
[
|
| 3 |
-
[
|
| 4 |
-
[
|
| 5 |
-
[
|
| 6 |
-
[
|
| 7 |
-
[
|
| 8 |
-
[
|
| 9 |
-
[
|
| 10 |
-
[
|
| 11 |
-
[
|
| 12 |
-
[
|
| 13 |
-
[
|
| 14 |
-
[
|
| 15 |
-
[
|
| 16 |
-
[
|
| 17 |
-
[
|
| 18 |
-
[
|
| 19 |
-
[
|
| 20 |
-
[
|
| 21 |
-
[
|
| 22 |
-
[
|
| 23 |
-
[
|
| 24 |
-
[
|
| 25 |
-
[
|
| 26 |
-
[
|
| 27 |
-
[
|
| 28 |
-
[
|
| 29 |
-
[
|
| 30 |
-
[
|
| 31 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[23:37:08] Device: cuda
|
| 2 |
+
[23:37:08] Model parameters: 4,534,230, channels=[48, 96, 192]
|
| 3 |
+
[23:37:08] Phase 1: Single-step (15 epochs)
|
| 4 |
+
[23:37:12] 45108 sequences
|
| 5 |
+
[23:37:54] P1 Epoch 1/15 | loss=0.09371
|
| 6 |
+
[23:38:34] P1 Epoch 2/15 | loss=0.07541
|
| 7 |
+
[23:39:15] P1 Epoch 3/15 | loss=0.07040
|
| 8 |
+
[23:39:56] P1 Epoch 4/15 | loss=0.06692
|
| 9 |
+
[23:40:36] P1 Epoch 5/15 | loss=0.06405
|
| 10 |
+
[23:41:17] P1 Epoch 6/15 | loss=0.06159
|
| 11 |
+
[23:41:58] P1 Epoch 7/15 | loss=0.05899
|
| 12 |
+
[23:42:40] P1 Epoch 8/15 | loss=0.05667
|
| 13 |
+
[23:43:21] P1 Epoch 9/15 | loss=0.05422
|
| 14 |
+
[23:44:01] P1 Epoch 10/15 | loss=0.05216
|
| 15 |
+
[23:44:43] P1 Epoch 11/15 | loss=0.05005
|
| 16 |
+
[23:45:23] P1 Epoch 12/15 | loss=0.04842
|
| 17 |
+
[23:46:03] P1 Epoch 13/15 | loss=0.04701
|
| 18 |
+
[23:46:45] P1 Epoch 14/15 | loss=0.04600
|
| 19 |
+
[23:47:24] P1 Epoch 15/15 | loss=0.04540
|
| 20 |
+
[23:47:24] Phase 2: Graduated AR (30 epochs)
|
| 21 |
+
[23:49:24] P2 Epoch 1/30 (steps=2) | loss=0.07130 lr=0.000500
|
| 22 |
+
[23:51:23] P2 Epoch 2/30 (steps=2) | loss=0.06985 lr=0.000500
|
| 23 |
+
[23:53:18] P2 Epoch 3/30 (steps=2) | loss=0.06784 lr=0.000500
|
| 24 |
+
[23:58:06] P2 Epoch 4/30 (steps=4) | loss=0.10299 lr=0.000500
|
| 25 |
+
[00:02:59] P2 Epoch 5/30 (steps=4) | loss=0.09840 lr=0.000500
|
| 26 |
+
[00:04:11] Val SSIM=0.8174 | {'pong': 0.7108, 'sonic': 0.8111, 'pole_position': 0.9302}
|
| 27 |
+
[00:04:11] New best! SSIM=0.8174
|
| 28 |
+
[00:09:08] P2 Epoch 6/30 (steps=4) | loss=0.09555 lr=0.000500
|
| 29 |
+
[00:21:04] P2 Epoch 7/30 (steps=8) | loss=0.14229 lr=0.000500
|
| 30 |
+
[00:32:46] P2 Epoch 8/30 (steps=8) | loss=0.13796 lr=0.000500
|
| 31 |
+
[00:44:48] P2 Epoch 9/30 (steps=8) | loss=0.13384 lr=0.000500
|
| 32 |
+
[00:57:15] P2 Epoch 10/30 (steps=8) | loss=0.12981 lr=0.000500
|
| 33 |
+
[00:58:37] Val SSIM=0.8540 | {'pong': 0.8022, 'sonic': 0.8237, 'pole_position': 0.936}
|
| 34 |
+
[00:58:37] New best! SSIM=0.8540
|
| 35 |
+
[01:11:08] P2 Epoch 11/30 (steps=8) | loss=0.12605 lr=0.000500
|
| 36 |
+
[01:23:41] P2 Epoch 12/30 (steps=8) | loss=0.12299 lr=0.000500
|
| 37 |
+
[01:36:24] P2 Epoch 13/30 (steps=8) | loss=0.12048 lr=0.000500
|
| 38 |
+
[01:48:54] P2 Epoch 14/30 (steps=8) | loss=0.11759 lr=0.000500
|
| 39 |
+
[02:01:33] P2 Epoch 15/30 (steps=8) | loss=0.11546 lr=0.000500
|
| 40 |
+
[02:02:55] Val SSIM=0.8644 | {'pong': 0.829, 'sonic': 0.8264, 'pole_position': 0.9378}
|
| 41 |
+
[02:02:55] New best! SSIM=0.8644
|
| 42 |
+
[02:15:31] P2 Epoch 16/30 (steps=8) | loss=0.11323 lr=0.000495
|
| 43 |
+
[02:28:01] P2 Epoch 17/30 (steps=8) | loss=0.11117 lr=0.000478
|
| 44 |
+
[02:40:14] P2 Epoch 18/30 (steps=8) | loss=0.10895 lr=0.000452
|
| 45 |
+
[02:52:32] P2 Epoch 19/30 (steps=8) | loss=0.10613 lr=0.000417
|
| 46 |
+
[03:05:05] P2 Epoch 20/30 (steps=8) | loss=0.10350 lr=0.000375
|
| 47 |
+
[03:06:28] Val SSIM=0.8744 | {'pong': 0.8512, 'sonic': 0.8308, 'pole_position': 0.9413}
|
| 48 |
+
[03:06:28] New best! SSIM=0.8744
|
| 49 |
+
[03:19:19] P2 Epoch 21/30 (steps=8) | loss=0.10044 lr=0.000327
|
| 50 |
+
[03:31:46] P2 Epoch 22/30 (steps=8) | loss=0.09729 lr=0.000276
|
| 51 |
+
[03:44:25] P2 Epoch 23/30 (steps=8) | loss=0.09401 lr=0.000224
|
| 52 |
+
[03:57:08] P2 Epoch 24/30 (steps=8) | loss=0.09080 lr=0.000173
|
| 53 |
+
[04:09:49] P2 Epoch 25/30 (steps=8) | loss=0.08751 lr=0.000125
|
| 54 |
+
[04:11:04] Val SSIM=0.8852 | {'pong': 0.8764, 'sonic': 0.8329, 'pole_position': 0.9462}
|
| 55 |
+
[04:11:04] New best! SSIM=0.8852
|
| 56 |
+
[04:23:43] P2 Epoch 26/30 (steps=8) | loss=0.08449 lr=0.000083
|
| 57 |
+
[04:36:13] P2 Epoch 27/30 (steps=8) | loss=0.08166 lr=0.000048
|
| 58 |
+
[04:48:48] P2 Epoch 28/30 (steps=8) | loss=0.07940 lr=0.000022
|
| 59 |
+
[05:01:33] P2 Epoch 29/30 (steps=8) | loss=0.07777 lr=0.000010
|
| 60 |
+
[05:14:14] P2 Epoch 30/30 (steps=8) | loss=0.07694 lr=0.000010
|
| 61 |
+
[05:15:35] Val SSIM=0.8850 | {'pong': 0.8783, 'sonic': 0.8292, 'pole_position': 0.9474}
|
| 62 |
+
[05:15:35] Experiment dir: 9.1 MB
|
| 63 |
+
[05:15:35] Training complete. Best val SSIM: 0.8852
|