ojaffe commited on
Commit
87bfad6
·
verified ·
1 Parent(s): 9633935

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. flowmask_model.py +79 -0
  2. flownet_model.py +150 -0
  3. loss_history.json +233 -0
  4. model.pt +3 -0
  5. predict.py +53 -259
  6. train.log +63 -31
flowmask_model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow-Warp-Mask U-Net: predicts flow, occlusion mask, and generated frame."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class ResConvBlock(nn.Module):
8
+ def __init__(self, in_ch, out_ch):
9
+ super().__init__()
10
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
11
+ self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch)
12
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
13
+ self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch)
14
+ self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
15
+
16
+ def forward(self, x):
17
+ residual = self.proj(x)
18
+ x = F.silu(self.gn1(self.conv1(x)))
19
+ x = F.silu(self.gn2(self.conv2(x)))
20
+ return x + residual
21
+
22
+
23
+ class FlowWarpMaskUNet(nn.Module):
24
+ def __init__(self, in_channels=12, channels=[48, 96, 192]):
25
+ super().__init__()
26
+ # Encoder
27
+ self.encoders = nn.ModuleList()
28
+ self.pools = nn.ModuleList()
29
+ prev_ch = in_channels
30
+ for ch in channels:
31
+ self.encoders.append(ResConvBlock(prev_ch, ch))
32
+ self.pools.append(nn.MaxPool2d(2))
33
+ prev_ch = ch
34
+
35
+ # Bottleneck
36
+ self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2)
37
+
38
+ # Decoder
39
+ self.upconvs = nn.ModuleList()
40
+ self.decoders = nn.ModuleList()
41
+ dec_channels = list(reversed(channels))
42
+ prev_ch = channels[-1] * 2
43
+ for ch in dec_channels:
44
+ self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2))
45
+ self.decoders.append(ResConvBlock(ch * 2, ch))
46
+ prev_ch = ch
47
+
48
+ # Flow head (2 channels: dx, dy)
49
+ self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1)
50
+ # Mask head (1 channel: occlusion mask, sigmoid applied)
51
+ self.mask_head = nn.Conv2d(dec_channels[-1], 1, 1)
52
+ # Generation head (3 channels: full frame for occluded areas)
53
+ self.gen_head = nn.Conv2d(dec_channels[-1], 3, 1)
54
+
55
+ # Initialize flow and mask heads near-zero for stable start
56
+ nn.init.zeros_(self.flow_head.weight)
57
+ nn.init.zeros_(self.flow_head.bias)
58
+ nn.init.zeros_(self.mask_head.weight)
59
+ nn.init.zeros_(self.mask_head.bias)
60
+
61
+ def forward(self, x):
62
+ skips = []
63
+ for enc, pool in zip(self.encoders, self.pools):
64
+ x = enc(x)
65
+ skips.append(x)
66
+ x = pool(x)
67
+
68
+ x = self.bottleneck(x)
69
+
70
+ for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
71
+ x = upconv(x)
72
+ x = torch.cat([x, skip], dim=1)
73
+ x = dec(x)
74
+
75
+ flow = self.flow_head(x)
76
+ mask = torch.sigmoid(self.mask_head(x))
77
+ gen_frame = self.gen_head(x)
78
+
79
+ return flow, mask, gen_frame
flownet_model.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow-Warp U-Net: predicts optical flow + residual, warps last frame."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class ResConvBlock(nn.Module):
8
+ def __init__(self, in_ch, out_ch):
9
+ super().__init__()
10
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
11
+ self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch)
12
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
13
+ self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch)
14
+ self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
15
+
16
+ def forward(self, x):
17
+ residual = self.proj(x)
18
+ x = F.silu(self.gn1(self.conv1(x)))
19
+ x = F.silu(self.gn2(self.conv2(x)))
20
+ return x + residual
21
+
22
+
23
+ class FlowWarpUNet(nn.Module):
24
+ def __init__(self, in_channels=12, channels=[48, 96, 192, 384]):
25
+ super().__init__()
26
+ # Encoder
27
+ self.encoders = nn.ModuleList()
28
+ self.pools = nn.ModuleList()
29
+ prev_ch = in_channels
30
+ for ch in channels:
31
+ self.encoders.append(ResConvBlock(prev_ch, ch))
32
+ self.pools.append(nn.MaxPool2d(2))
33
+ prev_ch = ch
34
+
35
+ # Bottleneck
36
+ self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2)
37
+
38
+ # Decoder
39
+ self.upconvs = nn.ModuleList()
40
+ self.decoders = nn.ModuleList()
41
+ dec_channels = list(reversed(channels))
42
+ prev_ch = channels[-1] * 2
43
+ for ch in dec_channels:
44
+ self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2))
45
+ self.decoders.append(ResConvBlock(ch * 2, ch))
46
+ prev_ch = ch
47
+
48
+ # Flow head (2 channels: dx, dy)
49
+ self.flow_head = nn.Conv2d(dec_channels[-1], 2, 1)
50
+ # Residual head (3 channels: RGB residual)
51
+ self.residual_head = nn.Conv2d(dec_channels[-1], 3, 1)
52
+
53
+ # Initialize flow head near-zero for stable start
54
+ nn.init.zeros_(self.flow_head.weight)
55
+ nn.init.zeros_(self.flow_head.bias)
56
+ # Initialize residual head near-zero too
57
+ nn.init.zeros_(self.residual_head.weight)
58
+ nn.init.zeros_(self.residual_head.bias)
59
+
60
+ def forward(self, x):
61
+ """
62
+ Args:
63
+ x: (B, 12, 64, 64) - 4 frames stacked
64
+ Returns:
65
+ flow: (B, 2, 64, 64) - optical flow (dx, dy) in pixels
66
+ residual: (B, 3, 64, 64) - residual correction
67
+ """
68
+ skips = []
69
+ for enc, pool in zip(self.encoders, self.pools):
70
+ x = enc(x)
71
+ skips.append(x)
72
+ x = pool(x)
73
+
74
+ x = self.bottleneck(x)
75
+
76
+ for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
77
+ x = upconv(x)
78
+ x = torch.cat([x, skip], dim=1)
79
+ x = dec(x)
80
+
81
+ flow = self.flow_head(x) # (B, 2, 64, 64)
82
+ residual = self.residual_head(x) # (B, 3, 64, 64)
83
+
84
+ return flow, residual
85
+
86
+
87
+ def differentiable_warp(img, flow):
88
+ """
89
+ Warp image by flow using bilinear sampling.
90
+
91
+ Args:
92
+ img: (B, C, H, W) - image to warp
93
+ flow: (B, 2, H, W) - flow field (dx, dy) in pixel coordinates
94
+ Returns:
95
+ warped: (B, C, H, W)
96
+ """
97
+ B, C, H, W = img.shape
98
+
99
+ # Create base grid
100
+ grid_y, grid_x = torch.meshgrid(
101
+ torch.arange(H, device=img.device, dtype=img.dtype),
102
+ torch.arange(W, device=img.device, dtype=img.dtype),
103
+ indexing='ij'
104
+ )
105
+ grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
106
+ grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
107
+
108
+ # Add flow
109
+ new_x = grid_x + flow[:, 0] # (B, H, W)
110
+ new_y = grid_y + flow[:, 1]
111
+
112
+ # Normalize to [-1, 1] for grid_sample
113
+ new_x = 2.0 * new_x / (W - 1) - 1.0
114
+ new_y = 2.0 * new_y / (H - 1) - 1.0
115
+
116
+ grid = torch.stack([new_x, new_y], dim=-1) # (B, H, W, 2)
117
+
118
+ warped = F.grid_sample(img, grid, mode='bilinear', padding_mode='border', align_corners=True)
119
+ return warped
120
+
121
+
122
+ def flow_smoothness_loss(flow):
123
+ """Penalize spatial gradients of flow field."""
124
+ dx = flow[:, :, :, 1:] - flow[:, :, :, :-1]
125
+ dy = flow[:, :, 1:, :] - flow[:, :, :-1, :]
126
+ return (dx.abs().mean() + dy.abs().mean()) / 2
127
+
128
+
129
+ class GlobalSSIMLoss(nn.Module):
130
+ def __init__(self):
131
+ super().__init__()
132
+ self.C1 = (0.01) ** 2
133
+ self.C2 = (0.03) ** 2
134
+
135
+ def forward(self, pred, target):
136
+ B, C, H, W = pred.shape
137
+ pred_flat = pred.view(B, C, -1)
138
+ target_flat = target.view(B, C, -1)
139
+
140
+ mu_pred = pred_flat.mean(dim=2)
141
+ mu_target = target_flat.mean(dim=2)
142
+ sigma_pred_sq = pred_flat.var(dim=2)
143
+ sigma_target_sq = target_flat.var(dim=2)
144
+ sigma_cross = ((pred_flat - mu_pred.unsqueeze(2)) *
145
+ (target_flat - mu_target.unsqueeze(2))).mean(dim=2)
146
+
147
+ numerator = (2 * mu_pred * mu_target + self.C1) * (2 * sigma_cross + self.C2)
148
+ denominator = (mu_pred ** 2 + mu_target ** 2 + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2)
149
+ ssim = numerator / denominator
150
+ return 1 - ssim.mean()
loss_history.json ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "epoch": 1,
4
+ "phase": "P1",
5
+ "loss": 0.093708
6
+ },
7
+ {
8
+ "epoch": 2,
9
+ "phase": "P1",
10
+ "loss": 0.075409
11
+ },
12
+ {
13
+ "epoch": 3,
14
+ "phase": "P1",
15
+ "loss": 0.070398
16
+ },
17
+ {
18
+ "epoch": 4,
19
+ "phase": "P1",
20
+ "loss": 0.066922
21
+ },
22
+ {
23
+ "epoch": 5,
24
+ "phase": "P1",
25
+ "loss": 0.064051
26
+ },
27
+ {
28
+ "epoch": 6,
29
+ "phase": "P1",
30
+ "loss": 0.061594
31
+ },
32
+ {
33
+ "epoch": 7,
34
+ "phase": "P1",
35
+ "loss": 0.058991
36
+ },
37
+ {
38
+ "epoch": 8,
39
+ "phase": "P1",
40
+ "loss": 0.056665
41
+ },
42
+ {
43
+ "epoch": 9,
44
+ "phase": "P1",
45
+ "loss": 0.054221
46
+ },
47
+ {
48
+ "epoch": 10,
49
+ "phase": "P1",
50
+ "loss": 0.052157
51
+ },
52
+ {
53
+ "epoch": 11,
54
+ "phase": "P1",
55
+ "loss": 0.050054
56
+ },
57
+ {
58
+ "epoch": 12,
59
+ "phase": "P1",
60
+ "loss": 0.048416
61
+ },
62
+ {
63
+ "epoch": 13,
64
+ "phase": "P1",
65
+ "loss": 0.047013
66
+ },
67
+ {
68
+ "epoch": 14,
69
+ "phase": "P1",
70
+ "loss": 0.046003
71
+ },
72
+ {
73
+ "epoch": 15,
74
+ "phase": "P1",
75
+ "loss": 0.0454
76
+ },
77
+ {
78
+ "epoch": 16,
79
+ "phase": "P2",
80
+ "loss": 0.071297
81
+ },
82
+ {
83
+ "epoch": 17,
84
+ "phase": "P2",
85
+ "loss": 0.069845
86
+ },
87
+ {
88
+ "epoch": 18,
89
+ "phase": "P2",
90
+ "loss": 0.067838
91
+ },
92
+ {
93
+ "epoch": 19,
94
+ "phase": "P2",
95
+ "loss": 0.102993
96
+ },
97
+ {
98
+ "epoch": 20,
99
+ "phase": "P2",
100
+ "loss": 0.098403,
101
+ "val_ssim": 0.8174
102
+ },
103
+ {
104
+ "epoch": 21,
105
+ "phase": "P2",
106
+ "loss": 0.095552
107
+ },
108
+ {
109
+ "epoch": 22,
110
+ "phase": "P2",
111
+ "loss": 0.142291
112
+ },
113
+ {
114
+ "epoch": 23,
115
+ "phase": "P2",
116
+ "loss": 0.137962
117
+ },
118
+ {
119
+ "epoch": 24,
120
+ "phase": "P2",
121
+ "loss": 0.133837
122
+ },
123
+ {
124
+ "epoch": 25,
125
+ "phase": "P2",
126
+ "loss": 0.129812,
127
+ "val_ssim": 0.854
128
+ },
129
+ {
130
+ "epoch": 26,
131
+ "phase": "P2",
132
+ "loss": 0.126053
133
+ },
134
+ {
135
+ "epoch": 27,
136
+ "phase": "P2",
137
+ "loss": 0.122985
138
+ },
139
+ {
140
+ "epoch": 28,
141
+ "phase": "P2",
142
+ "loss": 0.120476
143
+ },
144
+ {
145
+ "epoch": 29,
146
+ "phase": "P2",
147
+ "loss": 0.117592
148
+ },
149
+ {
150
+ "epoch": 30,
151
+ "phase": "P2",
152
+ "loss": 0.115456,
153
+ "val_ssim": 0.8644
154
+ },
155
+ {
156
+ "epoch": 31,
157
+ "phase": "P2",
158
+ "loss": 0.113231
159
+ },
160
+ {
161
+ "epoch": 32,
162
+ "phase": "P2",
163
+ "loss": 0.111175
164
+ },
165
+ {
166
+ "epoch": 33,
167
+ "phase": "P2",
168
+ "loss": 0.108953
169
+ },
170
+ {
171
+ "epoch": 34,
172
+ "phase": "P2",
173
+ "loss": 0.106131
174
+ },
175
+ {
176
+ "epoch": 35,
177
+ "phase": "P2",
178
+ "loss": 0.103505,
179
+ "val_ssim": 0.8744
180
+ },
181
+ {
182
+ "epoch": 36,
183
+ "phase": "P2",
184
+ "loss": 0.100435
185
+ },
186
+ {
187
+ "epoch": 37,
188
+ "phase": "P2",
189
+ "loss": 0.097286
190
+ },
191
+ {
192
+ "epoch": 38,
193
+ "phase": "P2",
194
+ "loss": 0.094014
195
+ },
196
+ {
197
+ "epoch": 39,
198
+ "phase": "P2",
199
+ "loss": 0.090802
200
+ },
201
+ {
202
+ "epoch": 40,
203
+ "phase": "P2",
204
+ "loss": 0.087507,
205
+ "val_ssim": 0.8852
206
+ },
207
+ {
208
+ "epoch": 41,
209
+ "phase": "P2",
210
+ "loss": 0.084485
211
+ },
212
+ {
213
+ "epoch": 42,
214
+ "phase": "P2",
215
+ "loss": 0.081661
216
+ },
217
+ {
218
+ "epoch": 43,
219
+ "phase": "P2",
220
+ "loss": 0.079401
221
+ },
222
+ {
223
+ "epoch": 44,
224
+ "phase": "P2",
225
+ "loss": 0.077772
226
+ },
227
+ {
228
+ "epoch": 45,
229
+ "phase": "P2",
230
+ "loss": 0.076937,
231
+ "val_ssim": 0.885
232
+ }
233
+ ]
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4774304dae39b918b34dd4ededabc4a793ac7efdeb772a746587fad584ccfe83
3
+ size 9089268
predict.py CHANGED
@@ -1,282 +1,76 @@
1
- """Full PP swap: Pong direct int8, full PP model, Sonic AR fp16 + direct int8."""
2
  import sys
