ojaffe commited on
Commit
99c8044
·
verified ·
1 Parent(s): 58e2e4e

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. pole_model.pt +3 -0
  2. pong_model.pt +3 -0
  3. predict.py +62 -22
  4. sonic_model.pt +3 -0
  5. train.log +131 -24
pole_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef5ce878dda321107ca4ec68285d326086badfcd5216f4d0654cb99c8fd3c4b0
3
+ size 2298466
pong_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49a4ad30877111b9f170b9aac88391c3ed9c9b0be784b6d2ad3882edf004b94c
3
+ size 4062370
predict.py CHANGED
@@ -1,30 +1,63 @@
1
- """Prediction interface for Multi-Scale Flow-Warp-Mask U-Net v10 with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
5
  import torch
6
 
7
  sys.path.insert(0, "/home/coder/code")
8
- from multiscale_flow_model import MultiScaleFlowUNet
9
  from flownet_model import differentiable_warp
10
 
11
  CONTEXT_LEN = 4
12
- CHANNELS = [56, 112, 224]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def load_model(model_dir: str):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model = MultiScaleFlowUNet(in_channels=12, channels=CHANNELS)
18
- model_path = os.path.join(model_dir, "model.pt")
19
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
20
- state_dict = {k: v.float() for k, v in state_dict.items()}
21
- model.load_state_dict(state_dict)
22
- model.to(device)
23
- model.eval()
24
- return {"model": model, "device": device}
25
-
26
-
27
- def _prepare_input(context_frames):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if len(context_frames) >= CONTEXT_LEN:
29
  frames = context_frames[-CONTEXT_LEN:]
30
  else:
@@ -33,20 +66,19 @@ def _prepare_input(context_frames):
33
  frames = np.concatenate([padding, context_frames], axis=0)
34
 
35
  frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
36
- frames_t = frames_t.permute(0, 3, 1, 2)
37
  return frames_t
38
 
39
 
40
  def _run_model(model, frames_t, device):
41
- last_frame = frames_t[-1].unsqueeze(0)
42
- inp = frames_t.reshape(1, -1, 64, 64)
 
43
 
44
  inp = inp.to(device)
45
  last_frame = last_frame.to(device)
46
 
47
- flows, mask, gen_frame = model(inp)
48
- # Use finest flow (last element)
49
- flow = flows[-1]
50
  warped = differentiable_warp(last_frame, flow)
51
  pred = mask * warped + (1 - mask) * gen_frame
52
  pred = torch.clamp(pred, 0, 1)
@@ -54,10 +86,13 @@ def _run_model(model, frames_t, device):
54
 
55
 
56
  def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
57
- model = model_dict["model"]
58
  device = model_dict["device"]
59
 
60
- frames_t = _prepare_input(context_frames)
 
 
 
61
 
62
  with torch.no_grad():
63
  # Original prediction
@@ -73,4 +108,9 @@ def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
73
 
74
  pred = pred[0].cpu().permute(1, 2, 0).numpy()
75
  pred = (pred * 255).clip(0, 255).astype(np.uint8)
 
 
 
 
 
76
  return pred
 
1
+ """Prediction interface for per-game Flow-Warp-Mask models v12 with motion encoding + TTA."""
2
  import sys
3
  import os
4
  import numpy as np
5
  import torch
6
 
7
  sys.path.insert(0, "/home/coder/code")
8
+ from flowmask_model import FlowWarpMaskUNet
9
  from flownet_model import differentiable_warp
10
 
11
  CONTEXT_LEN = 4
12
+ GAME_CONFIGS = {
13
+ "pong": {"channels": [32, 64, 128], "file": "pong_model.pt"},
14
+ "sonic": {"channels": [40, 80, 160], "file": "sonic_model.pt"},
15
+ "pole_position": {"channels": [24, 48, 96], "file": "pole_model.pt"},
16
+ }
17
+
18
+
19
+ def detect_game(context_frames):
20
+ mean_val = context_frames.mean()
21
+ if mean_val < 10:
22
+ return "pong"
23
+ elif mean_val < 80:
24
+ return "sonic"
25
+ else:
26
+ return "pole_position"
27
 
28
 
29
  def load_model(model_dir: str):
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ models = {}
32
+ for game, cfg in GAME_CONFIGS.items():
33
+ model = FlowWarpMaskUNet(in_channels=12, channels=cfg["channels"])
34
+ model_path = os.path.join(model_dir, cfg["file"])
35
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
36
+ state_dict = {k: v.float() for k, v in state_dict.items()}
37
+ model.load_state_dict(state_dict)
38
+ model.to(device)
39
+ model.eval()
40
+ models[game] = model
41
+ return {"models": models, "device": device}
42
+
43
+
44
+ def _make_motion_input(frames):
45
+ """Create motion encoding: last frame (3ch) + 3 pairwise diffs (9ch) = 12ch.
46
+
47
+ Args:
48
+ frames: (4, 3, H, W) tensor in [0,1]
49
+ Returns:
50
+ (12, H, W) tensor
51
+ """
52
+ last = frames[-1] # (3, H, W)
53
+ diff1 = frames[-1] - frames[-2] # most recent motion
54
+ diff2 = frames[-2] - frames[-3] # previous motion
55
+ diff3 = frames[-3] - frames[-4] # older motion
56
+ return torch.cat([last, diff1, diff2, diff3], dim=0) # (12, H, W)
57
+
58
+
59
+ def _prepare_context(context_frames):
60
+ """Prepare 4-frame context from numpy frames."""
61
  if len(context_frames) >= CONTEXT_LEN:
62
  frames = context_frames[-CONTEXT_LEN:]
63
  else:
 
66
  frames = np.concatenate([padding, context_frames], axis=0)
67
 
68
  frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
69
+ frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64)
70
  return frames_t
71
 
72
 
73
  def _run_model(model, frames_t, device):
74
+ """Run model with motion encoding input."""
75
+ last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64)
76
+ inp = _make_motion_input(frames_t).unsqueeze(0) # (1, 12, 64, 64)
77
 
78
  inp = inp.to(device)
79
  last_frame = last_frame.to(device)
80
 
81
+ flow, mask, gen_frame = model(inp)
 
 
82
  warped = differentiable_warp(last_frame, flow)
83
  pred = mask * warped + (1 - mask) * gen_frame
84
  pred = torch.clamp(pred, 0, 1)
 
86
 
87
 
88
  def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
89
+ models = model_dict["models"]
90
  device = model_dict["device"]
91
 
92
+ game = detect_game(context_frames)
93
+ model = models[game]
94
+
95
+ frames_t = _prepare_context(context_frames)
96
 
97
  with torch.no_grad():
98
  # Original prediction
 
108
 
109
  pred = pred[0].cpu().permute(1, 2, 0).numpy()
110
  pred = (pred * 255).clip(0, 255).astype(np.uint8)
111
+
112
+ # Post-processing for Pong: clamp dark pixels to pure black
113
+ if game == "pong":
114
+ pred[pred < 5] = 0
115
+
116
  return pred
sonic_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98a4cfecb0f8c6b864ef908feead0169925e037e437ba3bddc04380290c48bdb
3
+ size 6326456
train.log CHANGED
@@ -1,24 +1,131 @@
1
- [12:02:26] Device: cuda
2
- [12:02:27] Loaded v10 weights from /home/coder/experiments/2026-04-14-080000-multiscale-flow-v10
3
- [12:02:27] Model parameters: 6,169,586, channels=[56, 112, 224]
4
- [12:02:27] Fine-tune: 10 epochs of 8-step AR with pure SSIM loss
5
- [12:02:31] 43855 sequences
6
- [12:15:01] Epoch 1/10 | loss=0.09345 lr=0.0000098
7
- [12:27:32] Epoch 2/10 | loss=0.09245 lr=0.0000091
8
- [12:29:02] Val SSIM=0.8883 | {'pong': 0.8811, 'sonic': 0.8354, 'pole_position': 0.9485}
9
- [12:29:02] New best! SSIM=0.8883
10
- [12:41:38] Epoch 3/10 | loss=0.09166 lr=0.0000081
11
- [12:54:16] Epoch 4/10 | loss=0.09095 lr=0.0000069
12
- [12:55:36] Val SSIM=0.8887 | {'pong': 0.8824, 'sonic': 0.8352, 'pole_position': 0.9486}
13
- [12:55:36] New best! SSIM=0.8887
14
- [13:08:12] Epoch 5/10 | loss=0.09033 lr=0.0000055
15
- [13:20:56] Epoch 6/10 | loss=0.08982 lr=0.0000041
16
- [13:22:21] Val SSIM=0.8885 | {'pong': 0.8824, 'sonic': 0.8347, 'pole_position': 0.9483}
17
- [13:35:14] Epoch 7/10 | loss=0.08941 lr=0.0000029
18
- [13:48:14] Epoch 8/10 | loss=0.08911 lr=0.0000019
19
- [13:49:31] Val SSIM=0.8883 | {'pong': 0.8822, 'sonic': 0.8344, 'pole_position': 0.9484}
20
- [14:02:17] Epoch 9/10 | loss=0.08888 lr=0.0000012
21
- [14:14:47] Epoch 10/10 | loss=0.08874 lr=0.0000010
22
- [14:16:07] Val SSIM=0.8881 | {'pong': 0.8815, 'sonic': 0.8343, 'pole_position': 0.9485}
23
- [14:16:07] Experiment dir: 12.4 MB
24
- [14:16:07] Training complete. Best val SSIM: 0.8887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [14:35:51] Device: cuda
2
+ [14:35:51]
3
+ === Training pong ([32, 64, 128]) ===
4
+ [14:35:51] 2,018,278 parameters
5
+ [14:35:51] Phase 1: 10 epochs single-step
6
+ [14:35:51] 8568 sequences
7
+ [14:36:00] P1 pong Epoch 1/10 | loss=0.14558
8
+ [14:36:08] P1 pong Epoch 2/10 | loss=0.10721
9
+ [14:36:17] P1 pong Epoch 3/10 | loss=0.09795
10
+ [14:36:25] P1 pong Epoch 4/10 | loss=0.08996
11
+ [14:36:33] P1 pong Epoch 5/10 | loss=0.08384
12
+ [14:36:41] P1 pong Epoch 6/10 | loss=0.07755
13
+ [14:36:49] P1 pong Epoch 7/10 | loss=0.06995
14
+ [14:36:57] P1 pong Epoch 8/10 | loss=0.06272
15
+ [14:37:05] P1 pong Epoch 9/10 | loss=0.05640
16
+ [14:37:13] P1 pong Epoch 10/10 | loss=0.05177
17
+ [14:37:13] Phase 2: 25 epochs graduated AR
18
+ [14:37:37] P2 pong Epoch 1/25 (steps=2) | loss=0.09787 lr=0.000500
19
+ [14:37:59] P2 pong Epoch 2/25 (steps=2) | loss=0.08854 lr=0.000500
20
+ [14:38:21] P2 pong Epoch 3/25 (steps=2) | loss=0.08343 lr=0.000500
21
+ [14:39:15] P2 pong Epoch 4/25 (steps=4) | loss=0.13928 lr=0.000500
22
+ [14:40:08] P2 pong Epoch 5/25 (steps=4) | loss=0.12631 lr=0.000500
23
+ [14:41:04] P2 pong Epoch 6/25 (steps=4) | loss=0.11644 lr=0.000500
24
+ [14:43:21] P2 pong Epoch 7/25 (steps=8) | loss=0.18012 lr=0.000500
25
+ [14:45:38] P2 pong Epoch 8/25 (steps=8) | loss=0.17484 lr=0.000500
26
+ [14:47:57] P2 pong Epoch 9/25 (steps=8) | loss=0.16717 lr=0.000500
27
+ [14:50:15] P2 pong Epoch 10/25 (steps=8) | loss=0.15650 lr=0.000500
28
+ [14:52:31] P2 pong Epoch 11/25 (steps=8) | loss=0.14624 lr=0.000500
29
+ [14:54:46] P2 pong Epoch 12/25 (steps=8) | loss=0.13932 lr=0.000500
30
+ [14:57:01] P2 pong Epoch 13/25 (steps=8) | loss=0.12899 lr=0.000493
31
+ [14:59:17] P2 pong Epoch 14/25 (steps=8) | loss=0.11960 lr=0.000471
32
+ [15:01:35] P2 pong Epoch 15/25 (steps=8) | loss=0.10872 lr=0.000437
33
+ [15:03:52] P2 pong Epoch 16/25 (steps=8) | loss=0.09965 lr=0.000392
34
+ [15:06:07] P2 pong Epoch 17/25 (steps=8) | loss=0.08785 lr=0.000339
35
+ [15:08:27] P2 pong Epoch 18/25 (steps=8) | loss=0.07890 lr=0.000280
36
+ [15:10:44] P2 pong Epoch 19/25 (steps=8) | loss=0.06718 lr=0.000220
37
+ [15:13:01] P2 pong Epoch 20/25 (steps=8) | loss=0.06123 lr=0.000161
38
+ [15:15:20] P2 pong Epoch 21/25 (steps=8) | loss=0.05374 lr=0.000108
39
+ [15:17:40] P2 pong Epoch 22/25 (steps=8) | loss=0.04863 lr=0.000063
40
+ [15:19:57] P2 pong Epoch 23/25 (steps=8) | loss=0.04435 lr=0.000029
41
+ [15:22:13] P2 pong Epoch 24/25 (steps=8) | loss=0.04174 lr=0.000010
42
+ [15:24:31] P2 pong Epoch 25/25 (steps=8) | loss=0.04022 lr=0.000010
43
+ [15:24:31] pong training complete.
44
+ [15:24:31]
45
+ === Training sonic ([40, 80, 160]) ===
46
+ [15:24:31] 3,150,686 parameters
47
+ [15:24:31] Phase 1: 10 epochs single-step
48
+ [15:24:34] 32256 sequences
49
+ [15:25:03] P1 sonic Epoch 1/10 | loss=0.08400
50
+ [15:25:34] P1 sonic Epoch 2/10 | loss=0.06966
51
+ [15:26:03] P1 sonic Epoch 3/10 | loss=0.06589
52
+ [15:26:34] P1 sonic Epoch 4/10 | loss=0.06327
53
+ [15:27:03] P1 sonic Epoch 5/10 | loss=0.06111
54
+ [15:27:33] P1 sonic Epoch 6/10 | loss=0.05881
55
+ [15:28:03] P1 sonic Epoch 7/10 | loss=0.05682
56
+ [15:28:33] P1 sonic Epoch 8/10 | loss=0.05514
57
+ [15:29:02] P1 sonic Epoch 9/10 | loss=0.05358
58
+ [15:29:32] P1 sonic Epoch 10/10 | loss=0.05256
59
+ [15:29:32] Phase 2: 25 epochs graduated AR
60
+ [15:30:57] P2 sonic Epoch 1/25 (steps=2) | loss=0.07446 lr=0.000500
61
+ [15:32:15] P2 sonic Epoch 2/25 (steps=2) | loss=0.07291 lr=0.000500
62
+ [15:33:41] P2 sonic Epoch 3/25 (steps=2) | loss=0.07128 lr=0.000500
63
+ [15:37:15] P2 sonic Epoch 4/25 (steps=4) | loss=0.10220 lr=0.000500
64
+ [15:40:50] P2 sonic Epoch 5/25 (steps=4) | loss=0.09976 lr=0.000500
65
+ [15:44:24] P2 sonic Epoch 6/25 (steps=4) | loss=0.09779 lr=0.000500
66
+ [15:53:05] P2 sonic Epoch 7/25 (steps=8) | loss=0.14037 lr=0.000500
67
+ [16:01:41] P2 sonic Epoch 8/25 (steps=8) | loss=0.13753 lr=0.000500
68
+ [16:10:26] P2 sonic Epoch 9/25 (steps=8) | loss=0.13476 lr=0.000500
69
+ [16:19:08] P2 sonic Epoch 10/25 (steps=8) | loss=0.13232 lr=0.000500
70
+ [16:28:05] P2 sonic Epoch 11/25 (steps=8) | loss=0.13010 lr=0.000500
71
+ [16:37:18] P2 sonic Epoch 12/25 (steps=8) | loss=0.12790 lr=0.000500
72
+ [16:46:19] P2 sonic Epoch 13/25 (steps=8) | loss=0.12592 lr=0.000493
73
+ [16:55:21] P2 sonic Epoch 14/25 (steps=8) | loss=0.12408 lr=0.000471
74
+ [17:04:34] P2 sonic Epoch 15/25 (steps=8) | loss=0.12210 lr=0.000437
75
+ [17:13:54] P2 sonic Epoch 16/25 (steps=8) | loss=0.11900 lr=0.000392
76
+ [17:23:04] P2 sonic Epoch 17/25 (steps=8) | loss=0.11596 lr=0.000339
77
+ [17:32:08] P2 sonic Epoch 18/25 (steps=8) | loss=0.11287 lr=0.000280
78
+ [17:41:13] P2 sonic Epoch 19/25 (steps=8) | loss=0.10939 lr=0.000220
79
+ [17:50:18] P2 sonic Epoch 20/25 (steps=8) | loss=0.10548 lr=0.000161
80
+ [17:59:23] P2 sonic Epoch 21/25 (steps=8) | loss=0.10183 lr=0.000108
81
+ [18:08:26] P2 sonic Epoch 22/25 (steps=8) | loss=0.09841 lr=0.000063
82
+ [18:17:35] P2 sonic Epoch 23/25 (steps=8) | loss=0.09526 lr=0.000029
83
+ [18:26:41] P2 sonic Epoch 24/25 (steps=8) | loss=0.09337 lr=0.000010
84
+ [18:35:42] P2 sonic Epoch 25/25 (steps=8) | loss=0.09193 lr=0.000010
85
+ [18:35:42] sonic training complete.
86
+ [18:35:42]
87
+ === Training pole_position ([24, 48, 96]) ===
88
+ [18:35:42] 1,137,006 parameters
89
+ [18:35:42] Phase 1: 10 epochs single-step
90
+ [18:35:42] 4284 sequences
91
+ [18:35:46] P1 pole_position Epoch 1/10 | loss=0.05831
92
+ [18:35:50] P1 pole_position Epoch 2/10 | loss=0.03691
93
+ [18:35:54] P1 pole_position Epoch 3/10 | loss=0.03064
94
+ [18:35:57] P1 pole_position Epoch 4/10 | loss=0.02707
95
+ [18:36:00] P1 pole_position Epoch 5/10 | loss=0.02428
96
+ [18:36:04] P1 pole_position Epoch 6/10 | loss=0.02271
97
+ [18:36:07] P1 pole_position Epoch 7/10 | loss=0.02128
98
+ [18:36:11] P1 pole_position Epoch 8/10 | loss=0.02013
99
+ [18:36:15] P1 pole_position Epoch 9/10 | loss=0.01936
100
+ [18:36:19] P1 pole_position Epoch 10/10 | loss=0.01879
101
+ [18:36:19] Phase 2: 25 epochs graduated AR
102
+ [18:36:31] P2 pole_position Epoch 1/25 (steps=2) | loss=0.02742 lr=0.000500
103
+ [18:36:42] P2 pole_position Epoch 2/25 (steps=2) | loss=0.02621 lr=0.000500
104
+ [18:36:54] P2 pole_position Epoch 3/25 (steps=2) | loss=0.02502 lr=0.000500
105
+ [18:37:22] P2 pole_position Epoch 4/25 (steps=4) | loss=0.03779 lr=0.000500
106
+ [18:37:51] P2 pole_position Epoch 5/25 (steps=4) | loss=0.03543 lr=0.000500
107
+ [18:38:19] P2 pole_position Epoch 6/25 (steps=4) | loss=0.03421 lr=0.000500
108
+ [18:39:31] P2 pole_position Epoch 7/25 (steps=8) | loss=0.05263 lr=0.000500
109
+ [18:40:42] P2 pole_position Epoch 8/25 (steps=8) | loss=0.05159 lr=0.000500
110
+ [18:41:53] P2 pole_position Epoch 9/25 (steps=8) | loss=0.04987 lr=0.000500
111
+ [18:43:05] P2 pole_position Epoch 10/25 (steps=8) | loss=0.04848 lr=0.000500
112
+ [18:44:17] P2 pole_position Epoch 11/25 (steps=8) | loss=0.04744 lr=0.000500
113
+ [18:45:30] P2 pole_position Epoch 12/25 (steps=8) | loss=0.04603 lr=0.000500
114
+ [18:46:42] P2 pole_position Epoch 13/25 (steps=8) | loss=0.04495 lr=0.000493
115
+ [18:47:54] P2 pole_position Epoch 14/25 (steps=8) | loss=0.04383 lr=0.000471
116
+ [18:49:05] P2 pole_position Epoch 15/25 (steps=8) | loss=0.04233 lr=0.000437
117
+ [18:50:18] P2 pole_position Epoch 16/25 (steps=8) | loss=0.04089 lr=0.000392
118
+ [18:51:30] P2 pole_position Epoch 17/25 (steps=8) | loss=0.03911 lr=0.000339
119
+ [18:52:43] P2 pole_position Epoch 18/25 (steps=8) | loss=0.03667 lr=0.000280
120
+ [18:53:55] P2 pole_position Epoch 19/25 (steps=8) | loss=0.03494 lr=0.000220
121
+ [18:55:06] P2 pole_position Epoch 20/25 (steps=8) | loss=0.03271 lr=0.000161
122
+ [18:56:18] P2 pole_position Epoch 21/25 (steps=8) | loss=0.03049 lr=0.000108
123
+ [18:57:31] P2 pole_position Epoch 22/25 (steps=8) | loss=0.02831 lr=0.000063
124
+ [18:58:44] P2 pole_position Epoch 23/25 (steps=8) | loss=0.02653 lr=0.000029
125
+ [18:59:58] P2 pole_position Epoch 24/25 (steps=8) | loss=0.02527 lr=0.000010
126
+ [19:01:11] P2 pole_position Epoch 25/25 (steps=8) | loss=0.02460 lr=0.000010
127
+ [19:01:11] pole_position training complete.
128
+ [19:01:11] Evaluating...
129
+ [19:02:25] Val SSIM=0.8626 | {'pong': 0.862, 'sonic': 0.7822, 'pole_position': 0.9435}
130
+ [19:02:25] Experiment dir: 12.7 MB
131
+ [19:02:25] Training complete.