ojaffe commited on
Commit
5a6434e
·
verified ·
1 Parent(s): 87bfad6

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. loss_history.json +51 -51
  2. model.pt +2 -2
  3. multiscale_flow_model.py +102 -0
  4. predict.py +12 -12
  5. train.log +64 -63
loss_history.json CHANGED
@@ -2,232 +2,232 @@
2
  {
3
  "epoch": 1,
4
  "phase": "P1",
5
- "loss": 0.093708
6
  },
7
  {
8
  "epoch": 2,
9
  "phase": "P1",
10
- "loss": 0.075409
11
  },
12
  {
13
  "epoch": 3,
14
  "phase": "P1",
15
- "loss": 0.070398
16
  },
17
  {
18
  "epoch": 4,
19
  "phase": "P1",
20
- "loss": 0.066922
21
  },
22
  {
23
  "epoch": 5,
24
  "phase": "P1",
25
- "loss": 0.064051
26
  },
27
  {
28
  "epoch": 6,
29
  "phase": "P1",
30
- "loss": 0.061594
31
  },
32
  {
33
  "epoch": 7,
34
  "phase": "P1",
35
- "loss": 0.058991
36
  },
37
  {
38
  "epoch": 8,
39
  "phase": "P1",
40
- "loss": 0.056665
41
  },
42
  {
43
  "epoch": 9,
44
  "phase": "P1",
45
- "loss": 0.054221
46
  },
47
  {
48
  "epoch": 10,
49
  "phase": "P1",
50
- "loss": 0.052157
51
  },
52
  {
53
  "epoch": 11,
54
  "phase": "P1",
55
- "loss": 0.050054
56
  },
57
  {
58
  "epoch": 12,
59
  "phase": "P1",
60
- "loss": 0.048416
61
  },
62
  {
63
  "epoch": 13,
64
  "phase": "P1",
65
- "loss": 0.047013
66
  },
67
  {
68
  "epoch": 14,
69
  "phase": "P1",
70
- "loss": 0.046003
71
  },
72
  {
73
  "epoch": 15,
74
  "phase": "P1",
75
- "loss": 0.0454
76
  },
77
  {
78
  "epoch": 16,
79
  "phase": "P2",
80
- "loss": 0.071297
81
  },
82
  {
83
  "epoch": 17,
84
  "phase": "P2",
85
- "loss": 0.069845
86
  },
87
  {
88
  "epoch": 18,
89
  "phase": "P2",
90
- "loss": 0.067838
91
  },
92
  {
93
  "epoch": 19,
94
  "phase": "P2",
95
- "loss": 0.102993
96
  },
97
  {
98
  "epoch": 20,
99
  "phase": "P2",
100
- "loss": 0.098403,
101
- "val_ssim": 0.8174
102
  },
103
  {
104
  "epoch": 21,
105
  "phase": "P2",
106
- "loss": 0.095552
107
  },
108
  {
109
  "epoch": 22,
110
  "phase": "P2",
111
- "loss": 0.142291
112
  },
113
  {
114
  "epoch": 23,
115
  "phase": "P2",
116
- "loss": 0.137962
117
  },
118
  {
119
  "epoch": 24,
120
  "phase": "P2",
121
- "loss": 0.133837
122
  },
123
  {
124
  "epoch": 25,
125
  "phase": "P2",
126
- "loss": 0.129812,
127
- "val_ssim": 0.854
128
  },
129
  {
130
  "epoch": 26,
131
  "phase": "P2",
132
- "loss": 0.126053
133
  },
134
  {
135
  "epoch": 27,
136
  "phase": "P2",
137
- "loss": 0.122985
138
  },
139
  {
140
  "epoch": 28,
141
  "phase": "P2",
142
- "loss": 0.120476
143
  },
144
  {
145
  "epoch": 29,
146
  "phase": "P2",
147
- "loss": 0.117592
148
  },
149
  {
150
  "epoch": 30,
151
  "phase": "P2",
152
- "loss": 0.115456,
153
- "val_ssim": 0.8644
154
  },
155
  {
156
  "epoch": 31,
157
  "phase": "P2",
158
- "loss": 0.113231
159
  },
160
  {
161
  "epoch": 32,
162
  "phase": "P2",
163
- "loss": 0.111175
164
  },