3
  import os
4
  import numpy as np
5
  import torch
6
 
7
  sys.path.insert(0, "/home/coder/code")
8
- from unet_model import UNet
 
9
 
10
- CONTEXT_FRAMES = 8
11
- PRED_FRAMES = 8
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
-
15
- def detect_game(context_frames: np.ndarray) -> str:
16
- first_8 = context_frames[:CONTEXT_FRAMES]
17
- mean_val = first_8.mean()
18
- std_val = first_8.std()
19
- b_mean = first_8[:, :, :, 2].mean()
20
- r_mean = first_8[:, :, :, 0].mean()
21
- if mean_val > 100 and std_val < 80 and b_mean > r_mean * 1.5:
22
- return "pole_position"
23
- elif mean_val < 5 and 10 < std_val < 20:
24
- return "pong"
25
- else:
26
- return "sonic"
27
-
28
-
29
- def load_int8_state_dict(path, device):
30
- """Load int8 quantized state dict and dequantize to float32."""
31
- quantized = torch.load(path, map_location='cpu', weights_only=False)
32
- sd = {}
33
- for k, v in quantized.items():
34
- if 'int8' in v:
35
- sd[k] = (v['int8'].float() * v['scale']).to(device)
36
- else:
37
- sd[k] = v['float'].to(device)
38
- return sd
39
-
40
-
41
- class EnsembleModels:
42
- def __init__(self):
43
- self.models = {}
44
- self.sonic_ar = None
45
- self.sonic_direct = None
46
- self.pong_direct = None
47
- self.direct_cache = None
48
- self.cache_step = 0
49
-
50
- def reset_cache(self):
51
- self.direct_cache = None
52
- self.cache_step = 0
53
 
