ojaffe commited on
Commit
9633935
·
verified ·
1 Parent(s): 571442e

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. __pycache__/predict.cpython-311.pyc +0 -0
  2. predict.py +4 -4
  3. sweep.py +30 -30
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -106,11 +106,11 @@ def load_model(model_dir: str):
106
  return ens
107
 
108
 
109
- def _predict_8frames_direct(model, context_tensor, last_tensor):
110
  output = model(context_tensor)
111
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
112
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
113
- return torch.clamp(last_expanded + residuals, 0, 1)
114
 
115
 
116
  def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
@@ -263,10 +263,10 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
263
  context_tensor = torch.from_numpy(context).to(DEVICE)
264
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
265
 
266
- predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor)
267
  context_flipped = torch.flip(context_tensor, dims=[3])
268
  last_flipped = torch.flip(last_tensor, dims=[3])
269
- predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped)
270
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
271
  predicted = (predicted_orig + predicted_flipped) / 2.0
272
 
 
106
  return ens
107
 
108
 
109
+ def _predict_8frames_direct(model, context_tensor, last_tensor, residual_scale=1.0):
110
  output = model(context_tensor)
111
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
112
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
113
+ return torch.clamp(last_expanded + residual_scale * residuals, 0, 1)
114
 
115
 
116
  def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
 
263
  context_tensor = torch.from_numpy(context).to(DEVICE)
264
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
265
 
266
+ predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale=0.97)
267
  context_flipped = torch.flip(context_tensor, dims=[3])
268
  last_flipped = torch.flip(last_tensor, dims=[3])
269
+ predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale=0.97)
270
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
271
  predicted = (predicted_orig + predicted_flipped) / 2.0
272
 
sweep.py CHANGED
@@ -1,50 +1,50 @@
1
- """Sweep rounding bias."""
2
  import subprocess
3
  import json
4
  import re
5
 
6
- predict_path = "/home/coder/experiments/2026-04-12-332000-bias-resweep/predict.py"
 
 
 
7
 
8
  results = {}
9
- for bias in [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.40, 0.50]:
10
- with open(predict_path, 'r') as f:
11
- content = f.read()
12
-
13
- if bias == 0.0:
14
- content = re.sub(
15
- r'np\.round\(frame \* 255 \+ [\d.]+\)',
16
- 'np.round(frame * 255)',
17
- content
18
- )
19
- else:
20
- # First handle the case where there's already a bias
21
- content = re.sub(
22
- r'np\.round\(frame \* 255 \+ [\d.]+\)',
23
- f'np.round(frame * 255 + {bias})',
24
- content
25
- )
26
- # Then handle the case where there's no bias (from bias=0.0 step)
27
- content = re.sub(
28
- r'np\.round\(frame \* 255\)\.clip',
29
- f'np.round(frame * 255 + {bias}).clip',
30
- content
31
- )
32
 
33
  with open(predict_path, 'w') as f:
34
  f.write(content)
35
 
36
  result = subprocess.run(
37
- ['python', 'task/score.py', '--model_path', '/home/coder/experiments/2026-04-12-332000-bias-resweep'],
38
  capture_output=True, text=True, cwd='/home/coder'
39
  )
40
 
41
  for line in result.stdout.strip().split('\n'):
42
  if '"score"' in line:
43
  data = json.loads(line)
44
- results[bias] = data['score']
45
- print(f"Bias {bias:.2f}: overall={data['score']:.4f} pong={data['per_game']['pong']['ssim']:.4f} sonic={data['per_game']['sonic']['ssim']:.4f} pp={data['per_game']['pole_position']['ssim']:.4f}")
 
 
 
46
  break
47
 
 
 
 
 
48
  print("\n=== Summary ===")
49
- best_bias = max(results.keys(), key=lambda b: results[b])
50
- print(f"Best bias: {best_bias} with overall={results[best_bias]:.4f}")
 
1
+ """Sweep PP direct residual scale."""
2
  import subprocess
3
  import json
4
  import re
5
 
6
+ predict_path = "/home/coder/experiments/2026-04-12-342000-pp-residual-tune/predict.py"
7
+
8
+ with open(predict_path, 'r') as f:
9
+ original = f.read()
10
 
11
  results = {}
12
+ for scale in [0.96, 0.97, 0.98, 0.99, 1.00, 1.01, 1.02, 1.03, 1.04]:
13
+ content = original
14
+ # Replace PP direct calls (both orig and flipped)
15
+ content = re.sub(
16
+ r'_predict_8frames_direct\(ens\.models\["pole_position"\], context_tensor, last_tensor\)',
17
+ f'_predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale={scale})',
18
+ content
19
+ )
20
+ content = re.sub(
21
+ r'_predict_8frames_direct\(ens\.models\["pole_position"\], context_flipped, last_flipped\)',
22
+ f'_predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale={scale})',
23
+ content
24
+ )
 
 
 
 
 
 
 
 
 
 
25
 
26
  with open(predict_path, 'w') as f:
27
  f.write(content)
28
 
29
  result = subprocess.run(
30
+ ['python', 'task/score.py', '--model_path', '/home/coder/experiments/2026-04-12-342000-pp-residual-tune'],
31
  capture_output=True, text=True, cwd='/home/coder'
32
  )
33
 
34
  for line in result.stdout.strip().split('\n'):
35
  if '"score"' in line:
36
  data = json.loads(line)
37
+ results[scale] = {
38
+ 'score': data['score'],
39
+ 'pp': data['per_game']['pole_position']['ssim']
40
+ }
41
+ print(f"Scale {scale}: overall={data['score']:.4f} pp={data['per_game']['pole_position']['ssim']:.4f}")
42
  break
43
 
44
+ # Restore
45
+ with open(predict_path, 'w') as f:
46
+ f.write(original)
47
+
48
  print("\n=== Summary ===")
49
+ best_scale = max(results.keys(), key=lambda s: results[s]['pp'])
50
+ print(f"Best PP scale: {best_scale} with pp={results[best_scale]['pp']:.4f}, overall={results[best_scale]['score']:.4f}")