165
  {
166
  "epoch": 33,
167
  "phase": "P2",
168
- "loss": 0.108953
169
  },
170
  {
171
  "epoch": 34,
172
  "phase": "P2",
173
- "loss": 0.106131
174
  },
175
  {
176
  "epoch": 35,
177
  "phase": "P2",
178
- "loss": 0.103505,
179
- "val_ssim": 0.8744
180
  },
181
  {
182
  "epoch": 36,
183
  "phase": "P2",
184
- "loss": 0.100435
185
  },
186
  {
187
  "epoch": 37,
188
  "phase": "P2",
189
- "loss": 0.097286
190
  },
191
  {
192
  "epoch": 38,
193
  "phase": "P2",
194
- "loss": 0.094014
195
  },
196
  {
197
  "epoch": 39,
198
  "phase": "P2",
199
- "loss": 0.090802
200
  },
201
  {
202
  "epoch": 40,
203
  "phase": "P2",
204
- "loss": 0.087507,
205
- "val_ssim": 0.8852
206
  },
207
  {
208
  "epoch": 41,
209
  "phase": "P2",
210
- "loss": 0.084485
211
  },
212
  {
213
  "epoch": 42,
214
  "phase": "P2",
215
- "loss": 0.081661
216
  },
217
  {
218
  "epoch": 43,
219
  "phase": "P2",
220
- "loss": 0.079401
221
  },
222
  {
223
  "epoch": 44,
224
  "phase": "P2",
225
- "loss": 0.077772
226
  },
227
  {
228
  "epoch": 45,
229
  "phase": "P2",
230
- "loss": 0.076937,
231
- "val_ssim": 0.885
232
  }
233
  ]
 
2
  {
3
  "epoch": 1,
4
  "phase": "P1",
5
+ "loss": 0.152055
6
  },
7
  {
8
  "epoch": 2,
9
  "phase": "P1",
10
+ "loss": 0.126681
11
  },
12
  {
13
  "epoch": 3,
14
  "phase": "P1",
15
+ "loss": 0.119891
16
  },
17
  {
18
  "epoch": 4,
19
  "phase": "P1",
20
+ "loss": 0.114801
21
  },
22
  {
23
  "epoch": 5,
24
  "phase": "P1",
25
+ "loss": 0.110611
26
  },
27
  {
28
  "epoch": 6,
29
  "phase": "P1",
30
+ "loss": 0.107016
31
  },
32
  {
33
  "epoch": 7,
34
  "phase": "P1",
35
+ "loss": 0.103401
36
  },
37
  {
38
  "epoch": 8,
39
  "phase": "P1",
40
+ "loss": 0.100012
41
  },
42
  {
43
  "epoch": 9,
44
  "phase": "P1",
45
+ "loss": 0.096366
46
  },
47
  {
48
  "epoch": 10,
49
  "phase": "P1",
50
+ "loss": 0.09296
51
  },
52
  {
53
  "epoch": 11,
54
  "phase": "P1",
55
+ "loss": 0.089986
56
  },
57
  {
58
  "epoch": 12,
59
  "phase": "P1",
60
+ "loss": 0.087143
61
  },
62
  {
63
  "epoch": 13,
64
  "phase": "P1",
65
+ "loss": 0.08477
66
  },
67
  {
68
  "epoch": 14,
69
  "phase": "P1",
70
+ "loss": 0.083114
71
  },
72
  {
73
  "epoch": 15,
74
  "phase": "P1",
75
+ "loss": 0.082026
76
  },
77
  {
78
  "epoch": 16,
79
  "phase": "P2",
80
+ "loss": 0.122125
81
  },
82
  {
83
  "epoch": 17,
84
  "phase": "P2",
85
+ "loss": 0.118517
86
  },
87
  {
88
  "epoch": 18,
89
  "phase": "P2",
90
+ "loss": 0.115646
91
  },
92
  {
93
  "epoch": 19,
94
  "phase": "P2",
95
+ "loss": 0.170965
96
  },
97
  {
98
  "epoch": 20,
99
  "phase": "P2",
100
+ "loss": 0.163493,
101
+ "val_ssim": 0.8267
102
  },
