| import einops |
| import os |
| import random |
| from collections import deque |
| from pathlib import Path |
|
|
| import hydra |
| import numpy as np |
| import torch |
| import tqdm |
| from omegaconf import OmegaConf |
|
|
| import wandb |
| from utils.video import VideoRecorder |
| import pickle |
| from datasets.core import TrajectoryEmbeddingDataset, split_traj_datasets |
| from datasets.vqbet_repro import TrajectorySlicerDataset |
|
|
|
|
| if "MUJOCO_GL" not in os.environ: |
| os.environ["MUJOCO_GL"] = "egl" |
|
|
|
|
| def seed_everything(random_seed: int): |
| np.random.seed(random_seed) |
| torch.manual_seed(random_seed) |
| torch.cuda.manual_seed_all(random_seed) |
| random.seed(random_seed) |
|
|
|
|
| @hydra.main(config_path="eval_configs", version_base="1.2") |
| def main(cfg): |
| print(OmegaConf.to_yaml(cfg)) |
| seed_everything(cfg.seed) |
|
|
| encoder = hydra.utils.instantiate(cfg.encoder) |
| encoder = encoder.to(cfg.device).eval() |
|
|
| dataset = hydra.utils.instantiate(cfg.dataset) |
| train_data, test_data = split_traj_datasets( |
| dataset, |
| train_fraction=cfg.train_fraction, |
| random_seed=cfg.seed, |
| ) |
| use_libero_goal = cfg.data.get("use_libero_goal", False) |
| train_data = TrajectoryEmbeddingDataset( |
| encoder, train_data, device=cfg.device, embed_goal=use_libero_goal |
| ) |
| test_data = TrajectoryEmbeddingDataset( |
| encoder, test_data, device=cfg.device, embed_goal=use_libero_goal |
| ) |
| traj_slicer_kwargs = { |
| "window": cfg.data.window_size, |
| "action_window": cfg.data.action_window_size, |
| "vqbet_get_future_action_chunk": cfg.data.vqbet_get_future_action_chunk, |
| "future_conditional": (cfg.data.goal_conditional == "future"), |
| "min_future_sep": cfg.data.action_window_size, |
| "future_seq_len": cfg.data.future_seq_len, |
| "use_libero_goal": use_libero_goal, |
| } |
| train_data = TrajectorySlicerDataset(train_data, **traj_slicer_kwargs) |
| test_data = TrajectorySlicerDataset(test_data, **traj_slicer_kwargs) |
| train_loader = torch.utils.data.DataLoader( |
| train_data, batch_size=cfg.batch_size, shuffle=True, pin_memory=False |
| ) |
| test_loader = torch.utils.data.DataLoader( |
| test_data, batch_size=cfg.batch_size, shuffle=False, pin_memory=False |
| ) |
| for param in encoder.parameters(): |
| param.requires_grad = False |
| encoder.eval() |
|
|
| cbet_model = hydra.utils.instantiate(cfg.model).to(cfg.device) |
| optimizer = cbet_model.configure_optimizers( |
| weight_decay=cfg.optim.weight_decay, |
| learning_rate=cfg.optim.lr, |
| betas=cfg.optim.betas, |
| ) |
| env = hydra.utils.instantiate(cfg.env.gym) |
| if "use_libero_goal" in cfg.data: |
| with torch.no_grad(): |
| |
| goals_cache = [] |
| for i in range(10): |
| idx = i * 50 |
| last_obs, _, _ = dataset.get_frames(idx, [-1]) |
| last_obs = last_obs.to(cfg.device) |
| embd = encoder(last_obs)[0] |
| embd = einops.rearrange(embd, "V E -> (V E)") |
| goals_cache.append(embd) |
|
|
| def goal_fn(goal_idx): |
| return goals_cache[goal_idx] |
| else: |
| empty_tensor = torch.zeros(1) |
|
|
| def goal_fn(goal_idx): |
| return empty_tensor |
|
|
| run = wandb.init( |
| project=cfg.wandb.project, |
| entity=cfg.wandb.entity, |
| config=OmegaConf.to_container(cfg, resolve=True), |
| ) |
| run_name = run.name or "Offline" |
| save_path = Path(cfg.save_path) / run_name |
| save_path.mkdir(parents=True, exist_ok=False) |
| video = VideoRecorder(dir_name=save_path) |
|
|
| @torch.no_grad() |
| def eval_on_env( |
| cfg, |
| num_evals=cfg.num_env_evals, |
| num_eval_per_goal=1, |
| videorecorder=None, |
| epoch=None, |
| ): |
| def embed(enc, obs): |
| obs = ( |
| torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(cfg.device) |
| ) |
| result = enc(obs) |
| result = einops.rearrange(result, "1 V E -> (V E)") |
| return result |
|
|
| avg_reward = 0 |
| action_list = [] |
| completion_id_list = [] |
| avg_max_coverage = [] |
| avg_final_coverage = [] |
| env.seed(cfg.seed) |
| for goal_idx in range(num_evals): |
| if videorecorder is not None: |
| videorecorder.init(enabled=True) |
| for i in range(num_eval_per_goal): |
| obs_stack = deque(maxlen=cfg.eval_window_size) |
| this_obs = env.reset(goal_idx=goal_idx) |
| assert ( |
| this_obs.min() >= 0 and this_obs.max() <= 1 |
| ), "expect 0-1 range observation" |
| this_obs_enc = embed(encoder, this_obs) |
| obs_stack.append(this_obs_enc) |
| done, step, total_reward = False, 0, 0 |
| goal = goal_fn(goal_idx) |
| while not done: |
| obs = torch.stack(tuple(obs_stack)).float().to(cfg.device) |
| goal = torch.as_tensor(goal, dtype=torch.float32, device=cfg.device) |
| |
| goal = goal.unsqueeze(0).repeat(cfg.eval_window_size, 1) |
| action, _, _ = cbet_model(obs.unsqueeze(0), goal.unsqueeze(0), None) |
| action = action[0] |
| if cfg.action_window_size > 1: |
| action_list.append(action[-1].cpu().detach().numpy()) |
| if len(action_list) > cfg.action_window_size: |
| action_list = action_list[1:] |
| curr_action = np.array(action_list) |
| curr_action = ( |
| np.sum(curr_action, axis=0)[0] / curr_action.shape[0] |
| ) |
| new_action_list = [] |
| for a_chunk in action_list: |
| new_action_list.append( |
| np.concatenate( |
| (a_chunk[1:], np.zeros((1, a_chunk.shape[1]))) |
| ) |
| ) |
| action_list = new_action_list |
| else: |
| curr_action = action[-1, 0, :].cpu().detach().numpy() |
|
|
| this_obs, reward, done, info = env.step(curr_action) |
| this_obs_enc = embed(encoder, this_obs) |
| obs_stack.append(this_obs_enc) |
|
|
| if videorecorder.enabled: |
| videorecorder.record(info["image"]) |
| step += 1 |
| total_reward += reward |
| goal = goal_fn(goal_idx) |
| avg_reward += total_reward |
| if cfg.env.gym.id == "pusht": |
| env.env._seed += 1 |
| avg_max_coverage.append(info["max_coverage"]) |
| avg_final_coverage.append(info["final_coverage"]) |
| elif cfg.env.gym.id == "blockpush": |
| avg_max_coverage.append(info["moved"]) |
| avg_final_coverage.append(info["entered"]) |
| completion_id_list.append(info["all_completions_ids"]) |
| videorecorder.save("eval_{}_{}.mp4".format(epoch, goal_idx)) |
| return ( |
| avg_reward / (num_evals * num_eval_per_goal), |
| completion_id_list, |
| avg_max_coverage, |
| avg_final_coverage, |
| ) |
|
|
| metrics_history = [] |
| reward_history = [] |
| for epoch in tqdm.trange(cfg.epochs): |
| cbet_model.eval() |
| if epoch % cfg.eval_on_env_freq == 0: |
| avg_reward, completion_id_list, max_coverage, final_coverage = eval_on_env( |
| cfg, |
| videorecorder=video, |
| epoch=epoch, |
| num_eval_per_goal=cfg.num_final_eval_per_goal, |
| ) |
| reward_history.append(avg_reward) |
| with open("{}/completion_idx_{}.json".format(save_path, epoch), "wb") as fp: |
| pickle.dump(completion_id_list, fp) |
| wandb.log({"eval_on_env": avg_reward}) |
| if cfg.env.gym.id in ["pusht", "blockpush"]: |
| metric_final = ( |
| "final coverage" if cfg.env.gym.id == "pusht" else "entered" |
| ) |
| metric_max = "max coverage" if cfg.env.gym.id == "pusht" else "moved" |
| metrics = { |
| f"{metric_final} mean": sum(final_coverage) / len(final_coverage), |
| f"{metric_final} max": max(final_coverage), |
| f"{metric_final} min": min(final_coverage), |
| f"{metric_max} mean": sum(max_coverage) / len(max_coverage), |
| f"{metric_max} max": max(max_coverage), |
| f"{metric_max} min": min(max_coverage), |
| } |
| wandb.log(metrics) |
| metrics_history.append(metrics) |
|
|
| if epoch % cfg.eval_freq == 0: |
| total_loss = 0 |
| action_diff = 0 |
| action_diff_tot = 0 |
| action_diff_mean_res1 = 0 |
| action_diff_mean_res2 = 0 |
| action_diff_max = 0 |
| with torch.no_grad(): |
| for data in test_loader: |
| obs, act, goal = (x.to(cfg.device) for x in data) |
| assert obs.ndim == 4, "expect N T V E here" |
| obs = einops.rearrange(obs, "N T V E -> N T (V E)") |
| goal = einops.rearrange(goal, "N T V E -> N T (V E)") |
| predicted_act, loss, loss_dict = cbet_model(obs, goal, act) |
| total_loss += loss.item() |
| wandb.log({"eval/{}".format(x): y for (x, y) in loss_dict.items()}) |
| action_diff += loss_dict["action_diff"] |
| action_diff_tot += loss_dict["action_diff_tot"] |
| action_diff_mean_res1 += loss_dict["action_diff_mean_res1"] |
| action_diff_mean_res2 += loss_dict["action_diff_mean_res2"] |
| action_diff_max += loss_dict["action_diff_max"] |
| print(f"Test loss: {total_loss / len(test_loader)}") |
| wandb.log({"eval/epoch_wise_action_diff": action_diff}) |
| wandb.log({"eval/epoch_wise_action_diff_tot": action_diff_tot}) |
| wandb.log({"eval/epoch_wise_action_diff_mean_res1": action_diff_mean_res1}) |
| wandb.log({"eval/epoch_wise_action_diff_mean_res2": action_diff_mean_res2}) |
| wandb.log({"eval/epoch_wise_action_diff_max": action_diff_max}) |
|
|
| cbet_model.train() |
| for data in tqdm.tqdm(train_loader): |
| optimizer.zero_grad() |
| obs, act, goal = (x.to(cfg.device) for x in data) |
| obs = einops.rearrange(obs, "N T V E -> N T (V E)") |
| goal = einops.rearrange(goal, "N T V E -> N T (V E)") |
| predicted_act, loss, loss_dict = cbet_model(obs, goal, act) |
| wandb.log({"train/{}".format(x): y for (x, y) in loss_dict.items()}) |
| loss.backward() |
| optimizer.step() |
|
|
| avg_reward, completion_id_list, max_coverage, final_coverage = eval_on_env( |
| cfg, |
| num_evals=cfg.num_final_evals, |
| num_eval_per_goal=cfg.num_final_eval_per_goal, |
| videorecorder=video, |
| epoch=cfg.epochs, |
| ) |
| reward_history.append(avg_reward) |
| if cfg.env.gym.id in ["pusht", "blockpush"]: |
| metric_final = "final coverage" if cfg.env.gym.id == "pusht" else "entered" |
| metric_max = "max coverage" if cfg.env.gym.id == "pusht" else "moved" |
| metrics = { |
| f"{metric_final} mean": sum(final_coverage) / len(final_coverage), |
| f"{metric_final} max": max(final_coverage), |
| f"{metric_final} min": min(final_coverage), |
| f"{metric_max} mean": sum(max_coverage) / len(max_coverage), |
| f"{metric_max} max": max(max_coverage), |
| f"{metric_max} min": min(max_coverage), |
| } |
| wandb.log(metrics) |
| metrics_history.append(metrics) |
|
|
| with open("{}/completion_idx_final.json".format(save_path), "wb") as fp: |
| pickle.dump(completion_id_list, fp) |
| if cfg.env.gym.id == "pusht": |
| final_eval_on_env = max([x["final coverage mean"] for x in metrics_history]) |
| elif cfg.env.gym.id == "blockpush": |
| final_eval_on_env = max([x["entered mean"] for x in metrics_history]) |
| elif cfg.env.gym.id == "libero_goal": |
| final_eval_on_env = max(reward_history) |
| elif cfg.env.gym.id == "kitchen-v0": |
| final_eval_on_env = avg_reward |
| wandb.log({"final_eval_on_env": final_eval_on_env}) |
| return final_eval_on_env |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|