world-model / sweep.py
ojaffe's picture
Upload folder using huggingface_hub
9633935 verified
"""Sweep PP direct residual scale."""
import subprocess
import json
import re
predict_path = "/home/coder/experiments/2026-04-12-342000-pp-residual-tune/predict.py"
with open(predict_path, 'r') as f:
original = f.read()
results = {}
for scale in [0.96, 0.97, 0.98, 0.99, 1.00, 1.01, 1.02, 1.03, 1.04]:
content = original
# Replace PP direct calls (both orig and flipped)
content = re.sub(
r'_predict_8frames_direct\(ens\.models\["pole_position"\], context_tensor, last_tensor\)',
f'_predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale={scale})',
content
)
content = re.sub(
r'_predict_8frames_direct\(ens\.models\["pole_position"\], context_flipped, last_flipped\)',
f'_predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale={scale})',
content
)
with open(predict_path, 'w') as f:
f.write(content)
result = subprocess.run(
['python', 'task/score.py', '--model_path', '/home/coder/experiments/2026-04-12-342000-pp-residual-tune'],
capture_output=True, text=True, cwd='/home/coder'
)
for line in result.stdout.strip().split('\n'):
if '"score"' in line:
data = json.loads(line)
results[scale] = {
'score': data['score'],
'pp': data['per_game']['pole_position']['ssim']
}
print(f"Scale {scale}: overall={data['score']:.4f} pp={data['per_game']['pole_position']['ssim']:.4f}")
break
# Restore
with open(predict_path, 'w') as f:
f.write(original)
print("\n=== Summary ===")
best_scale = max(results.keys(), key=lambda s: results[s]['pp'])
print(f"Best PP scale: {best_scale} with pp={results[best_scale]['pp']:.4f}, overall={results[best_scale]['score']:.4f}")