Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- model_sonic_ar.pt +2 -2
- predict.py +6 -5
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
model_sonic_ar.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:413d9fcfa15f30c74cdfda5f7d7c9dba8958fe027dfc09de563e6209c78378f5
|
| 3 |
+
size 6180566
|
predict.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
@@ -75,12 +75,13 @@ def load_model(model_dir: str):
|
|
| 75 |
pong_direct.eval()
|
| 76 |
ens.pong_direct = pong_direct
|
| 77 |
|
| 78 |
-
# Sonic AR (
|
| 79 |
sonic_ar = UNet(in_channels=24, out_channels=3,
|
| 80 |
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 81 |
upsample_mode="bilinear").to(DEVICE)
|
| 82 |
-
sd =
|
| 83 |
-
|
|
|
|
| 84 |
sonic_ar.eval()
|
| 85 |
ens.sonic_ar = sonic_ar
|
| 86 |
|
|
@@ -182,7 +183,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 182 |
return result
|
| 183 |
|
| 184 |
elif game == "sonic":
|
| 185 |
-
# Sonic: AR+direct with step blending and TTA
|
| 186 |
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 187 |
result = ens.direct_cache[ens.cache_step]
|
| 188 |
ens.cache_step += 1
|
|
|
|
| 1 |
+
"""Selective int8: only Sonic direct quantized to int8, Sonic AR kept in fp16."""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
|
|
| 75 |
pong_direct.eval()
|
| 76 |
ens.pong_direct = pong_direct
|
| 77 |
|
| 78 |
+
# Sonic AR (fp16, 3 outputs) - kept in fp16 for AR chain quality
|
| 79 |
sonic_ar = UNet(in_channels=24, out_channels=3,
|
| 80 |
enc_channels=(48, 96, 192), bottleneck_channels=256,
|
| 81 |
upsample_mode="bilinear").to(DEVICE)
|
| 82 |
+
sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
|
| 83 |
+
map_location=DEVICE, weights_only=True)
|
| 84 |
+
sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 85 |
sonic_ar.eval()
|
| 86 |
ens.sonic_ar = sonic_ar
|
| 87 |
|
|
|
|
| 183 |
return result
|
| 184 |
|
| 185 |
elif game == "sonic":
|
| 186 |
+
# Sonic: AR(fp16)+direct(int8) with step blending and TTA
|
| 187 |
if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
|
| 188 |
result = ens.direct_cache[ens.cache_step]
|
| 189 |
ens.cache_step += 1
|