| from node import InferenceNode |
| import json |
| import torch |
| from PIL import Image as IMG |
| import numpy as np |
| from std_msgs.msg import String, Bool |
| import argparse |
| import h5py |
| import os, pickle |
| from einops import rearrange |
| import numpy as np |
| from PIL import Image |
| import time |
| """ |
| #!/usr/bin/python3 |
| """ |
|
|
| import argparse |
| import sys |
| import threading |
| import time |
| import yaml |
| from collections import deque |
|
|
| import numpy as np |
| import torch |
| from cv_bridge import CvBridge |
| from geometry_msgs.msg import Twist |
| from nav_msgs.msg import Odometry |
| from std_msgs.msg import Header |
| import cv2 |
|
|
| from scripts.agilex_model import create_model |
|
|
| class RDTNode(InferenceNode): |
| def __init__(self, action_chunk, instruction, ckpt_dir, unnorm_key, hz=20, max_timestep=1000, dataset_name=None, single_arm=True, lang_embed_name=''): |
| self.ckpt_dir = ckpt_dir |
| self.lang_embed_name = f'outs/{lang_embed_name}.pt' |
| self.run_name = f'rdt_{ckpt_dir.split("/")[-1]}' |
| self.single_arm = single_arm |
| super().__init__(hz=hz, max_timestep=max_timestep, dataset_name=dataset_name, single_arm=single_arm) |
| self.obs['language_instruction'] = f'{instruction}' |
| self.action_chunk = action_chunk |
| self.action_counter = 0 |
| self.unnorm_key = unnorm_key |
| self.prompt_sub = self._node.create_subscription(String, '/vla/prompt', self.prompt_sub, 1) |
| self.attn = None |
| |
|
|
| def prompt_sub(self, msg): |
| if self.policy is not None: |
| img = self.obs['image'] |
| pil_image = Image.fromarray(img) |
| print(self.policy.inference_prompt(pil_image, msg.data)) |
|
|
| def bringup_model(self): |
| with open('configs/base.yaml', "r") as fp: |
| config = yaml.safe_load(fp) |
| self.policy = create_model( |
| args=config, |
| dtype=torch.bfloat16, |
| pretrained=self.ckpt_dir, |
| |
| pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
| control_frequency=20, |
| single_arm=self.single_arm |
| ) |
| self.lang_embeddings = torch.load(self.lang_embed_name)["embeddings"] |
|
|
| def inference_fn(self): |
| if self.single_arm: |
| image_arrs = [ |
| self.frame_buffer[-2], |
| None, |
| None, |
| self.frame_buffer[-1], |
| None, |
| None |
| |
| ] |
| else: |
| image_arrs = [ |
| self.frame_buffer[-2], |
| self.left_frame_buffer[-2], |
| None, |
| self.frame_buffer[-1], |
| self.left_frame_buffer[-1], |
| None |
| ] |
| images = [Image.fromarray(arr) if arr is not None else None |
| for arr in image_arrs] |
| if self.single_arm: |
| proprio = torch.tensor(self.joint_pos_buffer[-1][7:]).unsqueeze(0) |
| else: |
| proprio = torch.tensor(self.joint_pos_buffer[-1]).unsqueeze(0) |
|
|
| actions = self.policy.step( |
| proprio=proprio, |
| images=images, |
| text_embeds=self.lang_embeddings |
| ).squeeze(0).cpu().numpy() |
|
|
| return actions |
| |
| def inference(self): |
| if self.action_counter == 0: |
| with torch.inference_mode(): |
| |
| start_time = time.time() |
| self.actions = self.inference_fn() |
| end_time = time.time() |
| print(f'{end_time - start_time:.6f} sec') |
| |
| action = self.actions[self.action_counter] |
| |
| if self.single_arm: |
| self.joint_action(None, action) |
| else: |
| self.joint_action(action[:7], action[7:]) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.action_counter += 1 |
| if self.action_counter == self.action_chunk: |
| self.action_counter = 0 |
|
|
| def done_callback(self, msg): |
| if not self.start: |
| |
| if self.data_list is not None: |
| root = h5py.File(self.data_list[self.num], 'r') |
| skip = 5 |
| if self.single_arm: |
| self.target_joint_right = root['observation']['joint_pos'][skip, :7] |
| self.joint_action(None, self.target_joint_right) |
| else: |
| self.target_joint_left = root['observation']['joint_pos'][skip, :7] |
| self.target_joint_right = root['observation']['joint_pos'][skip, 7:] |
| self.joint_action(self.target_joint_left, self.target_joint_right) |
| time.sleep(2) |
| |
| else: |
| self.target_ee_left = self.obs['left_pose'] |
| self.target_ee_right = self.obs['right_pose'] |
| print('Inference & Video Recording Start') |
| self.start = True |
| msg = Bool() |
| msg.data = True |
| self.sync_pub.publish(msg) |
| self.window.video_start() |
| else: |
| self.start = False |
| msg = Bool() |
| msg.data = False |
| self.sync_pub.publish(msg) |
| self.init_robot() |
| self.action_counter = 0 |
| if self.window.video_recording: |
| self.window.video_stop() |
| self.initialize() |
| print('Next Inference Ready') |
|
|
| if __name__ == "__main__": |
| import cv2 |
|
|
| ckpt_dir = '/home/univ/workspace/rdt-ckpts/checkpoint-38000' |
|
|
| action_chunk = 64 |
| hz = 20 |
|
|
| instruction = 'handover the stuffed doll' |
| unnorm_key = 'handover_kirby' |
| single_arm = False |
| dataset_name = [ |
| 'vla_upright_mug', |
| 'vla_sweep_screws', |
| 'vla_pick_ball_place_bin', |
| 'twinvla_handover_kirby', |
| 'twinvla_put_bottle', |
| 'twinvla_detach_ball', |
| 'twinvla_tear_paper_towel' |
| ] |
| lang_embed_name = [ |
| 'upright_mug', |
| 'sweep_screws', |
| 'pick_ball_place_bin', |
| 'handover_kirby' |
| ] |
| num = 3 |
|
|
| node = RDTNode( |
| action_chunk=action_chunk, |
| instruction=instruction, |
| ckpt_dir=ckpt_dir, |
| unnorm_key=unnorm_key, |
| hz=hz, |
| max_timestep=1000, |
| dataset_name=dataset_name[num], |
| lang_embed_name=lang_embed_name[num], |
| single_arm=single_arm |
| ) |
|
|
| while True: |
| try: |
| if node.single_arm: |
| img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) |
| else: |
| left_img = cv2.cvtColor(node.obs['leftview_image'], cv2.COLOR_BGR2RGB) |
| right_img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) |
| img = cv2.hconcat([left_img, right_img]) |
| if node.start: |
| node.window.show(img, overlay_img=None, text=node.obs['language_instruction']) |
| else: |
| |
| node.boundary_query() |
| node.window.show(img, overlay_img=node.overlay_img, text=node.obs['language_instruction'], grid=node.grid) |
| except KeyboardInterrupt: |
| node.ros_close() |
| |
| except Exception as e: |
| print(f"An error occurred: {e}") |
|
|
| |