54
 
55
  def load_model(model_dir: str):
56
- ens = EnsembleModels()
57
-
58
- # Pong AR (fp16, 3 outputs)
59
- pong = UNet(in_channels=24, out_channels=3,
60
- enc_channels=(32, 64, 128), bottleneck_channels=128,
61
- upsample_mode="bilinear").to(DEVICE)
62
- sd = torch.load(os.path.join(model_dir, "model_pong.pt"),
63
- map_location=DEVICE, weights_only=True)
64
- pong.load_state_dict({k: v.float() for k, v in sd.items()})
65
- pong.eval()
66
- ens.models["pong"] = pong
67
-
68
- # Pong direct (int8 quantized, 24 outputs)
69
- pong_direct = UNet(in_channels=24, out_channels=24,
70
- enc_channels=(32, 64, 128), bottleneck_channels=128,
71
- upsample_mode="bilinear").to(DEVICE)
72
- sd = load_int8_state_dict(os.path.join(model_dir, "model_pong_direct.pt"), DEVICE)
73
- pong_direct.load_state_dict(sd)
74
- pong_direct.eval()
75
- ens.pong_direct = pong_direct
76
-
77
- # Sonic AR (fp16, 3 outputs) - kept in fp16 for AR chain quality
78
- sonic_ar = UNet(in_channels=24, out_channels=3,
79
- enc_channels=(48, 96, 192), bottleneck_channels=256,
80
- upsample_mode="bilinear").to(DEVICE)
81
- sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
82
- map_location=DEVICE, weights_only=True)
83
- sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
84
- sonic_ar.eval()
85
- ens.sonic_ar = sonic_ar
86
-
87
- # Sonic direct (int8 quantized, 24 outputs)
88
- sonic_direct = UNet(in_channels=24, out_channels=24,
89
- enc_channels=(48, 96, 192), bottleneck_channels=256,
90
- upsample_mode="bilinear").to(DEVICE)
91
- sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_direct.pt"), DEVICE)
92
- sonic_direct.load_state_dict(sd)
93
- sonic_direct.eval()
94
- ens.sonic_direct = sonic_direct
95
-
96
- # PP full direct (fp16, 24 outputs)
97
- pp = UNet(in_channels=24, out_channels=24,
98
- enc_channels=(32, 64, 128), bottleneck_channels=192,
99
- upsample_mode="bilinear").to(DEVICE)
100
- sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
101
- map_location=DEVICE, weights_only=True)
102
- pp.load_state_dict({k: v.float() for k, v in sd.items()})
103
- pp.eval()
104
- ens.models["pole_position"] = pp
105
-
106
- return ens
107
-
108
-
109
- def _predict_8frames_direct(model, context_tensor, last_tensor, residual_scale=1.0):
110
- output = model(context_tensor)
111
- residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
112
- last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
113
- return torch.clamp(last_expanded + residual_scale * residuals, 0, 1)
114
-
115
-
116
- def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
117
- residual = model(context_tensor)
118
- return torch.clamp(last_tensor + residual_scale * residual, 0, 1)
119
-
120
-
121
- def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
122
- game = detect_game(context_frames)
123
- n = len(context_frames)
124
-
125
- if n < CONTEXT_FRAMES:
126
- padding = np.stack([context_frames[0]] * (CONTEXT_FRAMES - n), axis=0)
127
- frames = np.concatenate([padding, context_frames], axis=0)
128
  else:
