Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- model_sonic_ar.pt +3 -0
- model_sonic_direct.pt +3 -0
- predict.py +111 -39
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
model_sonic_ar.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:413d9fcfa15f30c74cdfda5f7d7c9dba8958fe027dfc09de563e6209c78378f5
|
| 3 |
+
size 6180566
|
model_sonic_direct.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7e17327a6f03cb72a35bd3c48d481b4eebea5db6572ed2b3fa290b330bca304
|
| 3 |
+
size 6182614
|
predict.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
@@ -26,9 +26,11 @@ def detect_game(context_frames: np.ndarray) -> str:
|
|
| 26 |
return "sonic"
|
| 27 |
|
| 28 |
|
| 29 |
-
class
|
| 30 |
def __init__(self):
|
| 31 |
self.models = {}
|
|
|
|
|
|
|
| 32 |
self.direct_cache = None
|
| 33 |
self.cache_step = 0
|
| 34 |
|
|
@@ -38,9 +40,9 @@ class HybridModels:
|
|
| 38 |
|
| 39 |
|
| 40 |
def load_model(model_dir: str):
|
| 41 |
-
|
| 42 |
|
| 43 |
-
# Pong: AR model (3 outputs)
|
| 44 |
pong = UNet(in_channels=24, out_channels=3,
|
| 45 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 46 |
upsample_mode="bilinear").to(DEVICE)
|
|
@@ -48,19 +50,29 @@ def load_model(model_dir: str):
|
|
| 48 |
map_location=DEVICE, weights_only=True)
|
| 49 |
pong.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 50 |
pong.eval()
|
| 51 |
-
|
| 52 |
|
| 53 |
-
# Sonic
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
sd = torch.load(os.path.join(model_dir, "
|
| 58 |
map_location=DEVICE, weights_only=True)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
# PP: direct 8-frame model (24 outputs)
|
| 64 |
pp = UNet(in_channels=24, out_channels=24,
|
| 65 |
enc_channels=(32, 64, 128), bottleneck_channels=192,
|
| 66 |
upsample_mode="bilinear").to(DEVICE)
|
|
@@ -68,21 +80,25 @@ def load_model(model_dir: str):
|
|
| 68 |
map_location=DEVICE, weights_only=True)
|
| 69 |
pp.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 70 |
pp.eval()
|
| 71 |
-
|
| 72 |
|
| 73 |
-
return
|
| 74 |
|
| 75 |
|
| 76 |
-
def
|
| 77 |
output = model(context_tensor) # (1, 24, 64, 64)
|
| 78 |
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 79 |
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 80 |
return torch.clamp(last_expanded + residuals, 0, 1)
|
| 81 |
|
| 82 |
|
| 83 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
game = detect_game(context_frames)
|
| 85 |
-
model = hybrid.models[game]
|
| 86 |
n = len(context_frames)
|
| 87 |
|
| 88 |
if n < CONTEXT_FRAMES:
|
|
@@ -99,49 +115,105 @@ def predict_next_frame(hybrid, context_frames: np.ndarray) -> np.ndarray:
|
|
| 99 |
last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
|
| 100 |
|
| 101 |
if game == "pong":
|
| 102 |
-
# AR prediction for Pong
|
| 103 |
with torch.no_grad():
|
| 104 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 105 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 106 |
-
|
| 107 |
-
predicted = torch.clamp(last_tensor + residual, 0, 1)
|
| 108 |
|
| 109 |
predicted_np = predicted[0].cpu().numpy()
|
| 110 |
predicted_np = np.transpose(predicted_np, (1, 2, 0))
|
| 111 |
predicted_np = (predicted_np * 255).clip(0, 255).astype(np.uint8)
|
| 112 |
return predicted_np
|
| 113 |
|
| 114 |
-
|
| 115 |
-
#
|
| 116 |
-
if
|
| 117 |
-
result =
|
| 118 |
-
|
| 119 |
-
if
|
| 120 |
-
|
| 121 |
return result
|
| 122 |
|
| 123 |
-
|
| 124 |
-
hybrid.reset_cache()
|
| 125 |
with torch.no_grad():
|
| 126 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 127 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 133 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 134 |
-
predicted_flipped =
|
| 135 |
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 136 |
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 137 |
|
| 138 |
-
predicted_np = predicted[0].cpu().numpy()
|
| 139 |
-
|
| 140 |
for i in range(PRED_FRAMES):
|
| 141 |
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 142 |
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
| 143 |
-
|
| 144 |
|
| 145 |
-
result =
|
| 146 |
-
|
| 147 |
return result
|
|
|
|
| 1 |
+
"""Ensemble hybrid: AR+direct ensemble for Sonic, AR for Pong, direct for PP."""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
|
|
| 26 |
return "sonic"
|
| 27 |
|
| 28 |
|
| 29 |
+
class EnsembleModels:
|
| 30 |
def __init__(self):
|
| 31 |
self.models = {}
|
| 32 |
+
self.sonic_ar = None
|
| 33 |
+
self.sonic_direct = None
|
| 34 |
self.direct_cache = None
|
| 35 |
self.cache_step = 0
|
| 36 |
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def load_model(model_dir: str):
|
| 43 |
+
ens = EnsembleModels()
|
| 44 |
|
| 45 |
+
# Pong: AR model (3 outputs)
|
| 46 |
pong = UNet(in_channels=24, out_channels=3,
|
| 47 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 48 |
upsample_mode="bilinear").to(DEVICE)
|
|
|
|
| 50 |
map_location=DEVICE, weights_only=True)
|
| 51 |
pong.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 52 |
pong.eval()
|
| 53 |
+
ens.models["pong"] = pong
|
| 54 |
|
| 55 |
+
# Sonic AR model (3 outputs)
|
| 56 |
+
sonic_ar = UNet(in_channels=24, out_channels=3,
|
| 57 |
+
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 58 |
+
upsample_mode="bilinear").to(DEVICE)
|
| 59 |
+
sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
|
| 60 |
map_location=DEVICE, weights_only=True)
|
| 61 |
+
sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 62 |
+
sonic_ar.eval()
|
| 63 |
+
ens.sonic_ar = sonic_ar
|
| 64 |
+
|
| 65 |
+
# Sonic direct model (24 outputs)
|
| 66 |
+
sonic_direct = UNet(in_channels=24, out_channels=24,
|
| 67 |
+
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 68 |
+
upsample_mode="bilinear").to(DEVICE)
|
| 69 |
+
sd = torch.load(os.path.join(model_dir, "model_sonic_direct.pt"),
|
| 70 |
+
map_location=DEVICE, weights_only=True)
|
| 71 |
+
sonic_direct.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 72 |
+
sonic_direct.eval()
|
| 73 |
+
ens.sonic_direct = sonic_direct
|
| 74 |
|
| 75 |
+
# PP: direct 8-frame model (24 outputs)
|
| 76 |
pp = UNet(in_channels=24, out_channels=24,
|
| 77 |
enc_channels=(32, 64, 128), bottleneck_channels=192,
|
| 78 |
upsample_mode="bilinear").to(DEVICE)
|
|
|
|
| 80 |
map_location=DEVICE, weights_only=True)
|
| 81 |
pp.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 82 |
pp.eval()
|
| 83 |
+
ens.models["pole_position"] = pp
|
| 84 |
|
| 85 |
+
return ens
|
| 86 |
|
| 87 |
|
| 88 |
+
def _predict_8frames_direct(model, context_tensor, last_tensor):
|
| 89 |
output = model(context_tensor) # (1, 24, 64, 64)
|
| 90 |
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 91 |
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 92 |
return torch.clamp(last_expanded + residuals, 0, 1)
|
| 93 |
|
| 94 |
|
| 95 |
+
def _predict_ar_frame(model, context_tensor, last_tensor):
|
| 96 |
+
residual = model(context_tensor) # (1, 3, 64, 64)
|
| 97 |
+
return torch.clamp(last_tensor + residual, 0, 1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
| 101 |
game = detect_game(context_frames)
|
|
|
|
| 102 |
n = len(context_frames)
|
| 103 |
|
| 104 |
if n < CONTEXT_FRAMES:
|
|
|
|
| 115 |
last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
|
| 116 |
|
| 117 |
if game == "pong":
|
| 118 |
+
# AR prediction for Pong
|
| 119 |
with torch.no_grad():
|
| 120 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 121 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 122 |
+
predicted = _predict_ar_frame(ens.models["pong"], context_tensor, last_tensor)
|
|
|
|
| 123 |
|
| 124 |
predicted_np = predicted[0].cpu().numpy()
|
| 125 |
predicted_np = np.transpose(predicted_np, (1, 2, 0))
|
| 126 |
predicted_np = (predicted_np * 255).clip(0, 255).astype(np.uint8)
|
| 127 |
return predicted_np
|
| 128 |
|
| 129 |
+
elif game == "sonic":
|
| 130 |
+
# Ensemble: AR + direct for Sonic with caching
|
| 131 |
+
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 132 |
+
result = ens.direct_cache[ens.cache_step]
|
| 133 |
+
ens.cache_step += 1
|
| 134 |
+
if ens.cache_step >= PRED_FRAMES:
|
| 135 |
+
ens.reset_cache()
|
| 136 |
return result
|
| 137 |
|
| 138 |
+
ens.reset_cache()
|
|
|
|
| 139 |
with torch.no_grad():
|
| 140 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 141 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 142 |
|
| 143 |
+
# Direct prediction with TTA
|
| 144 |
+
direct_orig = _predict_8frames_direct(ens.sonic_direct, context_tensor, last_tensor)
|
| 145 |
+
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 146 |
+
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 147 |
+
direct_flipped = _predict_8frames_direct(ens.sonic_direct, context_flipped, last_flipped)
|
| 148 |
+
direct_flipped = torch.flip(direct_flipped, dims=[4])
|
| 149 |
+
direct_pred = (direct_orig + direct_flipped) / 2.0 # (1, 8, 3, 64, 64)
|
| 150 |
+
|
| 151 |
+
# AR prediction with TTA for each step
|
| 152 |
+
ar_preds = []
|
| 153 |
+
ctx = context_tensor.clone()
|
| 154 |
+
ctx_flip = context_flipped.clone()
|
| 155 |
+
last_t = last_tensor.clone()
|
| 156 |
+
last_f = last_flipped.clone()
|
| 157 |
+
for step in range(PRED_FRAMES):
|
| 158 |
+
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx, last_t)
|
| 159 |
+
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip, last_f)
|
| 160 |
+
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 161 |
+
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 162 |
+
ar_preds.append(ar_frame)
|
| 163 |
+
# Shift context for next AR step
|
| 164 |
+
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 165 |
+
ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
|
| 166 |
+
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 167 |
+
last_t = ar_orig
|
| 168 |
+
ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 169 |
+
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 170 |
+
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 171 |
+
last_f = ar_flip
|
| 172 |
+
|
| 173 |
+
ar_pred = torch.stack(ar_preds, dim=1) # (1, 8, 3, 64, 64)
|
| 174 |
+
|
| 175 |
+
# Ensemble: average AR and direct
|
| 176 |
+
predicted = (ar_pred + direct_pred) / 2.0
|
| 177 |
+
|
| 178 |
+
predicted_np = predicted[0].cpu().numpy()
|
| 179 |
+
ens.direct_cache = []
|
| 180 |
+
for i in range(PRED_FRAMES):
|
| 181 |
+
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 182 |
+
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
| 183 |
+
ens.direct_cache.append(frame)
|
| 184 |
|
| 185 |
+
result = ens.direct_cache[ens.cache_step]
|
| 186 |
+
ens.cache_step += 1
|
| 187 |
+
return result
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
# Direct 8-frame for PP with caching and TTA
|
| 191 |
+
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 192 |
+
result = ens.direct_cache[ens.cache_step]
|
| 193 |
+
ens.cache_step += 1
|
| 194 |
+
if ens.cache_step >= PRED_FRAMES:
|
| 195 |
+
ens.reset_cache()
|
| 196 |
+
return result
|
| 197 |
+
|
| 198 |
+
ens.reset_cache()
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 201 |
+
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 202 |
+
|
| 203 |
+
predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor)
|
| 204 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 205 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 206 |
+
predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped)
|
| 207 |
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 208 |
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 209 |
|
| 210 |
+
predicted_np = predicted[0].cpu().numpy()
|
| 211 |
+
ens.direct_cache = []
|
| 212 |
for i in range(PRED_FRAMES):
|
| 213 |
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 214 |
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
| 215 |
+
ens.direct_cache.append(frame)
|
| 216 |
|
| 217 |
+
result = ens.direct_cache[ens.cache_step]
|
| 218 |
+
ens.cache_step += 1
|
| 219 |
return result
|