|
|
|
|
| import contextlib
|
| import dataclasses
|
| import datetime
|
| import faulthandler
|
| import os
|
| import signal
|
| import time
|
| from moviepy.editor import ImageSequenceClip
|
| import numpy as np
|
| from openpi_client import image_tools
|
| from openpi_client import websocket_client_policy
|
| import pandas as pd
|
| from PIL import Image
|
| from droid.robot_env import RobotEnv
|
| import tqdm
|
| import tyro
|
|
|
| faulthandler.enable()
|
|
|
|
|
| DROID_CONTROL_FREQUENCY = 15
|
|
|
|
|
| @dataclasses.dataclass
|
| class Args:
|
|
|
| left_camera_id: str = "<your_camera_id>"
|
| right_camera_id: str = "<your_camera_id>"
|
| wrist_camera_id: str = "<your_camera_id>"
|
|
|
|
|
| external_camera: str | None = (
|
| None
|
| )
|
|
|
|
|
| max_timesteps: int = 600
|
|
|
|
|
| open_loop_horizon: int = 8
|
|
|
|
|
| remote_host: str = "0.0.0.0"
|
| remote_port: int = (
|
| 8000
|
| )
|
|
|
|
|
|
|
|
|
|
|
| @contextlib.contextmanager
|
| def prevent_keyboard_interrupt():
|
| """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
| interrupted = False
|
| original_handler = signal.getsignal(signal.SIGINT)
|
|
|
| def handler(signum, frame):
|
| nonlocal interrupted
|
| interrupted = True
|
|
|
| signal.signal(signal.SIGINT, handler)
|
| try:
|
| yield
|
| finally:
|
| signal.signal(signal.SIGINT, original_handler)
|
| if interrupted:
|
| raise KeyboardInterrupt
|
|
|
|
|
| def main(args: Args):
|
|
|
| assert (
|
| args.external_camera is not None and args.external_camera in ["left", "right"]
|
| ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
|
|
|
|
| env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
| print("Created the droid env!")
|
|
|
|
|
| policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
|
|
| df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
|
|
| while True:
|
| instruction = input("Enter instruction: ")
|
|
|
|
|
| actions_from_chunk_completed = 0
|
| pred_action_chunk = None
|
|
|
|
|
| timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
| video = []
|
| bar = tqdm.tqdm(range(args.max_timesteps))
|
| print("Running rollout... press Ctrl+C to stop early.")
|
| for t_step in bar:
|
| start_time = time.time()
|
| try:
|
|
|
| curr_obs = _extract_observation(
|
| args,
|
| env.get_observation(),
|
|
|
| save_to_disk=t_step == 0,
|
| )
|
|
|
| video.append(curr_obs[f"{args.external_camera}_image"])
|
|
|
|
|
| if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
| actions_from_chunk_completed = 0
|
|
|
|
|
|
|
| request_data = {
|
| "observation/exterior_image_1_left": image_tools.resize_with_pad(
|
| curr_obs[f"{args.external_camera}_image"], 224, 224
|
| ),
|
| "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
| "observation/joint_position": curr_obs["joint_position"],
|
| "observation/gripper_position": curr_obs["gripper_position"],
|
| "prompt": instruction,
|
| }
|
|
|
|
|
|
|
| with prevent_keyboard_interrupt():
|
|
|
| pred_action_chunk = policy_client.infer(request_data)["actions"]
|
| assert pred_action_chunk.shape == (10, 8)
|
|
|
|
|
| action = pred_action_chunk[actions_from_chunk_completed]
|
| actions_from_chunk_completed += 1
|
|
|
|
|
| if action[-1].item() > 0.5:
|
|
|
| action = np.concatenate([action[:-1], np.ones((1,))])
|
| else:
|
|
|
| action = np.concatenate([action[:-1], np.zeros((1,))])
|
|
|
|
|
| action = np.clip(action, -1, 1)
|
|
|
| env.step(action)
|
|
|
|
|
| elapsed_time = time.time() - start_time
|
| if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
| time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
| except KeyboardInterrupt:
|
| break
|
|
|
| video = np.stack(video)
|
| save_filename = "video_" + timestamp
|
| ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
|
|
| success: str | float | None = None
|
| while not isinstance(success, float):
|
| success = input(
|
| "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
| )
|
| if success == "y":
|
| success = 1.0
|
| elif success == "n":
|
| success = 0.0
|
|
|
| success = float(success) / 100
|
| if not (0 <= success <= 1):
|
| print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
|
|
| df = df.append(
|
| {
|
| "success": success,
|
| "duration": t_step,
|
| "video_filename": save_filename,
|
| },
|
| ignore_index=True,
|
| )
|
|
|
| if input("Do one more eval? (enter y or n) ").lower() != "y":
|
| break
|
| env.reset()
|
|
|
| os.makedirs("results", exist_ok=True)
|
| timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
| csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
| df.to_csv(csv_filename)
|
| print(f"Results saved to {csv_filename}")
|
|
|
|
|
| def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
| image_observations = obs_dict["image"]
|
| left_image, right_image, wrist_image = None, None, None
|
| for key in image_observations:
|
|
|
|
|
| if args.left_camera_id in key and "left" in key:
|
| left_image = image_observations[key]
|
| elif args.right_camera_id in key and "left" in key:
|
| right_image = image_observations[key]
|
| elif args.wrist_camera_id in key and "left" in key:
|
| wrist_image = image_observations[key]
|
|
|
|
|
| left_image = left_image[..., :3]
|
| right_image = right_image[..., :3]
|
| wrist_image = wrist_image[..., :3]
|
|
|
|
|
| left_image = left_image[..., ::-1]
|
| right_image = right_image[..., ::-1]
|
| wrist_image = wrist_image[..., ::-1]
|
|
|
|
|
| robot_state = obs_dict["robot_state"]
|
| cartesian_position = np.array(robot_state["cartesian_position"])
|
| joint_position = np.array(robot_state["joint_positions"])
|
| gripper_position = np.array([robot_state["gripper_position"]])
|
|
|
|
|
|
|
| if save_to_disk:
|
| combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
| combined_image = Image.fromarray(combined_image)
|
| combined_image.save("robot_camera_views.png")
|
|
|
| return {
|
| "left_image": left_image,
|
| "right_image": right_image,
|
| "wrist_image": wrist_image,
|
| "cartesian_position": cartesian_position,
|
| "joint_position": joint_position,
|
| "gripper_position": gripper_position,
|
| }
|
|
|
|
|
| if __name__ == "__main__":
|
| args: Args = tyro.cli(Args)
|
| main(args)
|
|
|