| """Utils for evaluating policies in real-world ALOHA environments."""
|
|
|
| import os
|
|
|
| import imageio
|
| import numpy as np
|
| from PIL import Image
|
|
|
| from experiments.robot.aloha.real_env import make_real_env
|
| from experiments.robot.robot_utils import (
|
| DATE,
|
| DATE_TIME,
|
| )
|
|
|
|
|
| def get_next_task_label(task_label):
|
| """Prompt the user to input the next task."""
|
| if task_label == "":
|
| user_input = ""
|
| while user_input == "":
|
| user_input = input("Enter the task name: ")
|
| task_label = user_input
|
| else:
|
| user_input = input("Enter the task name (or leave blank to repeat the previous task): ")
|
| if user_input == "":
|
| pass
|
| else:
|
| task_label = user_input
|
| print(f"Task: {task_label}")
|
| return task_label
|
|
|
|
|
| def get_aloha_env():
|
| """Initializes and returns the ALOHA environment."""
|
| env = make_real_env(init_node=True)
|
| return env
|
|
|
|
|
| def resize_image_for_preprocessing(img):
|
| """
|
| Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done
|
| in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS.
|
| """
|
| ALOHA_PREPROCESS_SIZE = 256
|
| img = np.array(
|
| Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC)
|
| )
|
| return img
|
|
|
|
|
| def get_aloha_image(obs):
|
| """Extracts third-person image from observations and preprocesses it."""
|
|
|
| img = obs.observation["images"]["cam_high"]
|
| img = resize_image_for_preprocessing(img)
|
| return img
|
|
|
|
|
| def get_aloha_wrist_images(obs):
|
| """Extracts both wrist camera images from observations and preprocesses them."""
|
|
|
| left_wrist_img = obs.observation["images"]["cam_left_wrist"]
|
| right_wrist_img = obs.observation["images"]["cam_right_wrist"]
|
| left_wrist_img = resize_image_for_preprocessing(left_wrist_img)
|
| right_wrist_img = resize_image_for_preprocessing(right_wrist_img)
|
| return left_wrist_img, right_wrist_img
|
|
|
|
|
| def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None):
|
| """Saves an MP4 replay of an episode."""
|
| rollout_dir = f"./rollouts/{DATE}"
|
| os.makedirs(rollout_dir, exist_ok=True)
|
| processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
|
| filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}"
|
| if notes is not None:
|
| filetag += f"--{notes}"
|
| mp4_path = f"{filetag}.mp4"
|
| video_writer = imageio.get_writer(mp4_path, fps=25)
|
| for img in rollout_images:
|
| video_writer.append_data(img)
|
| video_writer.close()
|
| print(f"Saved rollout MP4 at path {mp4_path}")
|
| if log_file is not None:
|
| log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
|
| return mp4_path
|
|
|