103
  {
104
  "epoch": 21,
105
  "phase": "P2",
106
+ "loss": 0.159067
107
  },
108
  {
109
  "epoch": 22,
110
  "phase": "P2",
111
+ "loss": 0.237583
112
  },
113
  {
114
  "epoch": 23,
115
  "phase": "P2",
116
+ "loss": 0.229664
117
  },
118
  {
119
  "epoch": 24,
120
  "phase": "P2",
121
+ "loss": 0.221985
122
  },
123
  {
124
  "epoch": 25,
125
  "phase": "P2",
126
+ "loss": 0.215313,
127
+ "val_ssim": 0.8505
128
  },
129
  {
130
  "epoch": 26,
131
  "phase": "P2",
132
+ "loss": 0.208722
133
  },
134
  {
135
  "epoch": 27,
136
  "phase": "P2",
137
+ "loss": 0.203962
138
  },
139
  {
140
  "epoch": 28,
141
  "phase": "P2",
142
+ "loss": 0.198393
143
  },
144
  {
145
  "epoch": 29,
146
  "phase": "P2",
147
+ "loss": 0.194795
148
  },
149
  {
150
  "epoch": 30,
151
  "phase": "P2",
152
+ "loss": 0.191285,
153
+ "val_ssim": 0.8759
154
  },
155
  {
156
  "epoch": 31,
157
  "phase": "P2",
158
+ "loss": 0.187651
159
  },
160
  {
161
  "epoch": 32,
162
  "phase": "P2",
163
+ "loss": 0.184686
164
  },
165
  {
166
  "epoch": 33,
167
  "phase": "P2",
168
+ "loss": 0.180715
169
  },
170
  {
171
  "epoch": 34,
172
  "phase": "P2",
173
+ "loss": 0.176762
174
  },
175
  {
176
  "epoch": 35,
177
  "phase": "P2",
178
+ "loss": 0.172307,
179
+ "val_ssim": 0.8774
180
  },
181
  {
182
  "epoch": 36,
183
  "phase": "P2",
184
+ "loss": 0.167519
185
  },
186
  {
187
  "epoch": 37,
188
  "phase": "P2",
189
+ "loss": 0.162766
190
  },
191
  {
192
  "epoch": 38,
193
  "phase": "P2",
194
+ "loss": 0.157198
195
  },
196
  {
197
  "epoch": 39,
198
  "phase": "P2",
199
+ "loss": 0.152165
200
  },
201
  {
202
  "epoch": 40,
203
  "phase": "P2",
204
+ "loss": 0.147043,
205
+ "val_ssim": 0.886
206
  },
207
  {
208
  "epoch": 41,
209
  "phase": "P2",
210
+ "loss": 0.141957
211
  },
212
  {
213
  "epoch": 42,
214
  "phase": "P2",
215
+ "loss": 0.137481
216
  },
217
  {
218
  "epoch": 43,
219
  "phase": "P2",
220
+ "loss": 0.133861
221
  },
222
  {
223
  "epoch": 44,
224
  "phase": "P2",
225
+ "loss": 0.131363
226
  },
227
  {
228
  "epoch": 45,
229
  "phase": "P2",
230
+ "loss": 0.129965,
231
+ "val_ssim": 0.888
232
  }
233
  ]
