ojaffe commited on
Commit
f2eecc0
·
verified ·
1 Parent(s): 2fd1c51

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_sonic_ar.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:413d9fcfa15f30c74cdfda5f7d7c9dba8958fe027dfc09de563e6209c78378f5
3
+ size 6180566
model_sonic_direct.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7e17327a6f03cb72a35bd3c48d481b4eebea5db6572ed2b3fa290b330bca304
3
+ size 6182614
predict.py CHANGED
@@ -1,4 +1,4 @@
1
- """Hybrid v5: Best per-game models. AR for Pong, direct 8-frame for Sonic/PP with TTA."""
2
  import sys
3
  import os
4
  import numpy as np
@@ -26,9 +26,11 @@ def detect_game(context_frames: np.ndarray) -> str:
26
  return "sonic"
27
 
28
 
29
- class HybridModels:
30
  def __init__(self):
31
  self.models = {}
 
 
32
  self.direct_cache = None
33
  self.cache_step = 0
34
 
@@ -38,9 +40,9 @@ class HybridModels:
38
 
39
 
40
  def load_model(model_dir: str):
41
- hybrid = HybridModels()
42
 
43
- # Pong: AR model (3 outputs) from pergame-models
44
  pong = UNet(in_channels=24, out_channels=3,
45
  enc_channels=(32, 64, 128), bottleneck_channels=128,
46
  upsample_mode="bilinear").to(DEVICE)
@@ -48,19 +50,29 @@ def load_model(model_dir: str):
48
  map_location=DEVICE, weights_only=True)
49
  pong.load_state_dict({k: v.float() for k, v in sd.items()})
50
  pong.eval()
51
- hybrid.models["pong"] = pong
52
 
53
- # Sonic: direct 8-frame model (24 outputs) from direct-improved
54
- sonic = UNet(in_channels=24, out_channels=24,
55
- enc_channels=(48, 96, 192), bottleneck_channels=256,
56
- upsample_mode="bilinear").to(DEVICE)
57
- sd = torch.load(os.path.join(model_dir, "model_sonic.pt"),
58
  map_location=DEVICE, weights_only=True)
59
- sonic.load_state_dict({k: v.float() for k, v in sd.items()})
60
- sonic.eval()
61
- hybrid.models["sonic"] = sonic
 
 
 
 
 
 
 
 
 
 
62
 
63
- # PP: direct 8-frame model (24 outputs) from direct-8frame
64
  pp = UNet(in_channels=24, out_channels=24,
65
  enc_channels=(32, 64, 128), bottleneck_channels=192,
66
  upsample_mode="bilinear").to(DEVICE)
@@ -68,21 +80,25 @@ def load_model(model_dir: str):
68
  map_location=DEVICE, weights_only=True)
69
  pp.load_state_dict({k: v.float() for k, v in sd.items()})
70
  pp.eval()
71
- hybrid.models["pole_position"] = pp
72
 
73
- return hybrid
74
 
75
 
76
- def _predict_8frames(model, context_tensor, last_tensor):
77
  output = model(context_tensor) # (1, 24, 64, 64)
78
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
79
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
80
  return torch.clamp(last_expanded + residuals, 0, 1)
81
 
82
 
83
- def predict_next_frame(hybrid, context_frames: np.ndarray) -> np.ndarray:
 
 
 
 
 
84
  game = detect_game(context_frames)
85
- model = hybrid.models[game]
86
  n = len(context_frames)
87
 
88
  if n < CONTEXT_FRAMES:
@@ -99,49 +115,105 @@ def predict_next_frame(hybrid, context_frames: np.ndarray) -> np.ndarray:
99
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
100
 
101
  if game == "pong":
102
- # AR prediction for Pong (no TTA, no caching)
103
  with torch.no_grad():
104
  context_tensor = torch.from_numpy(context).to(DEVICE)
105
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
106
- residual = model(context_tensor)
107
- predicted = torch.clamp(last_tensor + residual, 0, 1)
108
 
109
  predicted_np = predicted[0].cpu().numpy()
110
  predicted_np = np.transpose(predicted_np, (1, 2, 0))
111
  predicted_np = (predicted_np * 255).clip(0, 255).astype(np.uint8)
112
  return predicted_np
113
 
114
- else:
115
- # Direct 8-frame for Sonic and PP with caching
116
- if hybrid.direct_cache is not None and n > CONTEXT_FRAMES and hybrid.cache_step < PRED_FRAMES:
117
- result = hybrid.direct_cache[hybrid.cache_step]
118
- hybrid.cache_step += 1
119
- if hybrid.cache_step >= PRED_FRAMES:
120
- hybrid.reset_cache()
121
  return result
