ojaffe commited on
Commit
339daa1
·
verified ·
1 Parent(s): 3af9de9

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __pycache__/predict.cpython-311.pyc +0 -0
  2. predict.py +18 -15
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -195,17 +195,24 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
195
  context_tensor = torch.from_numpy(context).to(DEVICE)
196
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
197
 
198
- direct_orig = _predict_8frames_direct(ens.sonic_direct, context_tensor, last_tensor)
199
  context_flipped = torch.flip(context_tensor, dims=[3])
200
  last_flipped = torch.flip(last_tensor, dims=[3])
201
- direct_flipped = _predict_8frames_direct(ens.sonic_direct, context_flipped, last_flipped)
202
- direct_flipped = torch.flip(direct_flipped, dims=[4])
203
- direct_pred = (direct_orig + direct_flipped) / 2.0
204
 
205
- # Multi-run AR with noise diversity - collect all 6 predictions per step
206
- # (3 noise levels x 2 TTA directions) and take per-pixel median
207
- all_step_preds = [[] for _ in range(PRED_FRAMES)]
 
 
 
 
 
 
 
 
 
 
208
  for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
 
209
  ctx = context_tensor.clone()
210
  ctx_flip = context_flipped.clone()
211
  last_t = last_tensor.clone()
@@ -216,8 +223,8 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
216
  ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
217
  ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
218
  ar_flip_back = torch.flip(ar_flip, dims=[3])
219
- all_step_preds[step].append(ar_orig)
220
- all_step_preds[step].append(ar_flip_back)
221
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
222
  ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
223
  ctx = ctx_frames.reshape(1, -1, 64, 64)
@@ -226,13 +233,9 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
226
  ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
227
  ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
228
  last_f = ar_flip
 
229
 
230
- ar_pred_list = []
231
- for step in range(PRED_FRAMES):
232
- stacked = torch.stack(all_step_preds[step], dim=0) # [6, 1, 3, 64, 64]
233
- median_val = torch.median(stacked, dim=0).values
234
- ar_pred_list.append(median_val)
235
- ar_pred = torch.stack(ar_pred_list, dim=1) # [1, 8, 3, 64, 64]
236
 
237
  predicted = torch.zeros_like(direct_pred)
238
  for step in range(PRED_FRAMES):
 
195
  context_tensor = torch.from_numpy(context).to(DEVICE)
196
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
197
 
 
198
  context_flipped = torch.flip(context_tensor, dims=[3])
199
  last_flipped = torch.flip(last_tensor, dims=[3])
 
 
 
200
 
201
+ # Multi-run direct with noise diversity
202
+ all_direct_runs = []
203
+ for noise_std in [0.0, 0.5/255.0, 1.0/255.0]:
204
+ ctx_in = context_tensor if noise_std == 0 else torch.clamp(context_tensor + torch.randn_like(context_tensor) * noise_std, 0, 1)
205
+ ctx_flip_in = context_flipped if noise_std == 0 else torch.clamp(context_flipped + torch.randn_like(context_flipped) * noise_std, 0, 1)
206
+ direct_orig = _predict_8frames_direct(ens.sonic_direct, ctx_in, last_tensor)
207
+ direct_flipped = _predict_8frames_direct(ens.sonic_direct, ctx_flip_in, last_flipped)
208
+ direct_flipped = torch.flip(direct_flipped, dims=[4])
209
+ all_direct_runs.append((direct_orig + direct_flipped) / 2.0)
210
+ direct_pred = sum(all_direct_runs) / len(all_direct_runs)
211
+
212
+ # Multi-run AR with noise diversity
213
+ all_ar_runs = []
214
  for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
215
+ ar_preds_run = []
216
  ctx = context_tensor.clone()
217
  ctx_flip = context_flipped.clone()
218
  last_t = last_tensor.clone()
 
223
  ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
224
  ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
225
  ar_flip_back = torch.flip(ar_flip, dims=[3])
226
+ ar_frame = (ar_orig + ar_flip_back) / 2.0
227
+ ar_preds_run.append(ar_frame)
228
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
229
  ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
230
  ctx = ctx_frames.reshape(1, -1, 64, 64)
 
233
  ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
234
  ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
235
  last_f = ar_flip
236
+ all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
237
 
238
+ ar_pred = sum(all_ar_runs) / len(all_ar_runs)
 
 
 
 
 
239
 
240
  predicted = torch.zeros_like(direct_pred)
241
  for step in range(PRED_FRAMES):