ojaffe commited on
Commit
2e7cf8e
·
verified ·
1 Parent(s): c50675d

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_pong.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d6c8b9235347bea94e7e5f5f0f225d4c1dbd13a749d5e28920c75c91902ecb11
3
- size 2435368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1a440d1801503eb7e00e8a6ce30b8f43058816440d98506b3e2c8629ca2eeff
3
+ size 2436712
model_sonic.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:413d9fcfa15f30c74cdfda5f7d7c9dba8958fe027dfc09de563e6209c78378f5
3
- size 6180566
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9035098568ea4789c5dda58d685af07b4b5a0cdf300848f79ed6d96ad901da34
3
+ size 6182614
predict.py CHANGED
@@ -1,4 +1,4 @@
1
- """Hybrid v3: AR for Pong/Sonic, direct 8-frame for PP."""
2
  import sys
3
  import os
4
  import numpy as np
@@ -11,6 +11,12 @@ CONTEXT_FRAMES = 8
11
  PRED_FRAMES = 8
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
 
 
 
 
 
14
 
15
  def detect_game(context_frames: np.ndarray) -> str:
16
  first_8 = context_frames[:CONTEXT_FRAMES]
@@ -26,58 +32,56 @@ def detect_game(context_frames: np.ndarray) -> str:
26
  return "sonic"
27
 
28
 
29
- class HybridModels:
30
- def __init__(self):
31
- self.models = {}
32
- self.pp_cache = None
33
- self.pp_cache_step = 0
34
 
35
- def reset_pp_cache(self):
36
- self.pp_cache = None
37
- self.pp_cache_step = 0
38
 
39
 
40
  def load_model(model_dir: str):
41
- hybrid = HybridModels()
42
-
43
- # Pong: standard AR model (3 outputs)
44
- pong = UNet(in_channels=24, out_channels=3,
45
- enc_channels=(32, 64, 128), bottleneck_channels=128,
46
- upsample_mode="bilinear").to(DEVICE)
47
- sd = torch.load(os.path.join(model_dir, "model_pong.pt"),
48
- map_location=DEVICE, weights_only=True)
49
- pong.load_state_dict({k: v.float() for k, v in sd.items()})
50
- pong.eval()
51
- hybrid.models["pong"] = pong
52
-
53
- # Sonic: standard AR model (3 outputs)
54
- sonic = UNet(in_channels=24, out_channels=3,
55
- enc_channels=(48, 96, 192), bottleneck_channels=256,
56
- upsample_mode="bilinear").to(DEVICE)
57
- sd = torch.load(os.path.join(model_dir, "model_sonic.pt"),
58
- map_location=DEVICE, weights_only=True)
59
- sonic.load_state_dict({k: v.float() for k, v in sd.items()})
60
- sonic.eval()
61
- hybrid.models["sonic"] = sonic
62
-
63
- # PP: direct 8-frame model (24 outputs)
64
- pp = UNet(in_channels=24, out_channels=24,
65
- enc_channels=(32, 64, 128), bottleneck_channels=192,
66
- upsample_mode="bilinear").to(DEVICE)
67
- sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
68
- map_location=DEVICE, weights_only=True)
69
- pp.load_state_dict({k: v.float() for k, v in sd.items()})
70
- pp.eval()
71
- hybrid.models["pole_position"] = pp
72
-
73
- return hybrid
74
-
75
-
76
- def predict_next_frame(hybrid, context_frames: np.ndarray) -> np.ndarray:
77
- game = detect_game(context_frames)
78
- model = hybrid.models[game]
79
  n = len(context_frames)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if n < CONTEXT_FRAMES:
82
  padding = np.stack([context_frames[0]] * (CONTEXT_FRAMES - n), axis=0)
83
  frames = np.concatenate([padding, context_frames], axis=0)
@@ -91,57 +95,32 @@ def predict_next_frame(hybrid, context_frames: np.ndarray) -> np.ndarray:
91
  last_frame = frames_norm[-1]
92
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
93
 
94
- if game == "pole_position":
95
- # Direct 8-frame: use cache
96
- if hybrid.pp_cache is not None and n > CONTEXT_FRAMES and hybrid.pp_cache_step < PRED_FRAMES:
97
- result = hybrid.pp_cache[hybrid.pp_cache_step]
98
- hybrid.pp_cache_step += 1
99
- if hybrid.pp_cache_step >= PRED_FRAMES:
100
- hybrid.reset_pp_cache()
101
- return result
102
-
103
- # First call: predict all 8 frames
104
- hybrid.reset_pp_cache()
105
- with torch.no_grad():
106
- context_tensor = torch.from_numpy(context).to(DEVICE)
107
- last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
108
- output = model(context_tensor) # (1, 24, 64, 64)
109
- residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
110
- last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
111
- predicted = torch.clamp(last_expanded + residuals, 0, 1)
112
-
113
- predicted_np = predicted[0].cpu().numpy() # (8, 3, 64, 64)
114
- hybrid.pp_cache = []
115
- for i in range(PRED_FRAMES):
116
- frame = np.transpose(predicted_np[i], (1, 2, 0))
117
- frame = (frame * 255).clip(0, 255).astype(np.uint8)
118
- hybrid.pp_cache.append(frame)
119
-
120
- result = hybrid.pp_cache[hybrid.pp_cache_step]
121
- hybrid.pp_cache_step += 1
122
- return result
123
-
124
- else:
125
- # AR prediction for Pong and Sonic
126
- with torch.no_grad():
127
- context_tensor = torch.from_numpy(context).to(DEVICE)
128
- last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
129
-
130
- residual_orig = model(context_tensor)
131
- predicted_orig = torch.clamp(last_tensor + residual_orig, 0, 1)
132
-
133
- if game == "pong":
134
- predicted = predicted_orig
135
- else:
136
- # TTA for Sonic
137
- context_flipped = torch.flip(context_tensor, dims=[3])
138
- last_flipped = torch.flip(last_tensor, dims=[3])
139
- residual_flipped = model(context_flipped)
140
- predicted_flipped = torch.clamp(last_flipped + residual_flipped, 0, 1)
141
- predicted_flipped = torch.flip(predicted_flipped, dims=[3])
142
- predicted = (predicted_orig + predicted_flipped) / 2.0
143
-
144
- predicted_np = predicted[0].cpu().numpy()
145
- predicted_np = np.transpose(predicted_np, (1, 2, 0))
146
- predicted_np = (predicted_np * 255).clip(0, 255).astype(np.uint8)
147
- return predicted_np
 