129
- frames = context_frames[-CONTEXT_FRAMES:]
130
-
131
- frames_norm = frames.astype(np.float32) / 255.0
132
- frames_t = np.transpose(frames_norm, (0, 3, 1, 2))
133
- context = frames_t.reshape(1, -1, 64, 64)
134
-
135
- last_frame = frames_norm[-1]
136
- last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
137
-
138
- if game == "pong":
139
- # Pong: AR+direct ensemble, float32 caching, no TTA
140
- if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
141
- result = ens.direct_cache[ens.cache_step]
142
- ens.cache_step += 1
143
- if ens.cache_step >= PRED_FRAMES:
144
- ens.reset_cache()
145
- return result
146
-
147
- ens.reset_cache()
148
- with torch.no_grad():
149
- context_tensor = torch.from_numpy(context).to(DEVICE)
150
- last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
151
-
152
- direct_pred = _predict_8frames_direct(ens.pong_direct, context_tensor, last_tensor)
153
-
154
- ar_preds = []
155
- ctx = context_tensor.clone()
156
- last_t = last_tensor.clone()
157
- for step in range(PRED_FRAMES):
158
- predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=1.02)
159
- ar_preds.append(predicted)
160
- ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
161
- ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
162
- ctx = ctx_frames.reshape(1, -1, 64, 64)
163
- last_t = predicted
164
-
165
- ar_pred = torch.stack(ar_preds, dim=1)
166
-
167
- predicted = torch.zeros_like(direct_pred)
168
- for step in range(PRED_FRAMES):
169
- ar_weight = 0.85 - (step / (PRED_FRAMES - 1)) * 0.3
170
- direct_weight = 1.0 - ar_weight
171
- predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
172
-
173
- predicted_np = predicted[0].cpu().numpy()
174
- ens.direct_cache = []
175
- for i in range(PRED_FRAMES):
176
- frame = np.transpose(predicted_np[i], (1, 2, 0))
177
- frame = np.round(frame * 255 + 0.2).clip(0, 255).astype(np.uint8)
178
- ens.direct_cache.append(frame)
179
-
180
- result = ens.direct_cache[ens.cache_step]
181
- ens.cache_step += 1
182
- return result
183
-
184
- elif game == "sonic":
185
- # Sonic: AR(fp16)+direct(int8) with step blending and TTA
186
- if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
187
- result = ens.direct_cache[ens.cache_step]
188
- ens.cache_step += 1
189
- if ens.cache_step >= PRED_FRAMES:
190
- ens.reset_cache()
191
- return result
192
 
