jaxgmg2_3phase_unique / train.yaml
dquarel's picture
Update train.yaml with full alpha/discount_rate grid, add train_missing.yaml
77cc104
parameters:
project_name: jaxgmg2_3phase_unique
action: rl
rl_action: train
lr: 5e-5
cheese_loc: any
env_layout: open
mask_type: first_episode
use_prev_action: false
log_optimizer_state: false
num_total_env_steps: 10_000_000_000
num_levels: 9600
grad_acc_per_chunk: 5
num_rollout_steps: 64
seed_formula: "{int(discount_rate*100):02d}{int(alpha*10):02d}{run_id:02d}"
ckpt_dir: jaxgmg2_3phase_unique
f_str_ckpt: "al_{alpha}_g_{discount_rate}_id_{run_id}_seed_{seed}"
eval_schedule: "0:1,250:2,500:5,2000:10"
wandb_project: jaxgmg2_3phase_unique
use_wandb: true
use_hf: true
no_tqdm: true
ntfy: david_jaxgmg
sweep:
- - alpha: 0.4
- alpha: 0.5
- alpha: 0.6
- alpha: 0.7
- alpha: 1.0
- - discount_rate: 0.97
- discount_rate: 0.98
- discount_rate: 0.99
- - run_id: 0
- run_id: 1
- run_id: 2
- run_id: 3
- run_id: 4
- run_id: 5
- run_id: 6
- run_id: 7
- run_id: 8
- run_id: 9
- run_id: 10
- run_id: 11
- run_id: 12
- run_id: 13
- run_id: 14