Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +4 -4
- sweep.py +30 -30
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -106,11 +106,11 @@ def load_model(model_dir: str):
|
|
| 106 |
return ens
|
| 107 |
|
| 108 |
|
| 109 |
-
def _predict_8frames_direct(model, context_tensor, last_tensor):
|
| 110 |
output = model(context_tensor)
|
| 111 |
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 112 |
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 113 |
-
return torch.clamp(last_expanded + residuals, 0, 1)
|
| 114 |
|
| 115 |
|
| 116 |
def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
|
|
@@ -263,10 +263,10 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 263 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 264 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 265 |
|
| 266 |
-
predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor)
|
| 267 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 268 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 269 |
-
predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped)
|
| 270 |
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 271 |
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 272 |
|
|
|
|
| 106 |
return ens
|
| 107 |
|
| 108 |
|
| 109 |
+
def _predict_8frames_direct(model, context_tensor, last_tensor, residual_scale=1.0):
|
| 110 |
output = model(context_tensor)
|
| 111 |
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 112 |
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 113 |
+
return torch.clamp(last_expanded + residual_scale * residuals, 0, 1)
|
| 114 |
|
| 115 |
|
| 116 |
def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
|
|
|
|
| 263 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 264 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 265 |
|
| 266 |
+
predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale=0.97)
|
| 267 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 268 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 269 |
+
predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale=0.97)
|
| 270 |
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 271 |
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 272 |
|
sweep.py
CHANGED
|
@@ -1,50 +1,50 @@
|
|
| 1 |
-
"""Sweep
|
| 2 |
import subprocess
|
| 3 |
import json
|
| 4 |
import re
|
| 5 |
|
| 6 |
-
predict_path = "/home/coder/experiments/2026-04-12-
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
results = {}
|
| 9 |
-
for
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
r'np\.round\(frame \* 255 \+ [\d.]+\)',
|
| 23 |
-
f'np.round(frame * 255 + {bias})',
|
| 24 |
-
content
|
| 25 |
-
)
|
| 26 |
-
# Then handle the case where there's no bias (from bias=0.0 step)
|
| 27 |
-
content = re.sub(
|
| 28 |
-
r'np\.round\(frame \* 255\)\.clip',
|
| 29 |
-
f'np.round(frame * 255 + {bias}).clip',
|
| 30 |
-
content
|
| 31 |
-
)
|
| 32 |
|
| 33 |
with open(predict_path, 'w') as f:
|
| 34 |
f.write(content)
|
| 35 |
|
| 36 |
result = subprocess.run(
|
| 37 |
-
['python', 'task/score.py', '--model_path', '/home/coder/experiments/2026-04-12-
|
| 38 |
capture_output=True, text=True, cwd='/home/coder'
|
| 39 |
)
|
| 40 |
|
| 41 |
for line in result.stdout.strip().split('\n'):
|
| 42 |
if '"score"' in line:
|
| 43 |
data = json.loads(line)
|
| 44 |
-
results[
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
break
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
print("\n=== Summary ===")
|
| 49 |
-
|
| 50 |
-
print(f"Best
|
|
|
|
| 1 |
+
"""Sweep PP direct residual scale."""
|
| 2 |
import subprocess
|
| 3 |
import json
|
| 4 |
import re
|
| 5 |
|
| 6 |
+
predict_path = "/home/coder/experiments/2026-04-12-342000-pp-residual-tune/predict.py"
|
| 7 |
+
|
| 8 |
+
with open(predict_path, 'r') as f:
|
| 9 |
+
original = f.read()
|
| 10 |
|
| 11 |
results = {}
|
| 12 |
+
for scale in [0.96, 0.97, 0.98, 0.99, 1.00, 1.01, 1.02, 1.03, 1.04]:
|
| 13 |
+
content = original
|
| 14 |
+
# Replace PP direct calls (both orig and flipped)
|
| 15 |
+
content = re.sub(
|
| 16 |
+
r'_predict_8frames_direct\(ens\.models\["pole_position"\], context_tensor, last_tensor\)',
|
| 17 |
+
f'_predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale={scale})',
|
| 18 |
+
content
|
| 19 |
+
)
|
| 20 |
+
content = re.sub(
|
| 21 |
+
r'_predict_8frames_direct\(ens\.models\["pole_position"\], context_flipped, last_flipped\)',
|
| 22 |
+
f'_predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale={scale})',
|
| 23 |
+
content
|
| 24 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
with open(predict_path, 'w') as f:
|
| 27 |
f.write(content)
|
| 28 |
|
| 29 |
result = subprocess.run(
|
| 30 |
+
['python', 'task/score.py', '--model_path', '/home/coder/experiments/2026-04-12-342000-pp-residual-tune'],
|
| 31 |
capture_output=True, text=True, cwd='/home/coder'
|
| 32 |
)
|
| 33 |
|
| 34 |
for line in result.stdout.strip().split('\n'):
|
| 35 |
if '"score"' in line:
|
| 36 |
data = json.loads(line)
|
| 37 |
+
results[scale] = {
|
| 38 |
+
'score': data['score'],
|
| 39 |
+
'pp': data['per_game']['pole_position']['ssim']
|
| 40 |
+
}
|
| 41 |
+
print(f"Scale {scale}: overall={data['score']:.4f} pp={data['per_game']['pole_position']['ssim']:.4f}")
|
| 42 |
break
|
| 43 |
|
| 44 |
+
# Restore
|
| 45 |
+
with open(predict_path, 'w') as f:
|
| 46 |
+
f.write(original)
|
| 47 |
+
|
| 48 |
print("\n=== Summary ===")
|
| 49 |
+
best_scale = max(results.keys(), key=lambda s: results[s]['pp'])
|
| 50 |
+
print(f"Best PP scale: {best_scale} with pp={results[best_scale]['pp']:.4f}, overall={results[best_scale]['score']:.4f}")
|