193
- ens.reset_cache()
194
- with torch.no_grad():
195
- context_tensor = torch.from_numpy(context).to(DEVICE)
196
- last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
197
 
198
- direct_orig = _predict_8frames_direct(ens.sonic_direct, context_tensor, last_tensor)
199
- context_flipped = torch.flip(context_tensor, dims=[3])
200
- last_flipped = torch.flip(last_tensor, dims=[3])
201
- direct_flipped = _predict_8frames_direct(ens.sonic_direct, context_flipped, last_flipped)
202
- direct_flipped = torch.flip(direct_flipped, dims=[4])
203
- direct_pred = (direct_orig + direct_flipped) / 2.0
204
 
205
- # Multi-run AR with noise diversity (fixed seed for reproducibility)
206
- all_ar_runs = []
207
- torch.manual_seed(2)
208
- for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
209
- ar_preds_run = []
210
- ctx = context_tensor.clone()
211
- ctx_flip = context_flipped.clone()
212
- last_t = last_tensor.clone()
213
- last_f = last_flipped.clone()
214
- sonic_scales = [1.04, 1.04, 1.04, 1.08, 1.08, 1.08, 1.12, 1.12]
215
- for step in range(PRED_FRAMES):
216
- ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
217
- ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
218
- ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=sonic_scales[step])
219
- ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=sonic_scales[step])
220
- ar_flip_back = torch.flip(ar_flip, dims=[3])
221
- ar_frame = (ar_orig + ar_flip_back) / 2.0
222
- ar_preds_run.append(ar_frame)
223
- ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
224
- ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
225
- ctx = ctx_frames.reshape(1, -1, 64, 64)
226
- last_t = ar_orig
227
- ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
228
- ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
229
- ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
230
- last_f = ar_flip
231
- all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
232
 
233
- ar_pred = sum(all_ar_runs) / len(all_ar_runs)
 
234
 
235
- predicted = torch.zeros_like(direct_pred)
236
- for step in range(PRED_FRAMES):
237
- ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
238
- direct_weight = 1.0 - ar_weight
239
- predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
240
 