model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4774304dae39b918b34dd4ededabc4a793ac7efdeb772a746587fad584ccfe83
3
- size 9089268
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e930868e7c620774f7f12cd9c2f056032024e50b17bd1405824daa5df80ecb6b
3
+ size 12361376
multiscale_flow_model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-Scale Flow-Warp-Mask U-Net: predicts flow at multiple resolutions."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class ResConvBlock(nn.Module):
8
+ def __init__(self, in_ch, out_ch):
9
+ super().__init__()
10
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
11
+ self.gn1 = nn.GroupNorm(min(8, out_ch), out_ch)
12
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
13
+ self.gn2 = nn.GroupNorm(min(8, out_ch), out_ch)
14
+ self.proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
15
+
16
+ def forward(self, x):
17
+ residual = self.proj(x)
18
+ x = F.silu(self.gn1(self.conv1(x)))
19
+ x = F.silu(self.gn2(self.conv2(x)))
20
+ return x + residual
21
+
22
+
23
+ class MultiScaleFlowUNet(nn.Module):
24
+ def __init__(self, in_channels=12, channels=[64, 128, 256]):
25
+ super().__init__()
26
+ # Encoder
27
+ self.encoders = nn.ModuleList()
28
+ self.pools = nn.ModuleList()
29
+ prev_ch = in_channels
30
+ for ch in channels:
31
+ self.encoders.append(ResConvBlock(prev_ch, ch))
32
+ self.pools.append(nn.MaxPool2d(2))
33
+ prev_ch = ch
34
+
35
+ # Bottleneck
36
+ self.bottleneck = ResConvBlock(channels[-1], channels[-1] * 2)
37
+
38
+ # Decoder
39
+ self.upconvs = nn.ModuleList()
40
+ self.decoders = nn.ModuleList()
41
+ dec_channels = list(reversed(channels))
42
+ prev_ch = channels[-1] * 2
43
+ for ch in dec_channels:
44
+ self.upconvs.append(nn.ConvTranspose2d(prev_ch, ch, 2, stride=2))
45
+ self.decoders.append(ResConvBlock(ch * 2, ch))
46
+ prev_ch = ch
47
+
48
+ # Multi-scale flow heads at each decoder level
49
+ # dec_channels = [256, 128, 64] (coarsest to finest)
50
+ # Level 0 (coarsest, 8x8): flow refinement
51
+ # Level 1 (16x16): flow refinement
52
+ # Level 2 (finest, 64x64): flow refinement + mask + gen_frame
53
+ self.flow_heads = nn.ModuleList()
54
+ for ch in dec_channels:
55
+ head = nn.Conv2d(ch, 2, 1)
56
+ nn.init.zeros_(head.weight)
57
+ nn.init.zeros_(head.bias)
58
+ self.flow_heads.append(head)
59
+
60
+ # Mask and generation heads only at finest level (level 2, 64x64)
61
+ self.mask_head = nn.Conv2d(dec_channels[-1], 1, 1)
62
+ nn.init.zeros_(self.mask_head.weight)
63
+ nn.init.zeros_(self.mask_head.bias)
64
+
65
+ self.gen_head = nn.Conv2d(dec_channels[-1], 3, 1)
66
+
67
+ def forward(self, x):
68
+ skips = []
69
+ for enc, pool in zip(self.encoders, self.pools):
70
+ x = enc(x)
71
+ skips.append(x)
72
+ x = pool(x)
73
+
74
+ x = self.bottleneck(x)
75
+
76
+ flows = [] # flow at each level, from coarsest to finest
77
+ for i, (upconv, dec, skip) in enumerate(zip(self.upconvs, self.decoders, reversed(skips))):
78
+ x = upconv(x)
79
+ x = torch.cat([x, skip], dim=1)
80
+ x = dec(x)
81
+
82
+ # Predict flow refinement at this level
83
+ flow_refine = self.flow_heads[i](x)
84
+
85
+ if i == 0:
86
+ # Coarsest level: just the flow refinement
87
+ flow = flow_refine
88
+ else:
89
+ # Upsample previous flow and add refinement
90
+ prev_flow_up = F.interpolate(flows[-1], scale_factor=2, mode='bilinear', align_corners=True)
91
+ # Scale flow values by 2 since coordinates double
92
+ prev_flow_up = prev_flow_up * 2
93
+ flow = prev_flow_up + flow_refine
94
+
95
+ flows.append(flow)
96
+
97
+ # Final level outputs
98
+ mask = torch.sigmoid(self.mask_head(x))
99
+ gen_frame = self.gen_head(x)
100
+
101
+ # flows[-1] is the finest (64x64) flow
102
+ return flows, mask, gen_frame
predict.py CHANGED
@@ -1,20 +1,20 @@
1
- """Prediction interface for Flow-Warp-Mask U-Net v9 with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
5
  import torch
6
 
7
  sys.path.insert(0, "/home/coder/code")
8
- from flowmask_model import FlowWarpMaskUNet
9
  from flownet_model import differentiable_warp
10
 
11
  CONTEXT_LEN = 4
12
- CHANNELS = [48, 96, 192]
13
 
14
 
15
  def load_model(model_dir: str):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model = FlowWarpMaskUNet(in_channels=12, channels=CHANNELS)
18
  model_path = os.path.join(model_dir, "model.pt")
19
  state_dict = torch.load(model_path, map_location=device, weights_only=True)
20
  state_dict = {k: v.float() for k, v in state_dict.items()}
@@ -25,7 +25,6 @@ def load_model(model_dir: str):
25
 
26
 
27
  def _prepare_input(context_frames):
28
- """Prepare 4-frame context tensor from numpy frames."""
29
  if len(context_frames) >= CONTEXT_LEN:
30
  frames = context_frames[-CONTEXT_LEN:]
31
  else:
@@ -34,19 +33,20 @@ def _prepare_input(context_frames):
34
  frames = np.concatenate([padding, context_frames], axis=0)
35
 
36
  frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
37
- frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64)
38
  return frames_t
39
 
40
 
41
  def _run_model(model, frames_t, device):
42
- """Run model on prepared frames, return prediction tensor."""
43
- last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64)
44
- inp = frames_t.reshape(1, -1, 64, 64) # (1, 12, 64, 64)
45
 
46
  inp = inp.to(device)
47
  last_frame = last_frame.to(device)
48
 
49
- flow, mask, gen_frame = model(inp)
 
 
50
  warped = differentiable_warp(last_frame, flow)
51
  pred = mask * warped + (1 - mask) * gen_frame
52
  pred = torch.clamp(pred, 0, 1)
@@ -64,9 +64,9 @@ def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
64
  pred1 = _run_model(model, frames_t, device)
65
 
66
  # TTA: horizontally flipped prediction
67
- frames_flipped = frames_t.flip(-1) # flip W dimension
68
  pred2_flipped = _run_model(model, frames_flipped, device)
69
- pred2 = pred2_flipped.flip(-1) # flip back
70
 
71
  # Average
72
  pred = (pred1 + pred2) / 2.0
 
1
+ """Prediction interface for Multi-Scale Flow-Warp-Mask U-Net v10 with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
5
  import torch
