Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +19 -5
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -155,8 +155,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 155 |
ctx = context_tensor.clone()
|
| 156 |
last_t = last_tensor.clone()
|
| 157 |
for step in range(PRED_FRAMES):
|
| 158 |
-
|
| 159 |
-
predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=pong_scale)
|
| 160 |
ar_preds.append(predicted)
|
| 161 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 162 |
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
|
@@ -214,9 +213,8 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 214 |
for step in range(PRED_FRAMES):
|
| 215 |
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 216 |
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=sonic_scale)
|
| 220 |
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 221 |
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 222 |
ar_preds_run.append(ar_frame)
|
|
@@ -232,6 +230,22 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 232 |
|
| 233 |
ar_pred = sum(all_ar_runs) / len(all_ar_runs)
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
predicted = torch.zeros_like(direct_pred)
|
| 236 |
for step in range(PRED_FRAMES):
|
| 237 |
ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
|
|
|
|
| 155 |
ctx = context_tensor.clone()
|
| 156 |
last_t = last_tensor.clone()
|
| 157 |
for step in range(PRED_FRAMES):
|
| 158 |
+
predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=1.03)
|
|
|
|
| 159 |
ar_preds.append(predicted)
|
| 160 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 161 |
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
|
|
|
| 213 |
for step in range(PRED_FRAMES):
|
| 214 |
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 215 |
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 216 |
+
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=1.08)
|
| 217 |
+
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=1.08)
|
|
|
|
| 218 |
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 219 |
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 220 |
ar_preds_run.append(ar_frame)
|
|
|
|
| 230 |
|
| 231 |
ar_pred = sum(all_ar_runs) / len(all_ar_runs)
|
| 232 |
|
| 233 |
+
# Apply mild Gaussian blur to AR predictions
|
| 234 |
+
import torch.nn.functional as F
|
| 235 |
+
# Create 3x3 Gaussian kernel with sigma=0.5
|
| 236 |
+
kernel_size = 3
|
| 237 |
+
sigma = 0.5
|
| 238 |
+
x = torch.arange(kernel_size, dtype=torch.float32, device=DEVICE) - kernel_size // 2
|
| 239 |
+
gauss = torch.exp(-x**2 / (2 * sigma**2))
|
| 240 |
+
kernel_2d = gauss[:, None] * gauss[None, :]
|
| 241 |
+
kernel_2d = kernel_2d / kernel_2d.sum()
|
| 242 |
+
kernel_2d = kernel_2d.view(1, 1, kernel_size, kernel_size).expand(3, 1, -1, -1)
|
| 243 |
+
pad = kernel_size // 2
|
| 244 |
+
for s in range(PRED_FRAMES):
|
| 245 |
+
frame = ar_pred[:, s] # [1, 3, 64, 64]
|
| 246 |
+
frame_padded = F.pad(frame, (pad, pad, pad, pad), mode='reflect')
|
| 247 |
+
ar_pred[:, s] = F.conv2d(frame_padded, kernel_2d, groups=3)
|
| 248 |
+
|
| 249 |
predicted = torch.zeros_like(direct_pred)
|
| 250 |
for step in range(PRED_FRAMES):
|
| 251 |
ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
|