241
- predicted_np = predicted[0].cpu().numpy()
242
- ens.direct_cache = []
243
- for i in range(PRED_FRAMES):
244
- frame = np.transpose(predicted_np[i], (1, 2, 0))
245
- frame = np.round(frame * 255 + 0.2).clip(0, 255).astype(np.uint8)
246
- ens.direct_cache.append(frame)
247
 
248
- result = ens.direct_cache[ens.cache_step]
249
- ens.cache_step += 1
250
- return result
251
 
252
- else:
253
- # PP: direct with TTA and caching
254
- if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
255
- result = ens.direct_cache[ens.cache_step]
256
- ens.cache_step += 1
257
- if ens.cache_step >= PRED_FRAMES:
258
- ens.reset_cache()
259
- return result
260
 
261
- ens.reset_cache()
262
- with torch.no_grad():
263
- context_tensor = torch.from_numpy(context).to(DEVICE)
264
- last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
265
 
266
- predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale=0.97)
267
- context_flipped = torch.flip(context_tensor, dims=[3])
268
- last_flipped = torch.flip(last_tensor, dims=[3])
269
- predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale=0.97)
270
- predicted_flipped = torch.flip(predicted_flipped, dims=[4])
271
- predicted = (predicted_orig + predicted_flipped) / 2.0
272
 
273
- predicted_np = predicted[0].cpu().numpy()
274
- ens.direct_cache = []
275
- for i in range(PRED_FRAMES):
276
- frame = np.transpose(predicted_np[i], (1, 2, 0))
277
- frame = np.round(frame * 255 + 0.2).clip(0, 255).astype(np.uint8)
278
- ens.direct_cache.append(frame)
279
 
280
- result = ens.direct_cache[ens.cache_step]
281
- ens.cache_step += 1
282
- return result
 
