YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

al_0.5_g_0.99_x150

~150 RL agent checkpoints trained on the JaxGMG maze environment with alpha=0.5 and discount_rate=0.99. Hyperparams chosen to straddle the boundary of when the model does and does not enter phase 2 before converging to optimal policy (phase 3).

Models are collected from multiple training campaigns with different step counts:

Seeds Steps Source YAML
100-104 10B jaxgmg2_3phase_fast_tight train_10B.yaml
105-150 5B jaxgmg2_3phase_fast_tight (extended) train_5B.yaml
200-299 10B jaxgmg2_phase2_edge_x100 train_10B.yaml

Seeds 148 and 267 excluded (massive loss spikes). They are renamed BAD_al_0.5_g_0.99_seed_148 and BAD_al_0.5_g_0.99_seed_267.

Should you want a yaml for all models that doesn't reflect the 5B/10B env step discrepancy, use train_all.yaml

Shared Hyperparams

rl_action=train
alpha=0.5
discount_rate=0.99
lr=5e-05
num_rollout_steps=64
num_levels=9600
cheese_loc=any
env_layout=open
env_size=13
mask_type=first_episode
use_prev_action=False
grad_acc_per_chunk=4
log_optimizer_state=True
eval_schedule=0:1,250:2,500:5,2000:10
f_str_ckpt=al_{alpha}_g_{discount_rate}_seed_{seed}

Naming Schema

Checkpoints are named al_0.5_g_0.99_seed_{seed}.

Reproduced with

timaeus run train_5B.yaml
timaeus run train_10B.yaml

from the timaeus monorepo.

WandB

https://wandb.ai/devinterp/jaxgmg2_3phase_fast_tight

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including timaeus/al_0.5_g_0.99_x150