ojaffe commited on
Commit
f3b7dc1
·
verified ·
1 Parent(s): 7f065fa

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_sonic_ar.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:413d9fcfa15f30c74cdfda5f7d7c9dba8958fe027dfc09de563e6209c78378f5
3
- size 6180566
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0853b8b0dad0a55f126be9bfd767d2e55fcc2ea9dcb379a79f6389c997e54816
3
+ size 3129452
model_sonic_direct.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7e17327a6f03cb72a35bd3c48d481b4eebea5db6572ed2b3fa290b330bca304
3
- size 6182614
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c51f9fb740cc1cb8dc93f119252905a47034bb4cab73b30f347e47de20ad3d6d
3
+ size 3131348
predict.py CHANGED
@@ -1,4 +1,4 @@
1
- """FP16 Pong ensemble: AR+direct for Pong, AR+direct for Sonic, direct for PP."""
2
  import sys
3
  import os
4
  import numpy as np
@@ -26,6 +26,18 @@ def detect_game(context_frames: np.ndarray) -> str:
26
  return "sonic"
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class EnsembleModels:
30
  def __init__(self):
31
  self.models = {}
@@ -43,7 +55,7 @@ class EnsembleModels:
43
  def load_model(model_dir: str):
44
  ens = EnsembleModels()
45
 
46
- # Pong AR (3 outputs)
47
  pong = UNet(in_channels=24, out_channels=3,
48
  enc_channels=(32, 64, 128), bottleneck_channels=128,
49
  upsample_mode="bilinear").to(DEVICE)
@@ -53,7 +65,7 @@ def load_model(model_dir: str):
53
  pong.eval()
54
  ens.models["pong"] = pong
55
 
56
- # Pong direct (24 outputs)
57
  pong_direct = UNet(in_channels=24, out_channels=24,
58
  enc_channels=(32, 64, 128), bottleneck_channels=128,
59
  upsample_mode="bilinear").to(DEVICE)
@@ -63,27 +75,25 @@ def load_model(model_dir: str):
63
  pong_direct.eval()
64
  ens.pong_direct = pong_direct
65
 
66
- # Sonic AR (3 outputs)
67
  sonic_ar = UNet(in_channels=24, out_channels=3,
68
  enc_channels=(48, 96, 192), bottleneck_channels=256,
69
  upsample_mode="bilinear").to(DEVICE)
70
- sd = torch.load(os.path.join(model_dir, "model_sonic_ar.pt"),
71
- map_location=DEVICE, weights_only=True)
72
- sonic_ar.load_state_dict({k: v.float() for k, v in sd.items()})
73
  sonic_ar.eval()
74
  ens.sonic_ar = sonic_ar
75
 
76
- # Sonic direct (24 outputs)
77
  sonic_direct = UNet(in_channels=24, out_channels=24,
78
  enc_channels=(48, 96, 192), bottleneck_channels=256,
79
  upsample_mode="bilinear").to(DEVICE)
80
- sd = torch.load(os.path.join(model_dir, "model_sonic_direct.pt"),
81
- map_location=DEVICE, weights_only=True)
82
- sonic_direct.load_state_dict({k: v.float() for k, v in sd.items()})
83
  sonic_direct.eval()
84
  ens.sonic_direct = sonic_direct
85
 
86
- # PP compact direct (24 outputs)
87
  pp = UNet(in_channels=24, out_channels=24,
88
  enc_channels=(24, 48, 96), bottleneck_channels=128,
89
  upsample_mode="bilinear").to(DEVICE)
@@ -126,7 +136,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
126
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
127
 
128
  if game == "pong":
129
- # Pong: AR+direct ensemble with float32 caching, no TTA
130
  if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
131
  result = ens.direct_cache[ens.cache_step]
132
  ens.cache_step += 1
@@ -135,21 +145,17 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
135
  return result
136
 
137
  ens.reset_cache()
138
- model_ar = ens.models["pong"]
139
- model_direct = ens.pong_direct
140
  with torch.no_grad():
141
  context_tensor = torch.from_numpy(context).to(DEVICE)
142
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
143
 
144
- # Direct prediction
145
- direct_pred = _predict_8frames_direct(model_direct, context_tensor, last_tensor)
146
 
147
- # AR prediction in float32
148
  ar_preds = []
149
  ctx = context_tensor.clone()
150
  last_t = last_tensor.clone()
151
  for step in range(PRED_FRAMES):
152
- predicted = _predict_ar_frame(model_ar, ctx, last_t)
153
  ar_preds.append(predicted)
154
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
155
  ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