1
+ """Direct 8-frame prediction for all games with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
 
11
  PRED_FRAMES = 8
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ GAME_CONFIGS = {
15
+ "pong": {"enc_channels": (32, 64, 128), "bottleneck": 128},
16
+ "sonic": {"enc_channels": (48, 96, 192), "bottleneck": 256},
17
+ "pole_position": {"enc_channels": (32, 64, 128), "bottleneck": 192},
18
+ }
19
+
20
 
21
  def detect_game(context_frames: np.ndarray) -> str:
22
  first_8 = context_frames[:CONTEXT_FRAMES]
 
32
  return "sonic"
33
 
34
 
35
+ class ModelCache:
36
+ def __init__(self, models):
37
+ self.models = models
38
+ self.cached_predictions = None
39
+ self.cache_step = 0
40
 
41
+ def reset_cache(self):
42
+ self.cached_predictions = None
43
+ self.cache_step = 0
44
 
45
 
46
  def load_model(model_dir: str):
47
+ models = {}
48
+ for game, cfg in GAME_CONFIGS.items():
49
+ model = UNet(in_channels=24, out_channels=24,
50
+ enc_channels=cfg["enc_channels"],
51
+ bottleneck_channels=cfg["bottleneck"],
52
+ upsample_mode="bilinear").to(DEVICE)
53
+ state_dict = torch.load(os.path.join(model_dir, f"model_{game}.pt"),
54
+ map_location=DEVICE, weights_only=True)
55
+ state_dict = {k: v.float() for k, v in state_dict.items()}
56
+ model.load_state_dict(state_dict)
57
+ model.eval()
58
+ models[game] = model
59
+ return ModelCache(models)
60
+
61
+
62
+ def _predict_8frames(model, context_tensor, last_tensor):
63
+ output = model(context_tensor) # (1, 24, 64, 64)
64
+ residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
65
+ last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
66
+ return torch.clamp(last_expanded + residuals, 0, 1) # (1, 8, 3, 64, 64)
67
+
68
+
69
+ def predict_next_frame(cache, context_frames: np.ndarray) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  n = len(context_frames)
71
 
72
+ # If cache exists and context grew (AR rollout), return next cached frame
73
+ if cache.cached_predictions is not None and n > CONTEXT_FRAMES and cache.cache_step < PRED_FRAMES:
74
+ result = cache.cached_predictions[cache.cache_step]
75
+ cache.cache_step += 1
76
+ if cache.cache_step >= PRED_FRAMES:
77
+ cache.reset_cache()
78
+ return result
79
+
80
+ # New window: predict all 8 frames
81
+ cache.reset_cache()
82
+ game = detect_game(context_frames)
83
+ model = cache.models[game]
84
+
85
  if n < CONTEXT_FRAMES:
86
  padding = np.stack([context_frames[0]] * (CONTEXT_FRAMES - n), axis=0)
87
  frames = np.concatenate([padding, context_frames], axis=0)
 
95
  last_frame = frames_norm[-1]
96
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
97
 
98
+ with torch.no_grad():
99
+ context_tensor = torch.from_numpy(context).to(DEVICE)
100
+ last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
101
+
102
+ predicted_orig = _predict_8frames(model, context_tensor, last_tensor)
103
+
104
+ if game == "pong":
105
+ # Pong: no TTA (asymmetric)
106
+ predicted = predicted_orig
107
+ else:
108
+ # TTA: horizontal flip (dim=3 is width for (B, T, C, H, W) reshaped from (B, 24, H, W))
109
+ # But we work on (1, 24, H, W) context - flip along dim 3 (width)
110
+ context_flipped = torch.flip(context_tensor, dims=[3])
111
+ last_flipped = torch.flip(last_tensor, dims=[3])
112
+ predicted_flipped = _predict_8frames(model, context_flipped, last_flipped)
113
+ # Flip back: predicted_flipped is (1, 8, 3, H, W), flip width dim=4
114
+ predicted_flipped = torch.flip(predicted_flipped, dims=[4])
115
+ predicted = (predicted_orig + predicted_flipped) / 2.0
116
+
117
+ predicted_np = predicted[0].cpu().numpy() # (8, 3, 64, 64)
118
+ cache.cached_predictions = []
119
+ for i in range(PRED_FRAMES):
120
+ frame = np.transpose(predicted_np[i], (1, 2, 0))
121
+ frame = (frame * 255).clip(0, 255).astype(np.uint8)
122
+ cache.cached_predictions.append(frame)
123
+
124
+ result = cache.cached_predictions[cache.cache_step]
125
+ cache.cache_step += 1
126
+ return result