6
 
7
  sys.path.insert(0, "/home/coder/code")
8
+ from multiscale_flow_model import MultiScaleFlowUNet
9
  from flownet_model import differentiable_warp
10
 
11
  CONTEXT_LEN = 4
12
+ CHANNELS = [56, 112, 224]
13
 
14
 
15
  def load_model(model_dir: str):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = MultiScaleFlowUNet(in_channels=12, channels=CHANNELS)
18
  model_path = os.path.join(model_dir, "model.pt")
19
  state_dict = torch.load(model_path, map_location=device, weights_only=True)
20
  state_dict = {k: v.float() for k, v in state_dict.items()}
 
25
 
26
 
27
  def _prepare_input(context_frames):
 
28
  if len(context_frames) >= CONTEXT_LEN:
29
  frames = context_frames[-CONTEXT_LEN:]
30
  else:
 
33
  frames = np.concatenate([padding, context_frames], axis=0)
34
 
35
  frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
36
+ frames_t = frames_t.permute(0, 3, 1, 2)
37
  return frames_t
38
 
39
 
40
  def _run_model(model, frames_t, device):
41
+ last_frame = frames_t[-1].unsqueeze(0)
42
+ inp = frames_t.reshape(1, -1, 64, 64)
 
43
 
44
  inp = inp.to(device)
45
  last_frame = last_frame.to(device)
46
 
47
+ flows, mask, gen_frame = model(inp)
48
+ # Use finest flow (last element)
49
+ flow = flows[-1]
50
  warped = differentiable_warp(last_frame, flow)
51
  pred = mask * warped + (1 - mask) * gen_frame
52
  pred = torch.clamp(pred, 0, 1)
 
64
  pred1 = _run_model(model, frames_t, device)
65
 
66
  # TTA: horizontally flipped prediction
67
+ frames_flipped = frames_t.flip(-1)
68
  pred2_flipped = _run_model(model, frames_flipped, device)
69
+ pred2 = pred2_flipped.flip(-1)
70
 
71
  # Average
72
  pred = (pred1 + pred2) / 2.0