@@ -158,7 +164,6 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
158
 
159
  ar_pred = torch.stack(ar_preds, dim=1)
160
 
161
- # Step-dependent blending: AR 0.7 -> 0.3
162
  predicted = torch.zeros_like(direct_pred)
163
  for step in range(PRED_FRAMES):
164
  ar_weight = 0.7 - (step / (PRED_FRAMES - 1)) * 0.4
 
1
+ """Int8 ensemble: Sonic models quantized to int8, Pong/PP in fp16."""
2
  import sys
3
  import os
4
  import numpy as np
 
26
  return "sonic"
27
 
28
 
29
+ def load_int8_state_dict(path, device):
30
+ """Load int8 quantized state dict and dequantize to float32."""
31
+ quantized = torch.load(path, map_location='cpu', weights_only=False)
32
+ sd = {}
33
+ for k, v in quantized.items():
34
+ if 'int8' in v:
35
+ sd[k] = (v['int8'].float() * v['scale']).to(device)
36
+ else:
37
+ sd[k] = v['float'].to(device)
38
+ return sd
39
+
40
+
41
  class EnsembleModels:
42
  def __init__(self):
43
  self.models = {}
 
55
  def load_model(model_dir: str):
56
  ens = EnsembleModels()
57
 
58
+ # Pong AR (fp16, 3 outputs)
59
  pong = UNet(in_channels=24, out_channels=3,
60
  enc_channels=(32, 64, 128), bottleneck_channels=128,
61
  upsample_mode="bilinear").to(DEVICE)
 
65
  pong.eval()
66
  ens.models["pong"] = pong
67
 
68
+ # Pong direct (fp16, 24 outputs)
69
  pong_direct = UNet(in_channels=24, out_channels=24,
70
  enc_channels=(32, 64, 128), bottleneck_channels=128,
71
  upsample_mode="bilinear").to(DEVICE)
 
75
  pong_direct.eval()
76
  ens.pong_direct = pong_direct
77
 
78
+ # Sonic AR (int8 quantized, 3 outputs)
79
  sonic_ar = UNet(in_channels=24, out_channels=3,
80
  enc_channels=(48, 96, 192), bottleneck_channels=256,
81
  upsample_mode="bilinear").to(DEVICE)
82
+ sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_ar.pt"), DEVICE)
83
+ sonic_ar.load_state_dict(sd)
 
84
  sonic_ar.eval()
85
  ens.sonic_ar = sonic_ar
86
 
87
+ # Sonic direct (int8 quantized, 24 outputs)
88
  sonic_direct = UNet(in_channels=24, out_channels=24,
89
  enc_channels=(48, 96, 192), bottleneck_channels=256,
90
  upsample_mode="bilinear").to(DEVICE)
91
+ sd = load_int8_state_dict(os.path.join(model_dir, "model_sonic_direct.pt"), DEVICE)
92
+ sonic_direct.load_state_dict(sd)
 
93
  sonic_direct.eval()
94
  ens.sonic_direct = sonic_direct
95
 
96
+ # PP compact direct (fp16, 24 outputs)
97
  pp = UNet(in_channels=24, out_channels=24,
98
  enc_channels=(24, 48, 96), bottleneck_channels=128,
99
  upsample_mode="bilinear").to(DEVICE)
 
136
  last_frame_t = np.transpose(last_frame, (2, 0, 1))[np.newaxis]
137
 
138
  if game == "pong":
139
+ # Pong: AR+direct ensemble, float32 caching, no TTA
140
  if ens.direct_cache is not None and n > CONTEXT_FRAMES and ens.cache_step < PRED_FRAMES:
141
  result = ens.direct_cache[ens.cache_step]
142
  ens.cache_step += 1
 
145
  return result
146
 
147
  ens.reset_cache()
 
 
148
  with torch.no_grad():
149
  context_tensor = torch.from_numpy(context).to(DEVICE)
150
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
151
 
152
+ direct_pred = _predict_8frames_direct(ens.pong_direct, context_tensor, last_tensor)
 
153
 
 
154
  ar_preds = []
155
  ctx = context_tensor.clone()
156
  last_t = last_tensor.clone()
157
  for step in range(PRED_FRAMES):
158
+ predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t)
159
  ar_preds.append(predicted)
160
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
161
  ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
 
164
 
165
  ar_pred = torch.stack(ar_preds, dim=1)
166
 
 
167
  predicted = torch.zeros_like(direct_pred)
168
  for step in range(PRED_FRAMES):
169
  ar_weight = 0.7 - (step / (PRED_FRAMES - 1)) * 0.4