Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +40 -44
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -190,14 +190,6 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 190 |
ens.reset_cache()
|
| 191 |
return result
|
| 192 |
|
| 193 |
-
# Detect extreme scene transitions (threshold 80 on 0-255 scale)
|
| 194 |
-
scene_transition = False
|
| 195 |
-
for i in range(len(frames) - 1):
|
| 196 |
-
diff = np.abs(frames[i].astype(np.float32) - frames[i + 1].astype(np.float32)).mean()
|
| 197 |
-
if diff > 80.0 / 255.0:
|
| 198 |
-
scene_transition = True
|
| 199 |
-
break
|
| 200 |
-
|
| 201 |
ens.reset_cache()
|
| 202 |
with torch.no_grad():
|
| 203 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
|
@@ -210,51 +202,55 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 210 |
direct_flipped = torch.flip(direct_flipped, dims=[4])
|
| 211 |
direct_pred = (direct_orig + direct_flipped) / 2.0
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
ctx = context_tensor.clone()
|
| 222 |
-
ctx_flip = context_flipped.clone()
|
| 223 |
-
last_t = last_tensor.clone()
|
| 224 |
-
last_f = last_flipped.clone()
|
| 225 |
-
for step in range(PRED_FRAMES):
|
| 226 |
-
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 227 |
-
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 228 |
-
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
|
| 229 |
-
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
|
| 230 |
-
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 231 |
-
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 232 |
-
ar_preds_run.append(ar_frame)
|
| 233 |
-
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 234 |
-
ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
|
| 235 |
-
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 236 |
-
last_t = ar_orig
|
| 237 |
-
ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 238 |
-
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 239 |
-
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 240 |
-
last_f = ar_flip
|
| 241 |
-
all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
|
| 242 |
-
|
| 243 |
-
ar_pred = sum(all_ar_runs) / len(all_ar_runs)
|
| 244 |
-
|
| 245 |
-
predicted = torch.zeros_like(direct_pred)
|
| 246 |
for step in range(PRED_FRAMES):
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
predicted_np = predicted[0].cpu().numpy()
|
| 252 |
ens.direct_cache = []
|
|
|
|
|
|
|
| 253 |
for i in range(PRED_FRAMES):
|
| 254 |
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 255 |
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
| 256 |
ens.direct_cache.append(frame)
|
| 257 |
|
|
|
|
|
|
|
|
|
|
| 258 |
result = ens.direct_cache[ens.cache_step]
|
| 259 |
ens.cache_step += 1
|
| 260 |
return result
|
|
|
|
| 190 |
ens.reset_cache()
|
| 191 |
return result
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
ens.reset_cache()
|
| 194 |
with torch.no_grad():
|
| 195 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
|
|
|
| 202 |
direct_flipped = torch.flip(direct_flipped, dims=[4])
|
| 203 |
direct_pred = (direct_orig + direct_flipped) / 2.0
|
| 204 |
|
| 205 |
+
# Multi-run AR with noise diversity
|
| 206 |
+
all_ar_runs = []
|
| 207 |
+
for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
|
| 208 |
+
ar_preds_run = []
|
| 209 |
+
ctx = context_tensor.clone()
|
| 210 |
+
ctx_flip = context_flipped.clone()
|
| 211 |
+
last_t = last_tensor.clone()
|
| 212 |
+
last_f = last_flipped.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
for step in range(PRED_FRAMES):
|
| 214 |
+
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 215 |
+
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 216 |
+
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
|
| 217 |
+
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
|
| 218 |
+
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 219 |
+
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 220 |
+
ar_preds_run.append(ar_frame)
|
| 221 |
+
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 222 |
+
ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
|
| 223 |
+
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 224 |
+
last_t = ar_orig
|
| 225 |
+
ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 226 |
+
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 227 |
+
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 228 |
+
last_f = ar_flip
|
| 229 |
+
all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
|
| 230 |
+
|
| 231 |
+
ar_pred = sum(all_ar_runs) / len(all_ar_runs)
|
| 232 |
+
|
| 233 |
+
predicted = torch.zeros_like(direct_pred)
|
| 234 |
+
for step in range(PRED_FRAMES):
|
| 235 |
+
ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
|
| 236 |
+
direct_weight = 1.0 - ar_weight
|
| 237 |
+
predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
|
| 238 |
|
| 239 |
predicted_np = predicted[0].cpu().numpy()
|
| 240 |
ens.direct_cache = []
|
| 241 |
+
last_ctx_uint8 = (last_frame * 255).clip(0, 255).astype(np.uint8)
|
| 242 |
+
catastrophic = False
|
| 243 |
for i in range(PRED_FRAMES):
|
| 244 |
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 245 |
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
| 246 |
+
diff = np.abs(frame.astype(np.float32) - last_ctx_uint8.astype(np.float32)).mean()
|
| 247 |
+
if diff > 100:
|
| 248 |
+
catastrophic = True
|
| 249 |
ens.direct_cache.append(frame)
|
| 250 |
|
| 251 |
+
if catastrophic:
|
| 252 |
+
ens.direct_cache = [last_ctx_uint8.copy() for _ in range(PRED_FRAMES)]
|
| 253 |
+
|
| 254 |
result = ens.direct_cache[ens.cache_step]
|
| 255 |
ens.cache_step += 1
|
| 256 |
return result
|