train.log CHANGED
@@ -1,63 +1,64 @@
1
- [23:37:08] Device: cuda
2
- [23:37:08] Model parameters: 4,534,230, channels=[48, 96, 192]
3
- [23:37:08] Phase 1: Single-step (15 epochs)
4
- [23:37:12] 45108 sequences
5
- [23:37:54] P1 Epoch 1/15 | loss=0.09371
6
- [23:38:34] P1 Epoch 2/15 | loss=0.07541
7
- [23:39:15] P1 Epoch 3/15 | loss=0.07040
8
- [23:39:56] P1 Epoch 4/15 | loss=0.06692
9
- [23:40:36] P1 Epoch 5/15 | loss=0.06405
10
- [23:41:17] P1 Epoch 6/15 | loss=0.06159
11
- [23:41:58] P1 Epoch 7/15 | loss=0.05899
12
- [23:42:40] P1 Epoch 8/15 | loss=0.05667
13
- [23:43:21] P1 Epoch 9/15 | loss=0.05422
14
- [23:44:01] P1 Epoch 10/15 | loss=0.05216
15
- [23:44:43] P1 Epoch 11/15 | loss=0.05005
16
- [23:45:23] P1 Epoch 12/15 | loss=0.04842
17
- [23:46:03] P1 Epoch 13/15 | loss=0.04701
18
- [23:46:45] P1 Epoch 14/15 | loss=0.04600
19
- [23:47:24] P1 Epoch 15/15 | loss=0.04540
20
- [23:47:24] Phase 2: Graduated AR (30 epochs)
21
- [23:49:24] P2 Epoch 1/30 (steps=2) | loss=0.07130 lr=0.000500
22
- [23:51:23] P2 Epoch 2/30 (steps=2) | loss=0.06985 lr=0.000500
23
- [23:53:18] P2 Epoch 3/30 (steps=2) | loss=0.06784 lr=0.000500
24
- [23:58:06] P2 Epoch 4/30 (steps=4) | loss=0.10299 lr=0.000500
25
- [00:02:59] P2 Epoch 5/30 (steps=4) | loss=0.09840 lr=0.000500
26
- [00:04:11] Val SSIM=0.8174 | {'pong': 0.7108, 'sonic': 0.8111, 'pole_position': 0.9302}
27
- [00:04:11] New best! SSIM=0.8174
28
- [00:09:08] P2 Epoch 6/30 (steps=4) | loss=0.09555 lr=0.000500
29
- [00:21:04] P2 Epoch 7/30 (steps=8) | loss=0.14229 lr=0.000500
30
- [00:32:46] P2 Epoch 8/30 (steps=8) | loss=0.13796 lr=0.000500
31
- [00:44:48] P2 Epoch 9/30 (steps=8) | loss=0.13384 lr=0.000500
32
- [00:57:15] P2 Epoch 10/30 (steps=8) | loss=0.12981 lr=0.000500
33
- [00:58:37] Val SSIM=0.8540 | {'pong': 0.8022, 'sonic': 0.8237, 'pole_position': 0.936}
34
- [00:58:37] New best! SSIM=0.8540
35
- [01:11:08] P2 Epoch 11/30 (steps=8) | loss=0.12605 lr=0.000500
36
- [01:23:41] P2 Epoch 12/30 (steps=8) | loss=0.12299 lr=0.000500
37
- [01:36:24] P2 Epoch 13/30 (steps=8) | loss=0.12048 lr=0.000500
38
- [01:48:54] P2 Epoch 14/30 (steps=8) | loss=0.11759 lr=0.000500
39
- [02:01:33] P2 Epoch 15/30 (steps=8) | loss=0.11546 lr=0.000500
40
- [02:02:55] Val SSIM=0.8644 | {'pong': 0.829, 'sonic': 0.8264, 'pole_position': 0.9378}
41
- [02:02:55] New best! SSIM=0.8644
42
- [02:15:31] P2 Epoch 16/30 (steps=8) | loss=0.11323 lr=0.000495
43
- [02:28:01] P2 Epoch 17/30 (steps=8) | loss=0.11117 lr=0.000478
44
- [02:40:14] P2 Epoch 18/30 (steps=8) | loss=0.10895 lr=0.000452
45
- [02:52:32] P2 Epoch 19/30 (steps=8) | loss=0.10613 lr=0.000417
46
- [03:05:05] P2 Epoch 20/30 (steps=8) | loss=0.10350 lr=0.000375
47
- [03:06:28] Val SSIM=0.8744 | {'pong': 0.8512, 'sonic': 0.8308, 'pole_position': 0.9413}
48
- [03:06:28] New best! SSIM=0.8744
49
- [03:19:19] P2 Epoch 21/30 (steps=8) | loss=0.10044 lr=0.000327
50
- [03:31:46] P2 Epoch 22/30 (steps=8) | loss=0.09729 lr=0.000276
51
- [03:44:25] P2 Epoch 23/30 (steps=8) | loss=0.09401 lr=0.000224
52
- [03:57:08] P2 Epoch 24/30 (steps=8) | loss=0.09080 lr=0.000173
53
- [04:09:49] P2 Epoch 25/30 (steps=8) | loss=0.08751 lr=0.000125
54
- [04:11:04] Val SSIM=0.8852 | {'pong': 0.8764, 'sonic': 0.8329, 'pole_position': 0.9462}
55
- [04:11:04] New best! SSIM=0.8852
56
- [04:23:43] P2 Epoch 26/30 (steps=8) | loss=0.08449 lr=0.000083
57
- [04:36:13] P2 Epoch 27/30 (steps=8) | loss=0.08166 lr=0.000048
58
- [04:48:48] P2 Epoch 28/30 (steps=8) | loss=0.07940 lr=0.000022
59
- [05:01:33] P2 Epoch 29/30 (steps=8) | loss=0.07777 lr=0.000010
60
- [05:14:14] P2 Epoch 30/30 (steps=8) | loss=0.07694 lr=0.000010
61
- [05:15:35] Val SSIM=0.8850 | {'pong': 0.8783, 'sonic': 0.8292, 'pole_position': 0.9474}
62
- [05:15:35] Experiment dir: 9.1 MB
63
- [05:15:35] Training complete. Best val SSIM: 0.8852
 
 
1
+ [05:19:55] Device: cuda
2
+ [05:19:55] Model parameters: 6,169,586, channels=[56, 112, 224]
3
+ [05:19:55] Phase 1: Single-step (15 epochs)
4
+ [05:19:59] 45108 sequences
5
+ [05:20:50] P1 Epoch 1/15 | loss=0.15205
6
+ [05:21:41] P1 Epoch 2/15 | loss=0.12668
7
+ [05:22:29] P1 Epoch 3/15 | loss=0.11989
8
+ [05:23:16] P1 Epoch 4/15 | loss=0.11480
9
+ [05:24:08] P1 Epoch 5/15 | loss=0.11061
10
+ [05:24:54] P1 Epoch 6/15 | loss=0.10702
11
+ [05:25:46] P1 Epoch 7/15 | loss=0.10340
12
+ [05:26:37] P1 Epoch 8/15 | loss=0.10001
13
+ [05:27:23] P1 Epoch 9/15 | loss=0.09637
14
+ [05:28:12] P1 Epoch 10/15 | loss=0.09296
15
+ [05:29:02] P1 Epoch 11/15 | loss=0.08999
16
+ [05:29:51] P1 Epoch 12/15 | loss=0.08714
17
+ [05:30:40] P1 Epoch 13/15 | loss=0.08477
18
+ [05:31:30] P1 Epoch 14/15 | loss=0.08311
19
+ [05:32:17] P1 Epoch 15/15 | loss=0.08203
20
+ [05:32:17] Phase 2: Graduated AR (30 epochs)
21
+ [05:34:32] P2 Epoch 1/30 (steps=2) | loss=0.12213 lr=0.000500
22
+ [05:36:49] P2 Epoch 2/30 (steps=2) | loss=0.11852 lr=0.000500
23
+ [05:38:58] P2 Epoch 3/30 (steps=2) | loss=0.11565 lr=0.000500
24
+ [05:44:14] P2 Epoch 4/30 (steps=4) | loss=0.17096 lr=0.000500
25
+ [05:49:31] P2 Epoch 5/30 (steps=4) | loss=0.16349 lr=0.000500
26
+ [05:50:57] Val SSIM=0.8267 | {'pong': 0.7258, 'sonic': 0.8199, 'pole_position': 0.9343}
27
+ [05:50:57] New best! SSIM=0.8267
28
+ [05:56:10] P2 Epoch 6/30 (steps=4) | loss=0.15907 lr=0.000500
29
+ [06:10:41] P2 Epoch 7/30 (steps=8) | loss=0.23758 lr=0.000500
30
+ [06:24:53] P2 Epoch 8/30 (steps=8) | loss=0.22966 lr=0.000500
31
+ [06:39:05] P2 Epoch 9/30 (steps=8) | loss=0.22198 lr=0.000500
32
+ [06:53:24] P2 Epoch 10/30 (steps=8) | loss=0.21531 lr=0.000500
33
+ [06:54:54] Val SSIM=0.8505 | {'pong': 0.7857, 'sonic': 0.8264, 'pole_position': 0.9393}
34
+ [06:54:54] New best! SSIM=0.8505
35
+ [07:09:06] P2 Epoch 11/30 (steps=8) | loss=0.20872 lr=0.000500
36
+ [07:23:28] P2 Epoch 12/30 (steps=8) | loss=0.20396 lr=0.000500
37
+ [07:37:46] P2 Epoch 13/30 (steps=8) | loss=0.19839 lr=0.000500
38
+ [07:52:00] P2 Epoch 14/30 (steps=8) | loss=0.19479 lr=0.000500
39
+ [08:06:23] P2 Epoch 15/30 (steps=8) | loss=0.19129 lr=0.000500
40
+ [08:07:46] Val SSIM=0.8759 | {'pong': 0.8609, 'sonic': 0.8246, 'pole_position': 0.9423}
41
+ [08:07:46] New best! SSIM=0.8759
42
+ [08:22:08] P2 Epoch 16/30 (steps=8) | loss=0.18765 lr=0.000495
43
+ [08:36:25] P2 Epoch 17/30 (steps=8) | loss=0.18469 lr=0.000478
44
+ [08:50:42] P2 Epoch 18/30 (steps=8) | loss=0.18071 lr=0.000452
45
+ [09:04:59] P2 Epoch 19/30 (steps=8) | loss=0.17676 lr=0.000417
46
+ [09:19:13] P2 Epoch 20/30 (steps=8) | loss=0.17231 lr=0.000375
47
+ [09:20:41] Val SSIM=0.8774 | {'pong': 0.8579, 'sonic': 0.8323, 'pole_position': 0.9419}
48
+ [09:20:41] New best! SSIM=0.8774
49
+ [09:35:11] P2 Epoch 21/30 (steps=8) | loss=0.16752 lr=0.000327
50
+ [09:49:35] P2 Epoch 22/30 (steps=8) | loss=0.16277 lr=0.000276
51
+ [10:03:57] P2 Epoch 23/30 (steps=8) | loss=0.15720 lr=0.000224
52
+ [10:18:08] P2 Epoch 24/30 (steps=8) | loss=0.15217 lr=0.000173
53
+ [10:32:53] P2 Epoch 25/30 (steps=8) | loss=0.14704 lr=0.000125
54
+ [10:34:17] Val SSIM=0.8860 | {'pong': 0.876, 'sonic': 0.8357, 'pole_position': 0.9463}
55
+ [10:34:17] New best! SSIM=0.8860
56
+ [10:49:35] P2 Epoch 26/30 (steps=8) | loss=0.14196 lr=0.000083
57
+ [11:04:55] P2 Epoch 27/30 (steps=8) | loss=0.13748 lr=0.000048
58
+ [11:20:12] P2 Epoch 28/30 (steps=8) | loss=0.13386 lr=0.000022
59
+ [11:35:30] P2 Epoch 29/30 (steps=8) | loss=0.13136 lr=0.000010
60
+ [11:49:54] P2 Epoch 30/30 (steps=8) | loss=0.12997 lr=0.000010
61
+ [11:51:09] Val SSIM=0.8880 | {'pong': 0.8813, 'sonic': 0.8349, 'pole_position': 0.9479}
62
+ [11:51:09] New best! SSIM=0.8880
63
+ [11:51:09] Experiment dir: 12.4 MB
64
+ [11:51:09] Training complete. Best val SSIM: 0.8880