ojaffe commited on
Commit
d4efc46
·
verified ·
1 Parent(s): 5ab9cda

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __pycache__/predict.cpython-311.pyc +0 -0
  2. predict.py +40 -44
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -190,14 +190,6 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
190
  ens.reset_cache()
191
  return result
192
 
193
- # Detect extreme scene transitions (threshold 80 on 0-255 scale)
194
- scene_transition = False
195
- for i in range(len(frames) - 1):
196
- diff = np.abs(frames[i].astype(np.float32) - frames[i + 1].astype(np.float32)).mean()
197
- if diff > 80.0 / 255.0:
198
- scene_transition = True
199
- break
200
-
201
  ens.reset_cache()
202
  with torch.no_grad():
203
  context_tensor = torch.from_numpy(context).to(DEVICE)
@@ -210,51 +202,55 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
210
  direct_flipped = torch.flip(direct_flipped, dims=[4])
211
  direct_pred = (direct_orig + direct_flipped) / 2.0
212
 
213
- if scene_transition:
214
- # Extreme scene transition: direct-only
215
- predicted = direct_pred
216
- else:
217
- # Normal: full AR+direct blend with noise diversity
218
- all_ar_runs = []
219
- for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
220
- ar_preds_run = []
221
- ctx = context_tensor.clone()
222
- ctx_flip = context_flipped.clone()
223
- last_t = last_tensor.clone()
224
- last_f = last_flipped.clone()
225
- for step in range(PRED_FRAMES):
226
- ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
227
- ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
228
- ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
229
- ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
230
- ar_flip_back = torch.flip(ar_flip, dims=[3])
231
- ar_frame = (ar_orig + ar_flip_back) / 2.0
232
- ar_preds_run.append(ar_frame)
233
- ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
234
- ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
235
- ctx = ctx_frames.reshape(1, -1, 64, 64)
236
- last_t = ar_orig
237
- ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
238
- ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
239
- ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
240
- last_f = ar_flip
241
- all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
242
-
243
- ar_pred = sum(all_ar_runs) / len(all_ar_runs)
244
-
245
- predicted = torch.zeros_like(direct_pred)
246
  for step in range(PRED_FRAMES):
247
- ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
248
- direct_weight = 1.0 - ar_weight
249
- predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  predicted_np = predicted[0].cpu().numpy()
252
  ens.direct_cache = []
 
 
253
  for i in range(PRED_FRAMES):
254
  frame = np.transpose(predicted_np[i], (1, 2, 0))
255
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
 
 
 
256
  ens.direct_cache.append(frame)
257
 
 
 
 
258
  result = ens.direct_cache[ens.cache_step]
259
  ens.cache_step += 1
260
  return result
 
190
  ens.reset_cache()
191
  return result
192
 
 
 
 
 
 
 
 
 
193
  ens.reset_cache()
194
  with torch.no_grad():
195
  context_tensor = torch.from_numpy(context).to(DEVICE)
 
202
  direct_flipped = torch.flip(direct_flipped, dims=[4])
203
  direct_pred = (direct_orig + direct_flipped) / 2.0
204
 
205
+ # Multi-run AR with noise diversity
206
+ all_ar_runs = []
207
+ for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
208
+ ar_preds_run = []
209
+ ctx = context_tensor.clone()
210
+ ctx_flip = context_flipped.clone()
211
+ last_t = last_tensor.clone()
212
+ last_f = last_flipped.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  for step in range(PRED_FRAMES):
214
+ ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
215
+ ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
216
+ ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
217
+ ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
218
+ ar_flip_back = torch.flip(ar_flip, dims=[3])
219
+ ar_frame = (ar_orig + ar_flip_back) / 2.0
220
+ ar_preds_run.append(ar_frame)
221
+ ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
222
+ ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
223
+ ctx = ctx_frames.reshape(1, -1, 64, 64)
224
+ last_t = ar_orig
225
+ ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
226
+ ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
227
+ ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
228
+ last_f = ar_flip
229
+ all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
230
+
231
+ ar_pred = sum(all_ar_runs) / len(all_ar_runs)
232
+
233
+ predicted = torch.zeros_like(direct_pred)
234
+ for step in range(PRED_FRAMES):
235
+ ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
236
+ direct_weight = 1.0 - ar_weight
237
+ predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
238
 
239
  predicted_np = predicted[0].cpu().numpy()
240
  ens.direct_cache = []
241
+ last_ctx_uint8 = (last_frame * 255).clip(0, 255).astype(np.uint8)
242
+ catastrophic = False
243
  for i in range(PRED_FRAMES):
244
  frame = np.transpose(predicted_np[i], (1, 2, 0))
245
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
246
+ diff = np.abs(frame.astype(np.float32) - last_ctx_uint8.astype(np.float32)).mean()
247
+ if diff > 100:
248
+ catastrophic = True
249
  ens.direct_cache.append(frame)
250
 
251
+ if catastrophic:
252
+ ens.direct_cache = [last_ctx_uint8.copy() for _ in range(PRED_FRAMES)]
253
+
254
  result = ens.direct_cache[ens.cache_step]
255
  ens.cache_step += 1
256
  return result