Project: RL1/RL2 (obsolete)
Collection
Older models that are no longer useful for anything in RL1 or RL2, or are now unused as experimentation discontinued. • 16 items • Updated
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
A series of models where the network was duplicated, and all the states where the mouse was south-east of the cheese were routed to one model, and the rest of the states were routed to the other model. Idea was to try and see if each half of the two-headed (Janus) model learned a different algorithm. Was considered a dead-end and this path was abandonded.
Trained on commit 7f1a28d99612f08bee9474306f907c8197c009eb
Wandb: https://wandb.ai/devinterp/jaxgmg_janus
Ran with
wandb sweep sweep.yaml
Contents of sweep.yaml
command:
- env
- WANDB_AGENT_MAX_INITIAL_FAILURES=1
- PYTHONPATH=/root/timaeus/projects/rl
- /root/timaeus/.venv/bin/python
- ${program}
- ${args}
- --use-wandb
- --use-hf
- --no-compile-pad
- --no-trim-episodes
entity: devinterp
method: grid
parameters:
alpha:
value: 1
cheese-loc:
value: any
ckpt-dir:
value: jaxgmg_janus
discount-rate:
values:
- 0.9
- 0.87
- 0.85
- 0.8
eval-schedule:
value: "0:20"
f-str-ckpt:
value: al_{alpha}_g_{discount_rate}_id_{run_id}_seed_{seed}
grad-acc-per-chunk:
value: 4
lr:
value: 5e-05
num-levels:
value: 9600
num-prev-actions:
value: 1
num-rollout-steps:
value: 64
num-total-env-steps:
value: 5000000000
run-id:
values:
- 0
- 1
- 2
- 3
- 4
seed-formula:
value: "{int(discount_rate*100):02d}{run_id:02d}"
wandb-project:
value: jaxgmg_janus
program: /root/timaeus/projects/rl/janus/main_train_janus.py
project: jaxgmg_janus
Hyperparams swept
discount_rate = ?
run_id = ?
seed = ? (derived from alpha/discount_rate/run_id/seed)
Shared Hyperparams Used
rl_action=train
num_rollout_steps=64
lr=5e-05
eff_horizon=None
eval_every=1
use_wandb=True
use_hf=True
use_log=True
num_total_env_steps=5000000000
checkpoint=al_1.0_g_0.85_id_1_seed_8501
render_sixel=False
sixel_idx=60
seed=8501
seed_formula={int(discount_rate*100):02d}{run_id:02d}
mask_type=first_episode
penalize_time=False
optim=adam
live_monitor=False
use_bf16=False
deterministic=True
eval_schedule=0:20
grad_acc_per_chunk=4
num_rollout_chunks=1
cheese_loc=any
env_rule=None
env_layout=open
alpha=1.0
env_size=13
num_levels=9600
f_str_ckpt=al_{alpha}_g_{discount_rate}_id_{run_id}_seed_{seed}
wandb_project=jaxgmg_janus
ckpt_dir=jaxgmg_janus
duplication_factor=-1
smoke=False
compile=True
num_chains=6
num_draws=3000
num_steps_bw_draws=1
on_policy=True
llc_nbeta=3000
localization=10
exact_solver_each_draw=False
llc_optimizer=sgld
iw_clip_eps=None
rmsprop_burnin_steps=20
llc_data_file=llc_scan_open_reinforce.pkl
llc_checkpoint_index=None
llc_checkpoint_number=None
sink=None
repo_id=davidquarel/jaxgmg_ckpt_zip
use_shuffled_checkpoints=False
force_re_download=False
off_distribution_data=False
weight_restrictions=None
weight_restrictions_invert=False
evaluate_every_position=False
num_prev_actions=1
ntfy=None
vis_average_state=False
trim_episodes=False
use_prev_action=False
ckpt_path=jaxgmg_janus/al_1.0_g_0.85_id_1_seed_8501
env_steps_per_loop=614400
total_loops=8138
eff_acc_steps=4
env_steps_per_microbatch=153600
chunk_size=9600
compile_pad=False
compile_bucket_size=1024
resume=None