Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +18 -15
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -195,17 +195,24 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 195 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 196 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 197 |
|
| 198 |
-
direct_orig = _predict_8frames_direct(ens.sonic_direct, context_tensor, last_tensor)
|
| 199 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 200 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 201 |
-
direct_flipped = _predict_8frames_direct(ens.sonic_direct, context_flipped, last_flipped)
|
| 202 |
-
direct_flipped = torch.flip(direct_flipped, dims=[4])
|
| 203 |
-
direct_pred = (direct_orig + direct_flipped) / 2.0
|
| 204 |
|
| 205 |
-
# Multi-run
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
|
|
|
|
| 209 |
ctx = context_tensor.clone()
|
| 210 |
ctx_flip = context_flipped.clone()
|
| 211 |
last_t = last_tensor.clone()
|
|
@@ -216,8 +223,8 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 216 |
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
|
| 217 |
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
|
| 218 |
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 219 |
-
|
| 220 |
-
|
| 221 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 222 |
ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
|
| 223 |
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
|
@@ -226,13 +233,9 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 226 |
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 227 |
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 228 |
last_f = ar_flip
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
for step in range(PRED_FRAMES):
|
| 232 |
-
stacked = torch.stack(all_step_preds[step], dim=0) # [6, 1, 3, 64, 64]
|
| 233 |
-
median_val = torch.median(stacked, dim=0).values
|
| 234 |
-
ar_pred_list.append(median_val)
|
| 235 |
-
ar_pred = torch.stack(ar_pred_list, dim=1) # [1, 8, 3, 64, 64]
|
| 236 |
|
| 237 |
predicted = torch.zeros_like(direct_pred)
|
| 238 |
for step in range(PRED_FRAMES):
|
|
|
|
| 195 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 196 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 197 |
|
|
|
|
| 198 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 199 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
# Multi-run direct with noise diversity
|
| 202 |
+
all_direct_runs = []
|
| 203 |
+
for noise_std in [0.0, 0.5/255.0, 1.0/255.0]:
|
| 204 |
+
ctx_in = context_tensor if noise_std == 0 else torch.clamp(context_tensor + torch.randn_like(context_tensor) * noise_std, 0, 1)
|
| 205 |
+
ctx_flip_in = context_flipped if noise_std == 0 else torch.clamp(context_flipped + torch.randn_like(context_flipped) * noise_std, 0, 1)
|
| 206 |
+
direct_orig = _predict_8frames_direct(ens.sonic_direct, ctx_in, last_tensor)
|
| 207 |
+
direct_flipped = _predict_8frames_direct(ens.sonic_direct, ctx_flip_in, last_flipped)
|
| 208 |
+
direct_flipped = torch.flip(direct_flipped, dims=[4])
|
| 209 |
+
all_direct_runs.append((direct_orig + direct_flipped) / 2.0)
|
| 210 |
+
direct_pred = sum(all_direct_runs) / len(all_direct_runs)
|
| 211 |
+
|
| 212 |
+
# Multi-run AR with noise diversity
|
| 213 |
+
all_ar_runs = []
|
| 214 |
for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
|
| 215 |
+
ar_preds_run = []
|
| 216 |
ctx = context_tensor.clone()
|
| 217 |
ctx_flip = context_flipped.clone()
|
| 218 |
last_t = last_tensor.clone()
|
|
|
|
| 223 |
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
|
| 224 |
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
|
| 225 |
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 226 |
+
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 227 |
+
ar_preds_run.append(ar_frame)
|
| 228 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 229 |
ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
|
| 230 |
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
|
|
|
| 233 |
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 234 |
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 235 |
last_f = ar_flip
|
| 236 |
+
all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
|
| 237 |
|
| 238 |
+
ar_pred = sum(all_ar_runs) / len(all_ar_runs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
predicted = torch.zeros_like(direct_pred)
|
| 241 |
for step in range(PRED_FRAMES):
|