Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- model_sonic_ar.pt +2 -2
- model_sonic_direct.pt +2 -2
- predict.py +25 -20
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
model_sonic_ar.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0853b8b0dad0a55f126be9bfd767d2e55fcc2ea9dcb379a79f6389c997e54816
|
| 3 |
+
size 3129452
|
model_sonic_direct.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c51f9fb740cc1cb8dc93f119252905a47034bb4cab73b30f347e47de20ad3d6d
|
| 3 |
+
size 3131348
|
predict.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
@@ -26,6 +26,18 @@ def detect_game(context_frames: np.ndarray) -> str:
|
|
| 26 |
return "sonic"
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
class EnsembleModels:
|
| 30 |
def __init__(self):
|
| 31 |
self.models = {}
|
|
@@ -43,7 +55,7 @@ class EnsembleModels:
|
|
| 43 |
def load_model(model_dir: str):
|
| 44 |
ens = EnsembleModels()
|
| 45 |
|
| 46 |
-
# Pong AR (3 outputs)
|
| 47 |
pong = UNet(in_channels=24, out_channels=3,
|
| 48 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 49 |
upsample_mode="bilinear").to(DEVICE)
|
|
@@ -53,7 +65,7 @@ def load_model(model_dir: str):
|
|
| 53 |
pong.eval()
|
| 54 |
ens.models["pong"] = pong
|
| 55 |
|
| 56 |
-
# Pong direct (24 outputs)
|
| 57 |
pong_direct = UNet(in_channels=24, out_channels=24,
|
| 58 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 59 |
upsample_mode="bilinear").to(DEVICE)
|
|
@@ -63,27 +75,25 @@ def load_model(model_dir: str):
|
|
| 63 |
pong_direct.eval()
|
| 64 |
ens.pong_direct = pong_direct
|
| 65 |
|
| 66 |
-
# Sonic AR (3 outputs)
|
| 67 |
sonic_ar = UNet(in_channels=24, out_channels=3,
|
| 68 |
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 69 |
upsample_mode="bilinear").to(DEVICE)
|
| 70 |
-
sd =
|
| 71 |
-
|
| 72 |
-
sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 73 |
sonic_ar.eval()
|
| 74 |
ens.sonic_ar = sonic_ar
|
| 75 |
|
| 76 |
-
# Sonic direct (24 outputs)
|
| 77 |
sonic_direct = UNet(in_channels=24, out_channels=24,
|
| 78 |
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 79 |
upsample_mode="bilinear").to(DEVICE)
|
| 80 |
-
sd =
|
| 81 |
-
|
| 82 |
-
sonic_direct.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 83 |
sonic_direct.eval()
|
| 84 |
ens.sonic_direct = sonic_direct
|
| 85 |
|
| 86 |
-
# PP compact direct (24 outputs)
|
| 87 |
pp = UNet(in_channels=24, out_channels=24,
|
| 88 |
enc_channels=(24, 48, 96), bottleneck_channels=128,
|
| 89 |
upsample_mode="bilinear").to(DEVICE)
|
|
@@ -126,7 +136,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 126 |
last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
|
| 127 |
|
| 128 |
if game == "pong":
|
| 129 |
-
# Pong: AR+direct ensemble
|
| 130 |
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 131 |
result = ens.direct_cache[ens.cache_step]
|
| 132 |
ens.cache_step += 1
|
|
@@ -135,21 +145,17 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 135 |
return result
|
| 136 |
|
| 137 |
ens.reset_cache()
|
| 138 |
-
model_ar = ens.models["pong"]
|
| 139 |
-
model_direct = ens.pong_direct
|
| 140 |
with torch.no_grad():
|
| 141 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 142 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
direct_pred = _predict_8frames_direct(model_direct, context_tensor, last_tensor)
|
| 146 |
|
| 147 |
-
# AR prediction in float32
|
| 148 |
ar_preds = []
|
| 149 |
ctx = context_tensor.clone()
|
| 150 |
last_t = last_tensor.clone()
|
| 151 |
for step in range(PRED_FRAMES):
|
| 152 |
-
predicted = _predict_ar_frame(
|
| 153 |
ar_preds.append(predicted)
|
| 154 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 155 |
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
|
@@ -158,7 +164,6 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 158 |
|
| 159 |
ar_pred = torch.stack(ar_preds, dim=1)
|
| 160 |
|
| 161 |
-
# Step-dependent blending: AR 0.7 -> 0.3
|
| 162 |
predicted = torch.zeros_like(direct_pred)
|
| 163 |
for step in range(PRED_FRAMES):
|
| 164 |
ar_weight = 0.7 - (step / (PRED_FRAMES - 1)) * 0.4
|
|
|
|
| 1 |
+
"""Int8 ensemble: Sonic models quantized to int8, Pong/PP in fp16."""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
|
|
| 26 |
return "sonic"
|
| 27 |
|
| 28 |
|
| 29 |
+
def load_int8_state_dict(path, device):
|
| 30 |
+
"""Load int8 quantized state dict and dequantize to float32."""
|
| 31 |
+
quantized = torch.load(path, map_location='cpu', weights_only=False)
|
| 32 |
+
sd = {}
|
| 33 |
+
for k, v in quantized.items():
|
| 34 |
+
if 'int8' in v:
|
| 35 |
+
sd[k] = (v['int8'].float() * v['scale']).to(device)
|
| 36 |
+
else:
|
| 37 |
+
sd[k] = v['float'].to(device)
|
| 38 |
+
return sd
|
| 39 |
+
|
| 40 |
+
|
| 41 |
class EnsembleModels:
|
| 42 |
def __init__(self):
|
| 43 |
self.models = {}
|
|
|
|
| 55 |
def load_model(model_dir: str):
|
| 56 |
ens = EnsembleModels()
|
| 57 |
|
| 58 |
+
# Pong AR (fp16, 3 outputs)
|
| 59 |
pong = UNet(in_channels=24, out_channels=3,
|
| 60 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 61 |
upsample_mode="bilinear").to(DEVICE)
|
|
|
|
| 65 |
pong.eval()
|
| 66 |
ens.models["pong"] = pong
|
| 67 |
|
| 68 |
+
# Pong direct (fp16, 24 outputs)
|
| 69 |
pong_direct = UNet(in_channels=24, out_channels=24,
|
| 70 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 71 |
upsample_mode="bilinear").to(DEVICE)
|
|
|
|
| 75 |
pong_direct.eval()
|
| 76 |
ens.pong_direct = pong_direct
|
| 77 |
|
| 78 |
+
# Sonic AR (int8 quantized, 3 outputs)
|
| 79 |
sonic_ar = UNet(in_channels=24, out_channels=3,
|
| 80 |
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 81 |
upsample_mode="bilinear").to(DEVICE)
|
| 82 |
+
sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_ar.pt"), DEVICE)
|
| 83 |
+
sonic_ar.load_state_dict(sd)
|
|
|
|
| 84 |
sonic_ar.eval()
|
| 85 |
ens.sonic_ar = sonic_ar
|
| 86 |
|
| 87 |
+
# Sonic direct (int8 quantized, 24 outputs)
|
| 88 |
sonic_direct = UNet(in_channels=24, out_channels=24,
|
| 89 |
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 90 |
upsample_mode="bilinear").to(DEVICE)
|
| 91 |
+
sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_direct.pt"), DEVICE)
|
| 92 |
+
sonic_direct.load_state_dict(sd)
|
|
|
|
| 93 |
sonic_direct.eval()
|
| 94 |
ens.sonic_direct = sonic_direct
|
| 95 |
|
| 96 |
+
# PP compact direct (fp16, 24 outputs)
|
| 97 |
pp = UNet(in_channels=24, out_channels=24,
|
| 98 |
enc_channels=(24, 48, 96), bottleneck_channels=128,
|
| 99 |
upsample_mode="bilinear").to(DEVICE)
|
|
|
|
| 136 |
last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
|
| 137 |
|
| 138 |
if game == "pong":
|
| 139 |
+
# Pong: AR+direct ensemble, float32 caching, no TTA
|
| 140 |
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 141 |
result = ens.direct_cache[ens.cache_step]
|
| 142 |
ens.cache_step += 1
|
|
|
|
| 145 |
return result
|
| 146 |
|
| 147 |
ens.reset_cache()
|
|
|
|
|
|
|
| 148 |
with torch.no_grad():
|
| 149 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 150 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 151 |
|
| 152 |
+
direct_pred = _predict_8frames_direct(ens.pong_direct, context_tensor, last_tensor)
|
|
|
|
| 153 |
|
|
|
|
| 154 |
ar_preds = []
|
| 155 |
ctx = context_tensor.clone()
|
| 156 |
last_t = last_tensor.clone()
|
| 157 |
for step in range(PRED_FRAMES):
|
| 158 |
+
predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t)
|
| 159 |
ar_preds.append(predicted)
|
| 160 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 161 |
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
|
|
|
| 164 |
|
| 165 |
ar_pred = torch.stack(ar_preds, dim=1)
|
| 166 |
|
|
|
|
| 167 |
predicted = torch.zeros_like(direct_pred)
|
| 168 |
for step in range(PRED_FRAMES):
|
| 169 |
ar_weight = 0.7 - (step / (PRED_FRAMES - 1)) * 0.4
|