ojaffe commited on
Commit
07239aa
·
verified ·
1 Parent(s): 65aa516

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __pycache__/predict.cpython-311.pyc +0 -0
  2. predict.py +19 -5
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -155,8 +155,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
155
  ctx = context_tensor.clone()
156
  last_t = last_tensor.clone()
157
  for step in range(PRED_FRAMES):
158
- pong_scale = 1.06 if step >= 4 else 1.0
159
- predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=pong_scale)
160
  ar_preds.append(predicted)
161
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
162
  ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
@@ -214,9 +213,8 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
214
  for step in range(PRED_FRAMES):
215
  ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
216
  ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
217
- sonic_scale = 1.12 if step >= 3 else 1.0
218
- ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=sonic_scale)
219
- ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=sonic_scale)
220
  ar_flip_back = torch.flip(ar_flip, dims=[3])
221
  ar_frame = (ar_orig + ar_flip_back) / 2.0
222
  ar_preds_run.append(ar_frame)
@@ -232,6 +230,22 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
232
 
233
  ar_pred = sum(all_ar_runs) / len(all_ar_runs)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  predicted = torch.zeros_like(direct_pred)
236
  for step in range(PRED_FRAMES):
237
  ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
 
155
  ctx = context_tensor.clone()
156
  last_t = last_tensor.clone()
157
  for step in range(PRED_FRAMES):
158
+ predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=1.03)
 
159
  ar_preds.append(predicted)
160
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
161
  ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
 
213
  for step in range(PRED_FRAMES):
214
  ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
215
  ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
216
+ ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=1.08)
217
+ ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=1.08)
 
218
  ar_flip_back = torch.flip(ar_flip, dims=[3])
219
  ar_frame = (ar_orig + ar_flip_back) / 2.0
220
  ar_preds_run.append(ar_frame)
 
230
 
231
  ar_pred = sum(all_ar_runs) / len(all_ar_runs)
232
 
233
+ # Apply mild Gaussian blur to AR predictions
234
+ import torch.nn.functional as F
235
+ # Create 3x3 Gaussian kernel with sigma=0.5
236
+ kernel_size = 3
237
+ sigma = 0.5
238
+ x = torch.arange(kernel_size, dtype=torch.float32, device=DEVICE) - kernel_size // 2
239
+ gauss = torch.exp(-x**2 / (2 * sigma**2))
240
+ kernel_2d = gauss[:, None] * gauss[None, :]
241
+ kernel_2d = kernel_2d / kernel_2d.sum()
242
+ kernel_2d = kernel_2d.view(1, 1, kernel_size, kernel_size).expand(3, 1, -1, -1)
243
+ pad = kernel_size // 2
244
+ for s in range(PRED_FRAMES):
245
+ frame = ar_pred[:, s] # [1, 3, 64, 64]
246
+ frame_padded = F.pad(frame, (pad, pad, pad, pad), mode='reflect')
247
+ ar_pred[:, s] = F.conv2d(frame_padded, kernel_2d, groups=3)
248
+
249
  predicted = torch.zeros_like(direct_pred)
250
  for step in range(PRED_FRAMES):
251
  ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3