haofuly's picture
Add files using upload-large-folder tool
b23769d verified
raw
history blame
3.14 kB
"""Utils for evaluating policies in real-world ALOHA environments."""
import os
import imageio
import numpy as np
from PIL import Image
from experiments.robot.aloha.real_env import make_real_env
from experiments.robot.robot_utils import (
DATE,
DATE_TIME,
)
def get_next_task_label(task_label):
"""Prompt the user to input the next task."""
if task_label == "":
user_input = ""
while user_input == "":
user_input = input("Enter the task name: ")
task_label = user_input
else:
user_input = input("Enter the task name (or leave blank to repeat the previous task): ")
if user_input == "":
pass # Do nothing -> Let task_label be the same
else:
task_label = user_input
print(f"Task: {task_label}")
return task_label
def get_aloha_env():
"""Initializes and returns the ALOHA environment."""
env = make_real_env(init_node=True)
return env
def resize_image_for_preprocessing(img):
"""
Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done
in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS.
"""
ALOHA_PREPROCESS_SIZE = 256
img = np.array(
Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC)
) # BICUBIC is default; specify explicitly to make it clear
return img
def get_aloha_image(obs):
"""Extracts third-person image from observations and preprocesses it."""
# obs: dm_env._environment.TimeStep
img = obs.observation["images"]["cam_high"]
img = resize_image_for_preprocessing(img)
return img
def get_aloha_wrist_images(obs):
"""Extracts both wrist camera images from observations and preprocesses them."""
# obs: dm_env._environment.TimeStep
left_wrist_img = obs.observation["images"]["cam_left_wrist"]
right_wrist_img = obs.observation["images"]["cam_right_wrist"]
left_wrist_img = resize_image_for_preprocessing(left_wrist_img)
right_wrist_img = resize_image_for_preprocessing(right_wrist_img)
return left_wrist_img, right_wrist_img
def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None):
"""Saves an MP4 replay of an episode."""
rollout_dir = f"./rollouts/{DATE}"
os.makedirs(rollout_dir, exist_ok=True)
processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}"
if notes is not None:
filetag += f"--{notes}"
mp4_path = f"{filetag}.mp4"
video_writer = imageio.get_writer(mp4_path, fps=25)
for img in rollout_images:
video_writer.append_data(img)
video_writer.close()
print(f"Saved rollout MP4 at path {mp4_path}")
if log_file is not None:
log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
return mp4_path