ojaffe commited on
Commit
0155fbe
·
verified ·
1 Parent(s): ed16acf

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_pole_position.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:62d218b9859acd4d19cfcfe6b3aa93ae129485a872175632ed32d6441ae9c7f6
3
- size 1580934
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e0affcef8e533a29037751e27948a3eb0f2fda2792ce2b3dfc876cadb09e281
3
+ size 2971526
model_pong_direct.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ada51c7d09e003a2bea134bfe7be0e762756f10cebaab73c904135c5a4e33cf
3
- size 2437546
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:381e3abbf7c5c308e985f53a916379140be588383760fa117009775b6bc79281
3
+ size 1262522
predict.py CHANGED
@@ -1,4 +1,4 @@
1
- """Optimized blend: Pong AR weight 0.85->0.65, Sonic unchanged 0.7->0.3."""
2
  import sys
3
  import os
4
  import numpy as np
@@ -65,13 +65,12 @@ def load_model(model_dir: str):
65
  pong.eval()
66
  ens.models["pong"] = pong
67
 
68
- # Pong direct (fp16, 24 outputs)
69
  pong_direct = UNet(in_channels=24, out_channels=24,
70
  enc_channels=(32, 64, 128), bottleneck_channels=128,
71
  upsample_mode="bilinear").to(DEVICE)
72
- sd = torch.load(os.path.join(model_dir, "model_pong_direct.pt"),
73
- map_location=DEVICE, weights_only=True)
74
- pong_direct.load_state_dict({k: v.float() for k, v in sd.items()})
75
  pong_direct.eval()
76
  ens.pong_direct = pong_direct
77
 
@@ -94,9 +93,9 @@ def load_model(model_dir: str):
94
  sonic_direct.eval()
95
  ens.sonic_direct = sonic_direct
96
 
97
- # PP compact direct (fp16, 24 outputs)
98
  pp = UNet(in_channels=24, out_channels=24,
99
- enc_channels=(24, 48, 96), bottleneck_channels=128,
100
  upsample_mode="bilinear").to(DEVICE)
101
  sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
102
  map_location=DEVICE, weights_only=True)
 
1
+ """Full PP swap: Pong direct int8, full PP model, Sonic AR fp16 + direct int8."""
2
  import sys
3
  import os
4
  import numpy as np
 
65
  pong.eval()
66
  ens.models["pong"] = pong
67
 
68
+ # Pong direct (int8 quantized, 24 outputs)
69
  pong_direct = UNet(in_channels=24, out_channels=24,
70
  enc_channels=(32, 64, 128), bottleneck_channels=128,
71
  upsample_mode="bilinear").to(DEVICE)
72
+ sd = load_int8_state_dict(os.path.join(model_dir, "model_pong_direct.pt"), DEVICE)
73
+ pong_direct.load_state_dict(sd)
 
74
  pong_direct.eval()
75
  ens.pong_direct = pong_direct
76
 
 
93
  sonic_direct.eval()
94
  ens.sonic_direct = sonic_direct
95
 
96
+ # PP full direct (fp16, 24 outputs)
97
  pp = UNet(in_channels=24, out_channels=24,
98
+ enc_channels=(32, 64, 128), bottleneck_channels=192,
99
  upsample_mode="bilinear").to(DEVICE)
100
  sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
101
  map_location=DEVICE, weights_only=True)