"""Utils for evaluating policies in real-world ALOHA environments.""" import os import imageio import numpy as np from PIL import Image from experiments.robot.aloha.real_env import make_real_env from experiments.robot.robot_utils import ( DATE, DATE_TIME, ) def get_next_task_label(task_label): """Prompt the user to input the next task.""" if task_label == "": user_input = "" while user_input == "": user_input = input("Enter the task name: ") task_label = user_input else: user_input = input("Enter the task name (or leave blank to repeat the previous task): ") if user_input == "": pass # Do nothing -> Let task_label be the same else: task_label = user_input print(f"Task: {task_label}") return task_label def get_aloha_env(): """Initializes and returns the ALOHA environment.""" env = make_real_env(init_node=True) return env def resize_image_for_preprocessing(img): """ Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS. """ ALOHA_PREPROCESS_SIZE = 256 img = np.array( Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC) ) # BICUBIC is default; specify explicitly to make it clear return img def get_aloha_image(obs): """Extracts third-person image from observations and preprocesses it.""" # obs: dm_env._environment.TimeStep img = obs.observation["images"]["cam_high"] img = resize_image_for_preprocessing(img) return img def get_aloha_wrist_images(obs): """Extracts both wrist camera images from observations and preprocesses them.""" # obs: dm_env._environment.TimeStep left_wrist_img = obs.observation["images"]["cam_left_wrist"] right_wrist_img = obs.observation["images"]["cam_right_wrist"] left_wrist_img = resize_image_for_preprocessing(left_wrist_img) right_wrist_img = resize_image_for_preprocessing(right_wrist_img) return left_wrist_img, right_wrist_img def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None): """Saves an MP4 replay of an episode.""" rollout_dir = f"./rollouts/{DATE}" os.makedirs(rollout_dir, exist_ok=True) processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}" if notes is not None: filetag += f"--{notes}" mp4_path = f"{filetag}.mp4" video_writer = imageio.get_writer(mp4_path, fps=25) for img in rollout_images: video_writer.append_data(img) video_writer.close() print(f"Saved rollout MP4 at path {mp4_path}") if log_file is not None: log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") return mp4_path