ojaffe commited on
Commit
2fd1c51
·
verified ·
1 Parent(s): 3cac80b

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_pole_position.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8e0affcef8e533a29037751e27948a3eb0f2fda2792ce2b3dfc876cadb09e281
3
  size 2971526
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27d26875071b536cc75cac27a0840b50cd6c9a8e1956c94f1cd08feacc49621f
3
  size 2971526
model_pong.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab8070ddcde00333d7b52c89a0da9a61eece1e67c46163cd011ce4cd3c422f0c
3
- size 2436712
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6c8b9235347bea94e7e5f5f0f225d4c1dbd13a749d5e28920c75c91902ecb11
3
+ size 2435368
predict.py CHANGED
@@ -1,4 +1,4 @@
1
- """Direct 8-frame prediction for all games with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
@@ -11,12 +11,6 @@ CONTEXT_FRAMES = 8
11
  PRED_FRAMES = 8
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- GAME_CONFIGS = {
15
- "pong": {"enc_channels": (32, 64, 128), "bottleneck": 128},
16
- "sonic": {"enc_channels": (48, 96, 192), "bottleneck": 256},
17
- "pole_position": {"enc_channels": (32, 64, 128), "bottleneck": 192},
18
- }
19
-
20
 
21
  def detect_game(context_frames: np.ndarray) -> str:
22
  first_8 = context_frames[:CONTEXT_FRAMES]
@@ -32,55 +26,64 @@ def detect_game(context_frames: np.ndarray) -> str:
32
  return "sonic"
33
 
34
 
35
- class ModelCache:
36
- def __init__(self, models):
37
- self.models = models
38
- self.cached_predictions = None
39
  self.cache_step = 0
40
 
41
  def reset_cache(self):
42
- self.cached_predictions = None
43
  self.cache_step = 0
44
 
45
 
46
  def load_model(model_dir: str):
47
- models = {}
48
- for game, cfg in GAME_CONFIGS.items():
49
- model = UNet(in_channels=24, out_channels=24,
50
- enc_channels=cfg["enc_channels"],
51
- bottleneck_channels=cfg["bottleneck"],
52
- upsample_mode="bilinear").to(DEVICE)
53
- state_dict = torch.load(os.path.join(model_dir, f"model_{game}.pt"),
54
- map_location=DEVICE, weights_only=True)
55
- state_dict = {k: v.float() for k, v in state_dict.items()}
56
- model.load_state_dict(state_dict)
57
- model.eval()
58
- models[game] = model
59
- return ModelCache(models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  def _predict_8frames(model, context_tensor, last_tensor):
63
  output = model(context_tensor) # (1, 24, 64, 64)
64
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
65
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
66
- return torch.clamp(last_expanded + residuals, 0, 1) # (1, 8, 3, 64, 64)
67
-
68
-
69
- def predict_next_frame(cache, context_frames: np.ndarray) -> np.ndarray:
70
- n = len(context_frames)
71
 
72
- # If cache exists and context grew (AR rollout), return next cached frame
73
- if cache.cached_predictions is not None and n > CONTEXT_FRAMES and cache.cache_step < PRED_FRAMES:
74
- result = cache.cached_predictions[cache.cache_step]
75
- cache.cache_step += 1
76
- if cache.cache_step >= PRED_FRAMES:
77
- cache.reset_cache()
78
- return result
79
 
80
- # New window: predict all 8 frames
81
- cache.reset_cache()
82
  game = detect_game(context_frames)
83
- model = cache.models[game]
 
84
 
85
  if n < CONTEXT_FRAMES:
86
  padding = np.stack([context_frames[0]] * (CONTEXT_FRAMES - n), axis=0)
@@ -95,32 +98,50 @@ def predict_next_frame(cache, context_frames: np.ndarray) -> np.ndarray:
95
  last_frame = frames_norm[-1]
96
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
97
 
98
- with torch.no_grad():
99
- context_tensor = torch.from_numpy(context).to(DEVICE)
100
- last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
 
 
 
 
101
 
102
- predicted_orig = _predict_8frames(model, context_tensor, last_tensor)
 
 
 
103
 
104
- if game == "pong":
105
- # Pong: no TTA (asymmetric)
106
- predicted = predicted_orig
107
- else:
108
- # TTA: horizontal flip (dim=3 is width for (B, T, C, H, W) reshaped from (B, 24, H, W))
109
- # But we work on (1, 24, H, W) context - flip along dim 3 (width)
 
 
 
 
 
 
 
 
 
 
 
 
110
  context_flipped = torch.flip(context_tensor, dims=[3])
111
  last_flipped = torch.flip(last_tensor, dims=[3])
112
  predicted_flipped = _predict_8frames(model, context_flipped, last_flipped)
113
- # Flip back: predicted_flipped is (1, 8, 3, H, W), flip width dim=4
114
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
115
  predicted = (predicted_orig + predicted_flipped) / 2.0
116
 
117
- predicted_np = predicted[0].cpu().numpy() # (8, 3, 64, 64)
118
- cache.cached_predictions = []
119
- for i in range(PRED_FRAMES):
120
- frame = np.transpose(predicted_np[i], (1, 2, 0))
121
- frame = (frame * 255).clip(0, 255).astype(np.uint8)
122
- cache.cached_predictions.append(frame)
123
 
124
- result = cache.cached_predictions[cache.cache_step]
125
- cache.cache_step += 1
126
- return result
 
1
+ """Hybrid v5: Best per-game models. AR for Pong, direct 8-frame for Sonic/PP with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
 
