| """ |
| Author: Minh Pham-Dinh |
| Created: Feb 4th, 2024 |
| Last Modified: Feb 6th, 2024 |
| Email: mhpham26@colby.edu |
| |
| Description: |
| Imagination file. Run this file to generate dream sequences |
| """ |
|
|
| import sys |
| import argparse |
| from utils.wrappers import DMCtoGymWrapper, AtariPreprocess |
| from addict import Dict |
| import yaml |
| import gymnasium as gym |
| import torch |
| from tqdm import tqdm |
| import numpy as np |
| import glob |
|
|
| parser = argparse.ArgumentParser(description='Process configuration file path.') |
| parser.add_argument('--runpath', type=str, help='Path to the run file.', required=True) |
| parser.add_argument('--horizon', type=int, help='number of imagination steps.', default=15) |
|
|
| |
| args = parser.parse_args() |
|
|
| |
| run_path = args.runpath |
| HORIZON = args.horizon |
|
|
| config_files = glob.glob(run_path + '/config/*.yml') |
|
|
| if len(config_files) != 1: |
| print('there should only be 1 config file in config directory') |
|
|
| with open(config_files[0], 'r') as file: |
| config = Dict(yaml.load(file, Loader=yaml.FullLoader)) |
|
|
| env_id = config.env.env_id |
|
|
| if 'ALE' in config.env.env_id: |
| env = gym.make(env_id, render_mode='rgb_array') |
| env = AtariPreprocess(env, config.env.new_obs_size, |
| False) |
| else: |
| task = config.env.task |
| env = DMCtoGymWrapper(env_id, task, |
| resize=config.env.new_obs_size, |
| record=False) |
|
|
| print("start imagining") |
|
|
| encode = torch.load(run_path + '/models/encoder', map_location=torch.device('cpu') ) |
| decoder = torch.load(run_path + '/models/decoder', map_location=torch.device('cpu') ) |
| rssm = torch.load(run_path + '/models/rssm_model', map_location=torch.device('cpu') ) |
| actor = torch.load(run_path + '/models/actor', map_location=torch.device('cpu')) |
|
|
| obs, _ = env.reset() |
|
|
| for i in range(100): |
| obs, _, _, _, _ = env.step(env.action_space.sample()) |
|
|
| posterior = torch.zeros((1, config.main.stochastic_size)) |
| deterministic = torch.zeros((1, config.main.deterministic_size)) |
| e_obs = encode(torch.from_numpy(obs).to(dtype=torch.float)) |
|
|
| _, posterior = rssm.representation(e_obs, deterministic) |
| |
| from PIL import Image |
|
|
| frames = [] |
|
|
| for i in tqdm(range(200)): |
| actions = actor(posterior, deterministic) |
| deterministic = rssm.recurrent(posterior, actions, deterministic) |
| dist, posterior = rssm.transition(deterministic) |
| d_obs = decoder(posterior, deterministic) |
| d_obs = d_obs.mean.squeeze().detach().numpy() |
| obs = ((d_obs.transpose([1,2,0]) + 0.5) * 255).clip(0, 255).astype(np.uint8) |
| img = Image.fromarray(obs, "RGB") |
| frames.append(img) |
|
|
| print("saving gif") |
| frame_one = frames[0] |
| frame_one.save(run_path + "/imagine.gif", format="GIF", append_images=frames, save_all=True, duration=30, loop=0) |
| print("finished") |