1
+ """Prediction interface for Flow-Warp-Mask U-Net v9 with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
5
  import torch
6
 
7
  sys.path.insert(0, "/home/coder/code")
8
+ from flowmask_model import FlowWarpMaskUNet
9
+ from flownet_model import differentiable_warp
10
 
11
+ CONTEXT_LEN = 4
12
+ CHANNELS = [48, 96, 192]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def load_model(model_dir: str):
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = FlowWarpMaskUNet(in_channels=12, channels=CHANNELS)
18
+ model_path = os.path.join(model_dir, "model.pt")
19
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
20
+ state_dict = {k: v.float() for k, v in state_dict.items()}
21
+ model.load_state_dict(state_dict)
22
+ model.to(device)
23
+ model.eval()
24
+ return {"model": model, "device": device}
25
+
26
+
27
+ def _prepare_input(context_frames):
28
+ """Prepare 4-frame context tensor from numpy frames."""
29
+ if len(context_frames) >= CONTEXT_LEN:
30
+ frames = context_frames[-CONTEXT_LEN:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  else:
32
+ pad_count = CONTEXT_LEN - len(context_frames)
33
+ padding = np.stack([context_frames[0]] * pad_count, axis=0)
34
+ frames = np.concatenate([padding, context_frames], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
37
+ frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64)
38
+ return frames_t
 
39
 
 
 
 
 
 
 
40
 
41
+ def _run_model(model, frames_t, device):
42
+ """Run model on prepared frames, return prediction tensor."""
43
+ last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64)
44
+ inp = frames_t.reshape(1, -1, 64, 64) # (1, 12, 64, 64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ inp = inp.to(device)
47
+ last_frame = last_frame.to(device)
48
 
49
+ flow, mask, gen_frame = model(inp)
50
+ warped = differentiable_warp(last_frame, flow)
51
+ pred = mask * warped + (1 - mask) * gen_frame
52
+ pred = torch.clamp(pred, 0, 1)
53
+ return pred
54
 
 
 
 
 
 
 
55
 
56
+ def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
57
+ model = model_dict["model"]
58
+ device = model_dict["device"]
59
 
60
+ frames_t = _prepare_input(context_frames)
 
 
 
 
 
 
 
61
 
62
+ with torch.no_grad():
63
+ # Original prediction
64
+ pred1 = _run_model(model, frames_t, device)
 
65
 
66
+ # TTA: horizontally flipped prediction
67
+ frames_flipped = frames_t.flip(-1) # flip W dimension
68
+ pred2_flipped = _run_model(model, frames_flipped, device)
69
+ pred2 = pred2_flipped.flip(-1) # flip back
 
 
70
 
71
+ # Average
72
+ pred = (pred1 + pred2) / 2.0
 
 
 
 
73
 
74
+ pred = pred[0].cpu().permute(1, 2, 0).numpy()
75
+ pred = (pred * 255).clip(0, 255).astype(np.uint8)
76
+ return pred
train.log CHANGED
@@ -1,31 +1,63 @@
1
- [2026-04-12 07:14:27] Starting PP SSIM-only training for 2026-04-12-153000-pp-ssim-only
2
- [2026-04-12 07:14:27] Device: cuda
3
- [2026-04-12 07:14:27] PP: 1,465,848 params (2.8 MB fp16)
4
- [2026-04-12 07:14:28] PP train: 4097 seqs (len=16)
5
- [2026-04-12 07:14:28] PP val: 482 seqs (len=16)
6
- [2026-04-12 07:14:38] E1/100 | T:0.089821(S:0.9041) V:0.069040(S:0.9263) LR:3.00e-04
7
- [2026-04-12 07:14:46] E2/100 | T:0.074465(S:0.9199) V:0.061649(S:0.9337) LR:3.00e-04
8
- [2026-04-12 07:14:55] E3/100 | T:0.069567(S:0.9251) V:0.060543(S:0.9351) LR:2.99e-04
9
- [2026-04-12 07:15:04] E4/100 | T:0.066393(S:0.9284) V:0.058608(S:0.9372) LR:2.99e-04
10
- [2026-04-12 07:15:12] E5/100 | T:0.063999(S:0.9309) V:0.057957(S:0.9378) LR:2.98e-04
11
- [2026-04-12 07:15:21] E6/100 | T:0.061633(S:0.9334) V:0.056070(S:0.9396) LR:2.97e-04
12
- [2026-04-12 07:15:31] E7/100 | T:0.060136(S:0.9350) V:0.051418(S:0.9447) LR:2.96e-04
13
- [2026-04-12 07:15:50] E9/100 | T:0.057306(S:0.9380) V:0.050541(S:0.9454) LR:2.94e-04
14
- [2026-04-12 07:16:00] E10/100 | T:0.055858(S:0.9396) V:0.052572(S:0.9431) LR:2.93e-04
15
- [2026-04-12 07:16:09] E11/100 | T:0.055021(S:0.9404) V:0.049028(S:0.9470) LR:2.91e-04
16
- [2026-04-12 07:17:14] E18/100 | T:0.048082(S:0.9480) V:0.047964(S:0.9479) LR:2.77e-04
17
- [2026-04-12 07:17:30] E20/100 | T:0.047000(S:0.9492) V:0.049137(S:0.9465) LR:2.71e-04
18
- [2026-04-12 07:18:25] E26/100 | T:0.043209(S:0.9533) V:0.047514(S:0.9482) LR:2.53e-04
19
- [2026-04-12 07:18:45] E28/100 | T:0.042278(S:0.9543) V:0.046412(S:0.9495) LR:2.46e-04
20
- [2026-04-12 07:19:04] E30/100 | T:0.041417(S:0.9553) V:0.047650(S:0.9481) LR:2.38e-04
21
- [2026-04-12 07:20:37] E40/100 | T:0.037882(S:0.9592) V:0.048536(S:0.9469) LR:1.97e-04
22
- [2026-04-12 07:21:15] E44/100 | T:0.036738(S:0.9604) V:0.046170(S:0.9495) LR:1.79e-04
23
- [2026-04-12 07:22:10] E50/100 | T:0.035150(S:0.9621) V:0.048340(S:0.9468) LR:1.50e-04
24
- [2026-04-12 07:23:41] E60/100 | T:0.033106(S:0.9644) V:0.048292(S:0.9467) LR:1.04e-04
25
- [2026-04-12 07:25:13] E70/100 | T:0.031615(S:0.9660) V:0.047850(S:0.9472) LR:6.26e-05
26
- [2026-04-12 07:26:46] E80/100 | T:0.030609(S:0.9671) V:0.047661(S:0.9474) LR:2.96e-05
27
- [2026-04-12 07:28:21] E90/100 | T:0.030099(S:0.9676) V:0.048131(S:0.9468) LR:8.32e-06
28
- [2026-04-12 07:29:58] E100/100 | T:0.029986(S:0.9678) V:0.048119(S:0.9469) LR:1.00e-06
29
- [2026-04-12 07:29:58] Done. Best val loss: 0.046170
30
- [2026-04-12 07:29:58] Model size: 2.8 MB
31
- [2026-04-12 07:29:58] Training complete!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [23:37:08] Device: cuda
2
+ [23:37:08] Model parameters: 4,534,230, channels=[48, 96, 192]
3
+ [23:37:08] Phase 1: Single-step (15 epochs)
4
+ [23:37:12] 45108 sequences
5
+ [23:37:54] P1 Epoch 1/15 | loss=0.09371
6
+ [23:38:34] P1 Epoch 2/15 | loss=0.07541
7
+ [23:39:15] P1 Epoch 3/15 | loss=0.07040
8
+ [23:39:56] P1 Epoch 4/15 | loss=0.06692
9
+ [23:40:36] P1 Epoch 5/15 | loss=0.06405
10
+ [23:41:17] P1 Epoch 6/15 | loss=0.06159
11
+ [23:41:58] P1 Epoch 7/15 | loss=0.05899
12
+ [23:42:40] P1 Epoch 8/15 | loss=0.05667
13
+ [23:43:21] P1 Epoch 9/15 | loss=0.05422
14
+ [23:44:01] P1 Epoch 10/15 | loss=0.05216
15
+ [23:44:43] P1 Epoch 11/15 | loss=0.05005
16
+ [23:45:23] P1 Epoch 12/15 | loss=0.04842
17
+ [23:46:03] P1 Epoch 13/15 | loss=0.04701
18
+ [23:46:45] P1 Epoch 14/15 | loss=0.04600
19
+ [23:47:24] P1 Epoch 15/15 | loss=0.04540
20
+ [23:47:24] Phase 2: Graduated AR (30 epochs)
21
+ [23:49:24] P2 Epoch 1/30 (steps=2) | loss=0.07130 lr=0.000500
22
+ [23:51:23] P2 Epoch 2/30 (steps=2) | loss=0.06985 lr=0.000500
23
+ [23:53:18] P2 Epoch 3/30 (steps=2) | loss=0.06784 lr=0.000500
24
+ [23:58:06] P2 Epoch 4/30 (steps=4) | loss=0.10299 lr=0.000500
25
+ [00:02:59] P2 Epoch 5/30 (steps=4) | loss=0.09840 lr=0.000500
26
+ [00:04:11] Val SSIM=0.8174 | {'pong': 0.7108, 'sonic': 0.8111, 'pole_position': 0.9302}
27
+ [00:04:11] New best! SSIM=0.8174
28
+ [00:09:08] P2 Epoch 6/30 (steps=4) | loss=0.09555 lr=0.000500
29
+ [00:21:04] P2 Epoch 7/30 (steps=8) | loss=0.14229 lr=0.000500
30
+ [00:32:46] P2 Epoch 8/30 (steps=8) | loss=0.13796 lr=0.000500
31
+ [00:44:48] P2 Epoch 9/30 (steps=8) | loss=0.13384 lr=0.000500
32
+ [00:57:15] P2 Epoch 10/30 (steps=8) | loss=0.12981 lr=0.000500
33
+ [00:58:37] Val SSIM=0.8540 | {'pong': 0.8022, 'sonic': 0.8237, 'pole_position': 0.936}
34
+ [00:58:37] New best! SSIM=0.8540
35
+ [01:11:08] P2 Epoch 11/30 (steps=8) | loss=0.12605 lr=0.000500
36
+ [01:23:41] P2 Epoch 12/30 (steps=8) | loss=0.12299 lr=0.000500
37
+ [01:36:24] P2 Epoch 13/30 (steps=8) | loss=0.12048 lr=0.000500
38
+ [01:48:54] P2 Epoch 14/30 (steps=8) | loss=0.11759 lr=0.000500
39
+ [02:01:33] P2 Epoch 15/30 (steps=8) | loss=0.11546 lr=0.000500
40
+ [02:02:55] Val SSIM=0.8644 | {'pong': 0.829, 'sonic': 0.8264, 'pole_position': 0.9378}
41
+ [02:02:55] New best! SSIM=0.8644
42
+ [02:15:31] P2 Epoch 16/30 (steps=8) | loss=0.11323 lr=0.000495
43
+ [02:28:01] P2 Epoch 17/30 (steps=8) | loss=0.11117 lr=0.000478
44
+ [02:40:14] P2 Epoch 18/30 (steps=8) | loss=0.10895 lr=0.000452
45
+ [02:52:32] P2 Epoch 19/30 (steps=8) | loss=0.10613 lr=0.000417
46
+ [03:05:05] P2 Epoch 20/30 (steps=8) | loss=0.10350 lr=0.000375
47
+ [03:06:28] Val SSIM=0.8744 | {'pong': 0.8512, 'sonic': 0.8308, 'pole_position': 0.9413}
48
+ [03:06:28] New best! SSIM=0.8744
49
+ [03:19:19] P2 Epoch 21/30 (steps=8) | loss=0.10044 lr=0.000327
50
+ [03:31:46] P2 Epoch 22/30 (steps=8) | loss=0.09729 lr=0.000276
51
+ [03:44:25] P2 Epoch 23/30 (steps=8) | loss=0.09401 lr=0.000224
52
+ [03:57:08] P2 Epoch 24/30 (steps=8) | loss=0.09080 lr=0.000173
53
+ [04:09:49] P2 Epoch 25/30 (steps=8) | loss=0.08751 lr=0.000125
54
+ [04:11:04] Val SSIM=0.8852 | {'pong': 0.8764, 'sonic': 0.8329, 'pole_position': 0.9462}
55
+ [04:11:04] New best! SSIM=0.8852
56
+ [04:23:43] P2 Epoch 26/30 (steps=8) | loss=0.08449 lr=0.000083
57
+ [04:36:13] P2 Epoch 27/30 (steps=8) | loss=0.08166 lr=0.000048
58
+ [04:48:48] P2 Epoch 28/30 (steps=8) | loss=0.07940 lr=0.000022
59
+ [05:01:33] P2 Epoch 29/30 (steps=8) | loss=0.07777 lr=0.000010
60
+ [05:14:14] P2 Epoch 30/30 (steps=8) | loss=0.07694 lr=0.000010
61
+ [05:15:35] Val SSIM=0.8850 | {'pong': 0.8783, 'sonic': 0.8292, 'pole_position': 0.9474}
62
+ [05:15:35] Experiment dir: 9.1 MB
63
+ [05:15:35] Training complete. Best val SSIM: 0.8852