11
  PRED_FRAMES = 8
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
 
 
 
 
 
14
 
15
  def detect_game(context_frames: np.ndarray) -> str:
16
  first_8 = context_frames[:CONTEXT_FRAMES]
 
26
  return "sonic"
27
 
28
 
29
+ class HybridModels:
30
+ def __init__(self):
31
+ self.models = {}
32
+ self.direct_cache = None
33
  self.cache_step = 0
34
 
35
  def reset_cache(self):
36
+ self.direct_cache = None
37
  self.cache_step = 0
38
 
39
 
40
  def load_model(model_dir: str):
41
+ hybrid = HybridModels()
42
+
43
+ # Pong: AR model (3 outputs) from pergame-models
44
+ pong = UNet(in_channels=24, out_channels=3,
45
+ enc_channels=(32, 64, 128), bottleneck_channels=128,
46
+ upsample_mode="bilinear").to(DEVICE)
47
+ sd = torch.load(os.path.join(model_dir, "model_pong.pt"),
48
+ map_location=DEVICE, weights_only=True)
49
+ pong.load_state_dict({k: v.float() for k, v in sd.items()})
50
+ pong.eval()
51
+ hybrid.models["pong"] = pong
52
+
53
+ # Sonic: direct 8-frame model (24 outputs) from direct-improved
54
+ sonic = UNet(in_channels=24, out_channels=24,
55
+ enc_channels=(48, 96, 192), bottleneck_channels=256,
56
+ upsample_mode="bilinear").to(DEVICE)
57
+ sd = torch.load(os.path.join(model_dir, "model_sonic.pt"),
58
+ map_location=DEVICE, weights_only=True)
59
+ sonic.load_state_dict({k: v.float() for k, v in sd.items()})
60
+ sonic.eval()
61
+ hybrid.models["sonic"] = sonic
62
+
63
+ # PP: direct 8-frame model (24 outputs) from direct-8frame
64
+ pp = UNet(in_channels=24, out_channels=24,
65
+ enc_channels=(32, 64, 128), bottleneck_channels=192,
66
+ upsample_mode="bilinear").to(DEVICE)
67
+ sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
68
+ map_location=DEVICE, weights_only=True)
69
+ pp.load_state_dict({k: v.float() for k, v in sd.items()})
70
+ pp.eval()
71
+ hybrid.models["pole_position"] = pp
72
+
73
+ return hybrid
74
 
75
 
76
  def _predict_8frames(model, context_tensor, last_tensor):
77
  output = model(context_tensor) # (1, 24, 64, 64)
78
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
79
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
80
+ return torch.clamp(last_expanded + residuals, 0, 1)
 
 
 
 
81
 
 
 
 
 
 
 
 
82
 
83
+ def predict_next_frame(hybrid, context_frames: np.ndarray) -> np.ndarray:
 
84
  game = detect_game(context_frames)
85
+ model = hybrid.models[game]
86
+ n = len(context_frames)
87
 
88
  if n < CONTEXT_FRAMES:
89
  padding = np.stack([context_frames[0]] * (CONTEXT_FRAMES - n), axis=0)
 
98
  last_frame = frames_norm[-1]
99
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
100
 
101
+ if game == "pong":
102
+ # AR prediction for Pong (no TTA, no caching)
103
+ with torch.no_grad():
104
+ context_tensor = torch.from_numpy(context).to(DEVICE)
105
+ last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
106
+ residual = model(context_tensor)
107
+ predicted = torch.clamp(last_tensor + residual, 0, 1)
108
 
109
+ predicted_np = predicted[0].cpu().numpy()
110
+ predicted_np = np.transpose(predicted_np, (1, 2, 0))
111
+ predicted_np = (predicted_np * 255).clip(0, 255).astype(np.uint8)
112
+ return predicted_np
113
 
114
+ else:
115
+ # Direct 8-frame for Sonic and PP with caching
116
+ if hybrid.direct_cache is not None and n > CONTEXT_FRAMES and hybrid.cache_step < PRED_FRAMES:
117
+ result = hybrid.direct_cache[hybrid.cache_step]
118
+ hybrid.cache_step += 1
119
+ if hybrid.cache_step >= PRED_FRAMES:
120
+ hybrid.reset_cache()
121
+ return result
122
+
123
+ # New window: predict all 8 frames with TTA
124
+ hybrid.reset_cache()
125
+ with torch.no_grad():
126
+ context_tensor = torch.from_numpy(context).to(DEVICE)
127
+ last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
128
+
129
+ predicted_orig = _predict_8frames(model, context_tensor, last_tensor)
130
+
131
+ # TTA: horizontal flip
132
  context_flipped = torch.flip(context_tensor, dims=[3])
133
  last_flipped = torch.flip(last_tensor, dims=[3])
134
  predicted_flipped = _predict_8frames(model, context_flipped, last_flipped)
 
135
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
136
  predicted = (predicted_orig + predicted_flipped) / 2.0
137
 
138
+ predicted_np = predicted[0].cpu().numpy() # (8, 3, 64, 64)
139
+ hybrid.direct_cache = []
140
+ for i in range(PRED_FRAMES):
141
+ frame = np.transpose(predicted_np[i], (1, 2, 0))
142
+ frame = (frame * 255).clip(0, 255).astype(np.uint8)
143
+ hybrid.direct_cache.append(frame)
144
 
145
+ result = hybrid.direct_cache[hybrid.cache_step]
146
+ hybrid.cache_step += 1
147
+ return result