| import os |
|
|
| import jax |
| from flax.training.train_state import TrainState |
| import optax |
| import orbax.checkpoint as ocp |
|
|
| from jaxgmg.procgen import maze_generation |
| from jaxgmg.environments import cheese_in_the_corner |
| from jaxgmg.baselines import networks |
|
|
|
|
| |
|
|
| SEED = 42 |
| |
| CHECKPOINT_FOLDER: str = "2ypu90e7" |
| CHECKPOINT_NUMBER: int = 7168 |
| |
| ENV = cheese_in_the_corner.Env( |
| obs_level_of_detail=0, |
| img_level_of_detail=1, |
| penalize_time=False, |
| terminate_after_cheese_and_corner=False, |
| ) |
| ARBITRARY_LEVEL_GENERATOR = cheese_in_the_corner.LevelGenerator( |
| height=15, |
| width=15, |
| maze_generator=maze_generation.get_generator_class_from_name("blocks")(), |
| corner_size=1, |
| ) |
| |
| NET_CNN_TYPE: str = "large" |
| NET_RNN_TYPE: str = "ff" |
| NET_WIDTH: int = 256 |
|
|
|
|
|
|
| |
|
|
| rng = jax.random.PRNGKey(seed=SEED) |
| rng_setup, rng_eval = jax.random.split(rng) |
|
|
| |
| net = networks.Impala( |
| num_actions=ENV.num_actions, |
| cnn_type=NET_CNN_TYPE, |
| rnn_type=NET_RNN_TYPE, |
| width=NET_WIDTH, |
| ) |
|
|
| |
| rng_model_init, rng_setup = jax.random.split(rng_setup) |
| rng_example_level, rng_setup = jax.random.split(rng_setup) |
| example_level = ARBITRARY_LEVEL_GENERATOR.sample(rng_example_level) |
| net_init_params, net_init_state = net.init_params_and_state( |
| rng=rng_model_init, |
| obs_type=ENV.obs_type(level=example_level), |
| ) |
|
|
| |
| checkpoint_manager = ocp.CheckpointManager( |
| directory=os.path.abspath(CHECKPOINT_FOLDER), |
| options=ocp.CheckpointManagerOptions( |
| max_to_keep=None, |
| save_interval_steps=1, |
| ), |
| ) |
|
|
| |
| net_params = checkpoint_manager.restore( |
| CHECKPOINT_NUMBER, |
| args=ocp.args.PyTreeRestore( |
| net_init_params, |
| restore_args=ocp.checkpoint_utils.construct_restore_args(net_init_params), |
| ) |
| ) |
|
|
| |
| train_state = TrainState.create( |
| apply_fn=net.apply, |
| params=net_params, |
| tx=optax.sgd(learning_rate=0), |
| ) |
|
|
| print(train_state) |
|
|