122
 
123
- # New window: predict all 8 frames with TTA
124
- hybrid.reset_cache()
125
  with torch.no_grad():
126
  context_tensor = torch.from_numpy(context).to(DEVICE)
127
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
128
 
129
- predicted_orig = _predict_8frames(model, context_tensor, last_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # TTA: horizontal flip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  context_flipped = torch.flip(context_tensor, dims=[3])
133
  last_flipped = torch.flip(last_tensor, dims=[3])
134
- predicted_flipped = _predict_8frames(model, context_flipped, last_flipped)
135
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
136
  predicted = (predicted_orig + predicted_flipped) / 2.0
137
 
138
- predicted_np = predicted[0].cpu().numpy() # (8, 3, 64, 64)
139
- hybrid.direct_cache = []
140
  for i in range(PRED_FRAMES):
141
  frame = np.transpose(predicted_np[i], (1, 2, 0))
142
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
143
- hybrid.direct_cache.append(frame)
144
 
145
- result = hybrid.direct_cache[hybrid.cache_step]
146
- hybrid.cache_step += 1
147
  return result
 
1
+ """Ensemble hybrid: AR+direct ensemble for Sonic, AR for Pong, direct for PP."""
2
  import sys
3
  import os
4
  import numpy as np
 
26
  return "sonic"
27
 
28
 
29
+ class EnsembleModels:
30
  def __init__(self):
31
  self.models = {}
32
+ self.sonic_ar = None
33
+ self.sonic_direct = None
34
  self.direct_cache = None
35
  self.cache_step = 0
36
 
 
40
 
41
 
42
  def load_model(model_dir: str):
43
+ ens = EnsembleModels()
44
 
45
+ # Pong: AR model (3 outputs)
46
  pong = UNet(in_channels=24, out_channels=3,
47
  enc_channels=(32, 64, 128), bottleneck_channels=128,
48
  upsample_mode="bilinear").to(DEVICE)
 
50
  map_location=DEVICE, weights_only=True)
51
  pong.load_state_dict({k: v.float() for k, v in sd.items()})
52
  pong.eval()
53
+ ens.models["pong"] = pong
54
 
55
+ # Sonic AR model (3 outputs)
56
+ sonic_ar = UNet(in_channels=24, out_channels=3,
57
+ enc_channels=(48, 96, 192), bottleneck_channels=256,
58
+ upsample_mode="bilinear").to(DEVICE)
59
+ sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
60
  map_location=DEVICE, weights_only=True)
61
+ sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
62
+ sonic_ar.eval()
63
+ ens.sonic_ar = sonic_ar
64
+
65
+ # Sonic direct model (24 outputs)
66
+ sonic_direct = UNet(in_channels=24, out_channels=24,
67
+ enc_channels=(48, 96, 192), bottleneck_channels=256,
68
+ upsample_mode="bilinear").to(DEVICE)
69
+ sd = torch.load(os.path.join(model_dir, "model_sonic_direct.pt"),
70
+ map_location=DEVICE, weights_only=True)
71
+ sonic_direct.load_state_dict({k: v.float() for k, v in sd.items()})
72
+ sonic_direct.eval()
73
+ ens.sonic_direct = sonic_direct
74
 
75
+ # PP: direct 8-frame model (24 outputs)
76
  pp = UNet(in_channels=24, out_channels=24,
77
  enc_channels=(32, 64, 128), bottleneck_channels=192,
78
  upsample_mode="bilinear").to(DEVICE)
 
80
  map_location=DEVICE, weights_only=True)
81
  pp.load_state_dict({k: v.float() for k, v in sd.items()})
82
  pp.eval()
83
+ ens.models["pole_position"] = pp
84
 
85
+ return ens
86
 
87
 
88
+ def _predict_8frames_direct(model, context_tensor, last_tensor):
89
  output = model(context_tensor) # (1, 24, 64, 64)
90
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
91
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
92
  return torch.clamp(last_expanded + residuals, 0, 1)
93
 
94
 
95
+ def _predict_ar_frame(model, context_tensor, last_tensor):
96
+ residual = model(context_tensor) # (1, 3, 64, 64)
97
+ return torch.clamp(last_tensor + residual, 0, 1)
98
+
99
+
100
+ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
101
  game = detect_game(context_frames)
 
102
  n = len(context_frames)
103
 
104
  if n < CONTEXT_FRAMES:
 
115
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
116
 
