| import dm_env |
| from absl import logging |
|
|
| import rclpy |
| from sensor_msgs.msg import Image, JointState |
| from std_msgs.msg import Bool |
| from std_msgs.msg import Int32 |
| import numpy as np |
| import threading |
| import time |
| |
| import random |
| from scipy.spatial.transform import Rotation |
| from glob import glob |
| import os |
| import h5py |
| import cv2 |
|
|
| class AnubisRobotEnv: |
| def __init__(self, hz=20, max_timestep=1000, task_name='', num_rollout=1): |
| rclpy.init() |
| self._node = rclpy.create_node('anubis_robot_env_node') |
| self._subscriber_bringup() |
| print('ROS2 node created') |
|
|
| self.window = None |
| self.start = False |
| self.thread_done = False |
| self.hz = hz |
| self.action_counter = 0 |
| self.num_rollout = num_rollout |
| self.rollout_counter = 0 |
|
|
| self.lang_dict = { |
| 'anubis_brush_to_pan' : 'insert the brush to the dustpan', |
| 'anubis_carrot_to_bag' : 'pick up the carrot and put into the bag', |
| 'anubis_towel_kirby' : 'take the towel off the kirby doll' |
| } |
| self.task_name = task_name |
| self.instruction = self.lang_dict[self.task_name] |
| self.data_list = glob(f'/home/rllab/workspace/jellyho/demo_collection/{self.task_name}/*.hdf5') |
|
|
| self.overlay_img = None |
| self.max_timestep = max_timestep |
|
|
| self.init_action = JointState() |
| self.init_action.position = [ |
| 0.20620185010895048, |
| 0.16183641523267392, |
| 0.2277105000367078, |
| -0.42093861525667453, |
| 0.6546518510233503, |
| -0.5770953981378887, |
| 0.24739146627474096, |
| -1.6, |
| 0.21136149716403216, |
| -0.16027684481842075, |
| 0.21879985782478842, |
| 0.6606782591766969, |
| -0.428768621033297, |
| 0.2340722378552696, |
| -0.569975345900049, |
| -1.6 |
| ] |
|
|
| print('Initializing Anubis Robot Environment') |
|
|
| self.thread = PeriodicThread(1/self.hz, self.timer_callback) |
| self.thread.start() |
|
|
| self.video_thread = PeriodicThread(1/30, self.video_timer_callback) |
| self.video_thread.start() |
|
|
| self.timer_thread = threading.Thread(target=rclpy.spin, args=(self._node,), daemon=True) |
| self.timer_thread.start() |
| print('Threads started') |
|
|
| self.bringup_model() |
| self.initialize() |
| logging.set_verbosity(logging.INFO) |
| logging.info('AnubisRobotEnv successfully initialized.') |
|
|
| def init_robot_pose(self, demo): |
| print('Initializing robot pose', demo % len(self.data_list)) |
| root = h5py.File(self.data_list[demo % len(self.data_list)], 'r') |
| first_action = root['action']['eef_pose'][0] |
| self.publish_action(first_action) |
| |
| def initialize(self): |
| self.curr_timestep = 0 |
| if self.window is None: |
| from visualize_utils import window |
| self.window = window('ENV Observation', video_path=f'{self.model_name}-{self.task_name}', video_fps=30, video_size=(640, 480), show=False) |
| else: |
| self.window.init_video() |
| self.send_demo(self.rollout_counter) |
| self.init_robot_pose(self.rollout_counter) |
|
|
| def reset(self): |
| while not self.thread_done: |
| time.sleep(0.01) |
| continue |
| self.thread_done = False |
| return dm_env.restart(observation=self._observation()) |
|
|
| def bringup_model(self): |
| raise NotImplementedError |
| |
| def inference(self): |
| raise NotImplementedError |
|
|
| def ros_close(self): |
| self.thread.stop() |
| self.timer_thread.stop() |
| self._node.destroy_node() |
| rclpy.shutdown() |
|
|
| def _subscriber_bringup(self): |
| ''' |
| Note: This function creates all the subscribers \ |
| for reading joint and gripper states. |
| ''' |
| |
| self.obs = {} |
| self.action = {} |
|
|
| |
| |
| self._node.create_subscription(Image, '/camera_center/camera/color/image_raw', self.agentview_image_callback, 10) |
| self.obs['agentview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
|
|
| self._node.create_subscription(Image, '/camera_right/camera/color/image_raw', self.rightview_image_callback, 10) |
| self.obs['rightview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
|
|
| self._node.create_subscription(Image, '/camera_left/camera/color/image_raw', self.leftview_image_callback, 10) |
| self.obs['leftview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
|
|
| |
| self._node.create_subscription(JointState, '/eef_pose', self.eef_pose_callback, 10) |
| self.obs['eef_pose'] = np.zeros(shape=(20,), dtype=np.float64) |
|
|
| |
| self.obs['language_instruction'] = '' |
|
|
| |
| self._node.create_subscription(Bool, '/done', self.done_callback, 10) |
|
|
| self.demo_pub = self._node.create_publisher(Int32, '/demo', 10) |
| self.action_pub = self._node.create_publisher(JointState, '/teleop/eef_pose', 10) |
|
|
| def send_demo(self, num): |
| demo_msg = Int32() |
| demo_msg.data = num |
| self.demo_pub.publish(demo_msg) |
|
|
| |
| def agentview_image_callback(self, msg): |
| self.obs['agentview_image'] = np.reshape(msg.data, (480, 640, 3)) |
|
|
| def rightview_image_callback(self, msg): |
| rightview = np.reshape(msg.data, (480, 640, 3)) |
| self.obs['rightview_image'] = np.rot90(rightview, 2) |
|
|
| def leftview_image_callback(self, msg): |
| self.obs['leftview_image'] = np.reshape(msg.data, (480, 640, 3)) |
|
|
| def eef_pose_callback(self, msg): |
| recevied_data = np.array(msg.position) |
| eef_pose_data = np.zeros(shape=(20,), dtype=np.float64) |
| eef_pose_data[:3] = recevied_data[:3] |
| eef_pose_data[3:9] = self.quat_to_6d(recevied_data[3:7], scalar_first=False) |
| eef_pose_data[9] = recevied_data[7] |
| eef_pose_data[10:13] = recevied_data[8:11] |
| eef_pose_data[13:19] = self.quat_to_6d(recevied_data[11:15], scalar_first=False) |
| eef_pose_data[19] = recevied_data[15] |
| self.obs['eef_pose'] = eef_pose_data |
|
|
| def send_action(self, act): |
| if self.start: |
| action_msg = JointState() |
| |
| |
| |
| |
| |
| action_msg_data = np.zeros(16) |
| action_msg_data[0:3] = act[0:3] |
| action_msg_data[3:7] = self.sixd_to_quat(act[3:9]) |
| action_msg_data[7] = act[9] |
| action_msg_data[8:11] = act[10:13] |
| action_msg_data[11:15] = self.sixd_to_quat(act[13:19]) |
| action_msg_data[15] = act[19] |
| action_msg.position = action_msg_data.astype(float).tolist() |
| self.action_pub.publish(action_msg) |
|
|
| def publish_action(self, action): |
| action_msg = JointState() |
| |
|
|
| |
| action = action.squeeze() |
| action_msg_data = np.zeros(16) |
| action_msg_data[0:3] = action[0:3] |
| action_msg_data[3:7] = self.sixd_to_quat(action[3:9]) |
| action_msg_data[7] = action[9] |
| action_msg_data[8:11] = action[10:13] |
| action_msg_data[11:15] = self.sixd_to_quat(action[13:19]) |
| action_msg_data[15] = action[19] |
| action_msg.position = action_msg_data.astype(float).tolist() |
| self.action_pub.publish(action_msg) |
|
|
| def done_callback(self, msg): |
| if not self.start: |
| print('Inference & Video Recording Start') |
| self.start = True |
| self.window.video_start() |
| else: |
| self.start = False |
| self.action_counter = 0 |
| self.rollout_counter += 1 |
| if self.window.video_recording: |
| self.window.video_stop() |
| self.initialize() |
| print('Next Inference Ready') |
|
|
| def timer_callback(self): |
| if self.start: |
| self.inference() |
| self.curr_timestep += 1 |
| if self.curr_timestep >= self.max_timestep: |
| print("Max timestep reached, resetting environment.") |
| self.start = False |
| if self.window.video_recording: |
| self.window.video_stop() |
| self.rollout_counter += 1 |
| self.action_counter = 0 |
| self.curr_timestep = 0 |
| self.initialize() |
| self.thread_done = True |
|
|
| def video_timer_callback(self): |
| if self.start and self.window.video_recording: |
| self.window.video_write() |
|
|
| def quat_to_6d(self, quat, scalar_first=False): |
| r = Rotation.from_quat(quat, scalar_first=scalar_first) |
| mat = r.as_matrix() |
| return mat[:, :2].flatten() |
| |
| def sixd_to_quat(self, sixd, scalar_first=False): |
| mat = np.zeros((3, 3)) |
| mat[:, :2] = sixd.reshape(3, 2) |
| mat[:, 2] = np.cross(mat[:, 0], mat[:, 1]) |
| r = Rotation.from_matrix(mat) |
| return r.as_quat(scalar_first=scalar_first) |
| |
| def ros_close(self): |
| if self.window.video_recording: |
| self.window.video_stop() |
| self.thread.stop() |
| self.video_thread.stop() |
| self.timer_thread.stop() |
| self._node.destroy_node() |
| rclpy.shutdown() |
|
|
| class PeriodicThread(threading.Thread): |
| def __init__(self, interval, function, *args, **kwargs): |
| super().__init__() |
| self.interval = interval |
| self.function = function |
| self.args = args |
| self.kwargs = kwargs |
| self.stop_event = threading.Event() |
| self._lock = threading.Lock() |
|
|
| def run(self): |
| while not self.stop_event.is_set(): |
| start_time = time.time() |
| self.function(*self.args, **self.kwargs) |
| elapsed_time = time.time() - start_time |
| sleep_time = max(0, self.interval - elapsed_time) |
| time.sleep(sleep_time) |
|
|
| def stop(self): |
| self.stop_event.set() |
|
|
| def change_period(self, new_interval): |
| with self._lock: |
| self.interval = new_interval |