ojaffe commited on
Commit
f6e18f6
·
verified ·
1 Parent(s): f3b7dc1

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_sonic_ar.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0853b8b0dad0a55f126be9bfd767d2e55fcc2ea9dcb379a79f6389c997e54816
3
- size 3129452
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:413d9fcfa15f30c74cdfda5f7d7c9dba8958fe027dfc09de563e6209c78378f5
3
+ size 6180566
predict.py CHANGED
@@ -1,4 +1,4 @@
1
- """Int8 ensemble: Sonic models quantized to int8, Pong/PP in fp16."""
2
  import sys
3
  import os
4
  import numpy as np
@@ -75,12 +75,13 @@ def load_model(model_dir: str):
75
  pong_direct.eval()
76
  ens.pong_direct = pong_direct
77
 
78
- # Sonic AR (int8 quantized, 3 outputs)
79
  sonic_ar = UNet(in_channels=24, out_channels=3,
80
  enc_channels=(48, 96, 192), bottleneck_channels=256,
81
  upsample_mode="bilinear").to(DEVICE)
82
- sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_ar.pt"), DEVICE)
83
- sonic_ar.load_state_dict(sd)
 
84
  sonic_ar.eval()
85
  ens.sonic_ar = sonic_ar
86
 
@@ -182,7 +183,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
182
  return result
183
 
184
  elif game == "sonic":
185
- # Sonic: AR+direct with step blending and TTA
186
  if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
187
  result = ens.direct_cache[ens.cache_step]
188
  ens.cache_step += 1
 
1
+ """Selective int8: only Sonic direct quantized to int8, Sonic AR kept in fp16."""
2
  import sys
3
  import os
4
  import numpy as np
 
75
  pong_direct.eval()
76
  ens.pong_direct = pong_direct
77
 
78
+ # Sonic AR (fp16, 3 outputs) - kept in fp16 for AR chain quality
79
  sonic_ar = UNet(in_channels=24, out_channels=3,
80
  enc_channels=(48, 96, 192), bottleneck_channels=256,
81
  upsample_mode="bilinear").to(DEVICE)
82
+ sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
83
+ map_location=DEVICE, weights_only=True)
84
+ sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
85
  sonic_ar.eval()
86
  ens.sonic_ar = sonic_ar
87
 
 
183
  return result
184
 
185
  elif game == "sonic":
186
+ # Sonic: AR(fp16)+direct(int8) with step blending and TTA
187
  if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
188
  result = ens.direct_cache[ens.cache_step]
189
  ens.cache_step += 1