117
  if game == "pong":
118
+ # AR prediction for Pong
119
  with torch.no_grad():
120
  context_tensor = torch.from_numpy(context).to(DEVICE)
121
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
122
+ predicted = _predict_ar_frame(ens.models["pong"], context_tensor, last_tensor)
 
123
 
124
  predicted_np = predicted[0].cpu().numpy()
125
  predicted_np = np.transpose(predicted_np, (1, 2, 0))
126
  predicted_np = (predicted_np * 255).clip(0, 255).astype(np.uint8)
127
  return predicted_np
128
 
129
+ elif game == "sonic":
130
+ # Ensemble: AR + direct for Sonic with caching
131
+ if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
132
+ result = ens.direct_cache[ens.cache_step]
133
+ ens.cache_step += 1
134
+ if ens.cache_step >= PRED_FRAMES:
135
+ ens.reset_cache()
136
  return result
137
 
138
+ ens.reset_cache()
 
139
  with torch.no_grad():
140
  context_tensor = torch.from_numpy(context).to(DEVICE)
141
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
142
 
143
+ # Direct prediction with TTA
144
+ direct_orig = _predict_8frames_direct(ens.sonic_direct, context_tensor, last_tensor)
145
+ context_flipped = torch.flip(context_tensor, dims=[3])
146
+ last_flipped = torch.flip(last_tensor, dims=[3])
147
+ direct_flipped = _predict_8frames_direct(ens.sonic_direct, context_flipped, last_flipped)
148
+ direct_flipped = torch.flip(direct_flipped, dims=[4])
149
+ direct_pred = (direct_orig + direct_flipped) / 2.0 # (1, 8, 3, 64, 64)
150
+
151
+ # AR prediction with TTA for each step
152
+ ar_preds = []
153
+ ctx = context_tensor.clone()
154
+ ctx_flip = context_flipped.clone()
155
+ last_t = last_tensor.clone()
156
+ last_f = last_flipped.clone()
157
+ for step in range(PRED_FRAMES):
158
+ ar_orig = _predict_ar_frame(ens.sonic_ar, ctx, last_t)
159
+ ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip, last_f)
160
+ ar_flip_back = torch.flip(ar_flip, dims=[3])
161
+ ar_frame = (ar_orig + ar_flip_back) / 2.0
162
+ ar_preds.append(ar_frame)
163
+ # Shift context for next AR step
164
+ ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
165
+ ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
166
+ ctx = ctx_frames.reshape(1, -1, 64, 64)
167
+ last_t = ar_orig
168
+ ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
169
+ ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
170
+ ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
171
+ last_f = ar_flip
172
+
173
+ ar_pred = torch.stack(ar_preds, dim=1) # (1, 8, 3, 64, 64)
174
+
175
+ # Ensemble: average AR and direct
176
+ predicted = (ar_pred + direct_pred) / 2.0
177
+
178
+ predicted_np = predicted[0].cpu().numpy()
179
+ ens.direct_cache = []
180
+ for i in range(PRED_FRAMES):
181
+ frame = np.transpose(predicted_np[i], (1, 2, 0))
182
+ frame = (frame * 255).clip(0, 255).astype(np.uint8)
183
+ ens.direct_cache.append(frame)
184
 
185
+ result = ens.direct_cache[ens.cache_step]
186
+ ens.cache_step += 1
187
+ return result
188
+
189
+ else:
190
+ # Direct 8-frame for PP with caching and TTA
191
+ if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
192
+ result = ens.direct_cache[ens.cache_step]
193
+ ens.cache_step += 1
194
+ if ens.cache_step >= PRED_FRAMES:
195
+ ens.reset_cache()
196
+ return result
197
+
198
+ ens.reset_cache()
199
+ with torch.no_grad():
200
+ context_tensor = torch.from_numpy(context).to(DEVICE)
201
+ last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
202
+
203
+ predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor)
204
  context_flipped = torch.flip(context_tensor, dims=[3])
205
  last_flipped = torch.flip(last_tensor, dims=[3])
206
+ predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped)
207
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
208
  predicted = (predicted_orig + predicted_flipped) / 2.0
209
 
210
+ predicted_np = predicted[0].cpu().numpy()
211
+ ens.direct_cache = []
212
  for i in range(PRED_FRAMES):
213
  frame = np.transpose(predicted_np[i], (1, 2, 0))
214
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
215
+ ens.direct_cache.append(frame)
216
 
217
+ result = ens.direct_cache[ens.cache_step]
218
+ ens.cache_step += 1
219
  return result