Commit ·
1501ed7
1
Parent(s): 342fd2c
update
Browse files- equidiff/equi_diffpo/common/robomimic_config_util.py +47 -0
- equidiff/equi_diffpo/common/robomimic_util.py +240 -0
- equidiff/equi_diffpo/dataset/robomimic_replay_image_dataset.py +377 -0
- equidiff/equi_diffpo/dataset/robomimic_replay_image_sym_dataset.py +90 -0
- equidiff/equi_diffpo/dataset/robomimic_replay_lowdim_dataset.py +169 -0
- equidiff/equi_diffpo/dataset/robomimic_replay_lowdim_sym_dataset.py +73 -0
- equidiff/equi_diffpo/dataset/robomimic_replay_point_cloud_dataset.py +407 -0
- equidiff/equi_diffpo/dataset/robomimic_replay_voxel_sym_dataset.py +452 -0
- equidiff/equi_diffpo/env/robomimic/robomimic_image_wrapper.py +170 -0
- equidiff/equi_diffpo/env/robomimic/robomimic_lowdim_wrapper.py +133 -0
- equidiff/equi_diffpo/env_runner/robomimic_image_runner.py +378 -0
- equidiff/equi_diffpo/env_runner/robomimic_lowdim_runner.py +405 -0
- equidiff/equi_diffpo/policy/robomimic_image_policy.py +142 -0
- equidiff/equi_diffpo/scripts/robomimic_dataset_action_comparison.py +51 -0
- equidiff/equi_diffpo/scripts/robomimic_dataset_conversion.py +103 -0
- equidiff/equi_diffpo/scripts/robomimic_dataset_obs_conversion.py +69 -0
equidiff/equi_diffpo/common/robomimic_config_util.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
from robomimic.config import config_factory
|
| 3 |
+
import robomimic.scripts.generate_paper_configs as gpc
|
| 4 |
+
from robomimic.scripts.generate_paper_configs import (
|
| 5 |
+
modify_config_for_default_image_exp,
|
| 6 |
+
modify_config_for_default_low_dim_exp,
|
| 7 |
+
modify_config_for_dataset,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
def get_robomimic_config(
|
| 11 |
+
algo_name='bc_rnn',
|
| 12 |
+
hdf5_type='low_dim',
|
| 13 |
+
task_name='square',
|
| 14 |
+
dataset_type='ph'
|
| 15 |
+
):
|
| 16 |
+
base_dataset_dir = '/tmp/null'
|
| 17 |
+
filter_key = None
|
| 18 |
+
|
| 19 |
+
# decide whether to use low-dim or image training defaults
|
| 20 |
+
modifier_for_obs = modify_config_for_default_image_exp
|
| 21 |
+
if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
|
| 22 |
+
modifier_for_obs = modify_config_for_default_low_dim_exp
|
| 23 |
+
|
| 24 |
+
algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name
|
| 25 |
+
config = config_factory(algo_name=algo_config_name)
|
| 26 |
+
# turn into default config for observation modalities (e.g.: low-dim or rgb)
|
| 27 |
+
config = modifier_for_obs(config)
|
| 28 |
+
# add in config based on the dataset
|
| 29 |
+
config = modify_config_for_dataset(
|
| 30 |
+
config=config,
|
| 31 |
+
task_name=task_name,
|
| 32 |
+
dataset_type=dataset_type,
|
| 33 |
+
hdf5_type=hdf5_type,
|
| 34 |
+
base_dataset_dir=base_dataset_dir,
|
| 35 |
+
filter_key=filter_key,
|
| 36 |
+
)
|
| 37 |
+
# add in algo hypers based on dataset
|
| 38 |
+
algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset')
|
| 39 |
+
config = algo_config_modifier(
|
| 40 |
+
config=config,
|
| 41 |
+
task_name=task_name,
|
| 42 |
+
dataset_type=dataset_type,
|
| 43 |
+
hdf5_type=hdf5_type,
|
| 44 |
+
)
|
| 45 |
+
return config
|
| 46 |
+
|
| 47 |
+
|
equidiff/equi_diffpo/common/robomimic_util.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import copy
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import robomimic.utils.obs_utils as ObsUtils
|
| 6 |
+
import robomimic.utils.file_utils as FileUtils
|
| 7 |
+
import robomimic.utils.env_utils as EnvUtils
|
| 8 |
+
import robomimic.utils.tensor_utils as TensorUtils
|
| 9 |
+
from scipy.spatial.transform import Rotation
|
| 10 |
+
|
| 11 |
+
from robomimic.config import config_factory
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RobomimicAbsoluteActionConverter:
|
| 15 |
+
def __init__(self, dataset_path, algo_name='bc'):
|
| 16 |
+
# default BC config
|
| 17 |
+
config = config_factory(algo_name=algo_name)
|
| 18 |
+
|
| 19 |
+
# read config to set up metadata for observation modalities (e.g. detecting rgb observations)
|
| 20 |
+
# must ran before create dataset
|
| 21 |
+
ObsUtils.initialize_obs_utils_with_config(config)
|
| 22 |
+
|
| 23 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
|
| 24 |
+
abs_env_meta = copy.deepcopy(env_meta)
|
| 25 |
+
abs_env_meta['env_kwargs']['controller_configs']['control_delta'] = False
|
| 26 |
+
|
| 27 |
+
env = EnvUtils.create_env_from_metadata(
|
| 28 |
+
env_meta=env_meta,
|
| 29 |
+
render=False,
|
| 30 |
+
render_offscreen=False,
|
| 31 |
+
use_image_obs=False,
|
| 32 |
+
)
|
| 33 |
+
assert len(env.env.robots) in (1, 2)
|
| 34 |
+
|
| 35 |
+
abs_env = EnvUtils.create_env_from_metadata(
|
| 36 |
+
env_meta=abs_env_meta,
|
| 37 |
+
render=False,
|
| 38 |
+
render_offscreen=False,
|
| 39 |
+
use_image_obs=False,
|
| 40 |
+
)
|
| 41 |
+
assert not abs_env.env.robots[0].controller.use_delta
|
| 42 |
+
|
| 43 |
+
self.env = env
|
| 44 |
+
self.abs_env = abs_env
|
| 45 |
+
self.file = h5py.File(dataset_path, 'r')
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return len(self.file['data'])
|
| 49 |
+
|
| 50 |
+
def convert_actions(self,
|
| 51 |
+
states: np.ndarray,
|
| 52 |
+
actions: np.ndarray) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Given state and delta action sequence
|
| 55 |
+
generate equivalent goal position and orientation for each step
|
| 56 |
+
keep the original gripper action intact.
|
| 57 |
+
"""
|
| 58 |
+
# in case of multi robot
|
| 59 |
+
# reshape (N,14) to (N,2,7)
|
| 60 |
+
# or (N,7) to (N,1,7)
|
| 61 |
+
stacked_actions = actions.reshape(*actions.shape[:-1],-1,7)
|
| 62 |
+
|
| 63 |
+
env = self.env
|
| 64 |
+
# generate abs actions
|
| 65 |
+
action_goal_pos = np.zeros(
|
| 66 |
+
stacked_actions.shape[:-1]+(3,),
|
| 67 |
+
dtype=stacked_actions.dtype)
|
| 68 |
+
action_goal_ori = np.zeros(
|
| 69 |
+
stacked_actions.shape[:-1]+(3,),
|
| 70 |
+
dtype=stacked_actions.dtype)
|
| 71 |
+
action_gripper = stacked_actions[...,[-1]]
|
| 72 |
+
for i in range(len(states)):
|
| 73 |
+
_ = env.reset_to({'states': states[i]})
|
| 74 |
+
|
| 75 |
+
# taken from robot_env.py L#454
|
| 76 |
+
for idx, robot in enumerate(env.env.robots):
|
| 77 |
+
# run controller goal generator
|
| 78 |
+
robot.control(stacked_actions[i,idx], policy_step=True)
|
| 79 |
+
|
| 80 |
+
# read pos and ori from robots
|
| 81 |
+
controller = robot.controller
|
| 82 |
+
action_goal_pos[i,idx] = controller.goal_pos
|
| 83 |
+
action_goal_ori[i,idx] = Rotation.from_matrix(
|
| 84 |
+
controller.goal_ori).as_rotvec()
|
| 85 |
+
|
| 86 |
+
stacked_abs_actions = np.concatenate([
|
| 87 |
+
action_goal_pos,
|
| 88 |
+
action_goal_ori,
|
| 89 |
+
action_gripper
|
| 90 |
+
], axis=-1)
|
| 91 |
+
abs_actions = stacked_abs_actions.reshape(actions.shape)
|
| 92 |
+
return abs_actions
|
| 93 |
+
|
| 94 |
+
def convert_idx(self, idx):
|
| 95 |
+
file = self.file
|
| 96 |
+
demo = file[f'data/demo_{idx}']
|
| 97 |
+
# input
|
| 98 |
+
states = demo['states'][:]
|
| 99 |
+
actions = demo['actions'][:]
|
| 100 |
+
|
| 101 |
+
# generate abs actions
|
| 102 |
+
abs_actions = self.convert_actions(states, actions)
|
| 103 |
+
return abs_actions
|
| 104 |
+
|
| 105 |
+
def convert_and_eval_idx(self, idx):
|
| 106 |
+
env = self.env
|
| 107 |
+
abs_env = self.abs_env
|
| 108 |
+
file = self.file
|
| 109 |
+
# first step have high error for some reason, not representative
|
| 110 |
+
eval_skip_steps = 1
|
| 111 |
+
|
| 112 |
+
demo = file[f'data/demo_{idx}']
|
| 113 |
+
# input
|
| 114 |
+
states = demo['states'][:]
|
| 115 |
+
actions = demo['actions'][:]
|
| 116 |
+
|
| 117 |
+
# generate abs actions
|
| 118 |
+
abs_actions = self.convert_actions(states, actions)
|
| 119 |
+
|
| 120 |
+
# verify
|
| 121 |
+
robot0_eef_pos = demo['obs']['robot0_eef_pos'][:]
|
| 122 |
+
robot0_eef_quat = demo['obs']['robot0_eef_quat'][:]
|
| 123 |
+
|
| 124 |
+
delta_error_info = self.evaluate_rollout_error(
|
| 125 |
+
env, states, actions, robot0_eef_pos, robot0_eef_quat,
|
| 126 |
+
metric_skip_steps=eval_skip_steps)
|
| 127 |
+
abs_error_info = self.evaluate_rollout_error(
|
| 128 |
+
abs_env, states, abs_actions, robot0_eef_pos, robot0_eef_quat,
|
| 129 |
+
metric_skip_steps=eval_skip_steps)
|
| 130 |
+
|
| 131 |
+
info = {
|
| 132 |
+
'delta_max_error': delta_error_info,
|
| 133 |
+
'abs_max_error': abs_error_info
|
| 134 |
+
}
|
| 135 |
+
return abs_actions, info
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def evaluate_rollout_error(env,
|
| 139 |
+
states, actions,
|
| 140 |
+
robot0_eef_pos,
|
| 141 |
+
robot0_eef_quat,
|
| 142 |
+
metric_skip_steps=1):
|
| 143 |
+
# first step have high error for some reason, not representative
|
| 144 |
+
|
| 145 |
+
# evaluate abs actions
|
| 146 |
+
rollout_next_states = list()
|
| 147 |
+
rollout_next_eef_pos = list()
|
| 148 |
+
rollout_next_eef_quat = list()
|
| 149 |
+
obs = env.reset_to({'states': states[0]})
|
| 150 |
+
for i in range(len(states)):
|
| 151 |
+
obs = env.reset_to({'states': states[i]})
|
| 152 |
+
obs, reward, done, info = env.step(actions[i])
|
| 153 |
+
obs = env.get_observation()
|
| 154 |
+
rollout_next_states.append(env.get_state()['states'])
|
| 155 |
+
rollout_next_eef_pos.append(obs['robot0_eef_pos'])
|
| 156 |
+
rollout_next_eef_quat.append(obs['robot0_eef_quat'])
|
| 157 |
+
rollout_next_states = np.array(rollout_next_states)
|
| 158 |
+
rollout_next_eef_pos = np.array(rollout_next_eef_pos)
|
| 159 |
+
rollout_next_eef_quat = np.array(rollout_next_eef_quat)
|
| 160 |
+
|
| 161 |
+
next_state_diff = states[1:] - rollout_next_states[:-1]
|
| 162 |
+
max_next_state_diff = np.max(np.abs(next_state_diff[metric_skip_steps:]))
|
| 163 |
+
|
| 164 |
+
next_eef_pos_diff = robot0_eef_pos[1:] - rollout_next_eef_pos[:-1]
|
| 165 |
+
next_eef_pos_dist = np.linalg.norm(next_eef_pos_diff, axis=-1)
|
| 166 |
+
max_next_eef_pos_dist = next_eef_pos_dist[metric_skip_steps:].max()
|
| 167 |
+
|
| 168 |
+
next_eef_rot_diff = Rotation.from_quat(robot0_eef_quat[1:]) \
|
| 169 |
+
* Rotation.from_quat(rollout_next_eef_quat[:-1]).inv()
|
| 170 |
+
next_eef_rot_dist = next_eef_rot_diff.magnitude()
|
| 171 |
+
max_next_eef_rot_dist = next_eef_rot_dist[metric_skip_steps:].max()
|
| 172 |
+
|
| 173 |
+
info = {
|
| 174 |
+
'state': max_next_state_diff,
|
| 175 |
+
'pos': max_next_eef_pos_dist,
|
| 176 |
+
'rot': max_next_eef_rot_dist
|
| 177 |
+
}
|
| 178 |
+
return info
|
| 179 |
+
|
| 180 |
+
class RobomimicObsConverter:
|
| 181 |
+
def __init__(self, dataset_path, algo_name='bc'):
|
| 182 |
+
# default BC config
|
| 183 |
+
# config = config_factory(algo_name=algo_name)
|
| 184 |
+
|
| 185 |
+
# read config to set up metadata for observation modalities (e.g. detecting rgb observations)
|
| 186 |
+
# must ran before create dataset
|
| 187 |
+
# ObsUtils.initialize_obs_utils_with_config(config)
|
| 188 |
+
|
| 189 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
|
| 190 |
+
# env_meta['env_kwargs']['camera_names'] = ['birdview', 'agentview', 'sideview', 'robot0_eye_in_hand']
|
| 191 |
+
|
| 192 |
+
env = EnvUtils.create_env_for_data_processing(
|
| 193 |
+
env_meta=env_meta,
|
| 194 |
+
# camera_names=['frontview', 'birdview', 'agentview', 'sideview', 'agentview_full', 'robot0_robotview', 'robot0_eye_in_hand'],
|
| 195 |
+
camera_names=['birdview', 'agentview', 'sideview', 'robot0_eye_in_hand'],
|
| 196 |
+
camera_height=84,
|
| 197 |
+
camera_width=84,
|
| 198 |
+
reward_shaping=False,
|
| 199 |
+
)
|
| 200 |
+
# env = EnvUtils.create_env_from_metadata(
|
| 201 |
+
# env_meta=env_meta,
|
| 202 |
+
# render=True,
|
| 203 |
+
# render_offscreen=True,
|
| 204 |
+
# use_image_obs=True,
|
| 205 |
+
# )
|
| 206 |
+
|
| 207 |
+
self.env = env
|
| 208 |
+
self.file = h5py.File(dataset_path, 'r')
|
| 209 |
+
|
| 210 |
+
def __len__(self):
|
| 211 |
+
return len(self.file['data'])
|
| 212 |
+
|
| 213 |
+
def convert_obs(self, initial_state, states):
|
| 214 |
+
obss = []
|
| 215 |
+
self.env.reset()
|
| 216 |
+
obs = self.env.reset_to(initial_state)
|
| 217 |
+
obss.append(obs)
|
| 218 |
+
for i in range(1, len(states)):
|
| 219 |
+
obs = self.env.reset_to({'states': states[i]})
|
| 220 |
+
obss.append(obs)
|
| 221 |
+
return TensorUtils.list_of_flat_dict_to_dict_of_list(obss)
|
| 222 |
+
|
| 223 |
+
def convert_idx(self, idx):
|
| 224 |
+
file = self.file
|
| 225 |
+
demo = file[f'data/demo_{idx}']
|
| 226 |
+
# input
|
| 227 |
+
|
| 228 |
+
states = demo['states'][:]
|
| 229 |
+
initial_state = dict(states=states[0])
|
| 230 |
+
initial_state["model"] = demo.attrs["model_file"]
|
| 231 |
+
|
| 232 |
+
# generate abs actions
|
| 233 |
+
obss = self.convert_obs(initial_state, states)
|
| 234 |
+
del obss['birdview_image']
|
| 235 |
+
del obss['birdview_depth']
|
| 236 |
+
del obss['agentview_depth']
|
| 237 |
+
del obss['sideview_image']
|
| 238 |
+
del obss['sideview_depth']
|
| 239 |
+
del obss['robot0_eye_in_hand_depth']
|
| 240 |
+
return obss
|
equidiff/equi_diffpo/dataset/robomimic_replay_image_dataset.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import h5py
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import zarr
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import hashlib
|
| 12 |
+
from filelock import FileLock
|
| 13 |
+
from threadpoolctl import threadpool_limits
|
| 14 |
+
import concurrent.futures
|
| 15 |
+
import multiprocessing
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 18 |
+
from equi_diffpo.dataset.base_dataset import BaseImageDataset, LinearNormalizer
|
| 19 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
| 20 |
+
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
|
| 21 |
+
from equi_diffpo.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
|
| 22 |
+
from equi_diffpo.common.replay_buffer import ReplayBuffer
|
| 23 |
+
from equi_diffpo.common.sampler import SequenceSampler, get_val_mask
|
| 24 |
+
from equi_diffpo.common.normalize_util import (
|
| 25 |
+
robomimic_abs_action_only_normalizer_from_stat,
|
| 26 |
+
robomimic_abs_action_only_dual_arm_normalizer_from_stat,
|
| 27 |
+
get_range_normalizer_from_stat,
|
| 28 |
+
get_image_range_normalizer,
|
| 29 |
+
get_identity_normalizer_from_stat,
|
| 30 |
+
array_to_stats
|
| 31 |
+
)
|
| 32 |
+
register_codecs()
|
| 33 |
+
|
| 34 |
+
class RobomimicReplayImageDataset(BaseImageDataset):
|
| 35 |
+
def __init__(self,
|
| 36 |
+
shape_meta: dict,
|
| 37 |
+
dataset_path: str,
|
| 38 |
+
horizon=1,
|
| 39 |
+
pad_before=0,
|
| 40 |
+
pad_after=0,
|
| 41 |
+
n_obs_steps=None,
|
| 42 |
+
abs_action=False,
|
| 43 |
+
rotation_rep='rotation_6d', # ignored when abs_action=False
|
| 44 |
+
use_legacy_normalizer=False,
|
| 45 |
+
use_cache=False,
|
| 46 |
+
seed=42,
|
| 47 |
+
val_ratio=0.0,
|
| 48 |
+
n_demo=100
|
| 49 |
+
):
|
| 50 |
+
self.n_demo = n_demo
|
| 51 |
+
rotation_transformer = RotationTransformer(
|
| 52 |
+
from_rep='axis_angle', to_rep=rotation_rep)
|
| 53 |
+
|
| 54 |
+
replay_buffer = None
|
| 55 |
+
if use_cache:
|
| 56 |
+
cache_zarr_path = dataset_path + f'.{n_demo}.' + '.zarr.zip'
|
| 57 |
+
cache_lock_path = cache_zarr_path + '.lock'
|
| 58 |
+
print('Acquiring lock on cache.')
|
| 59 |
+
with FileLock(cache_lock_path):
|
| 60 |
+
if not os.path.exists(cache_zarr_path):
|
| 61 |
+
# cache does not exists
|
| 62 |
+
try:
|
| 63 |
+
print('Cache does not exist. Creating!')
|
| 64 |
+
# store = zarr.DirectoryStore(cache_zarr_path)
|
| 65 |
+
replay_buffer = _convert_robomimic_to_replay(
|
| 66 |
+
store=zarr.MemoryStore(),
|
| 67 |
+
shape_meta=shape_meta,
|
| 68 |
+
dataset_path=dataset_path,
|
| 69 |
+
abs_action=abs_action,
|
| 70 |
+
rotation_transformer=rotation_transformer,
|
| 71 |
+
n_demo=n_demo)
|
| 72 |
+
print('Saving cache to disk.')
|
| 73 |
+
with zarr.ZipStore(cache_zarr_path) as zip_store:
|
| 74 |
+
replay_buffer.save_to_store(
|
| 75 |
+
store=zip_store
|
| 76 |
+
)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
shutil.rmtree(cache_zarr_path)
|
| 79 |
+
raise e
|
| 80 |
+
else:
|
| 81 |
+
print('Loading cached ReplayBuffer from Disk.')
|
| 82 |
+
with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
|
| 83 |
+
replay_buffer = ReplayBuffer.copy_from_store(
|
| 84 |
+
src_store=zip_store, store=zarr.MemoryStore())
|
| 85 |
+
print('Loaded!')
|
| 86 |
+
else:
|
| 87 |
+
replay_buffer = _convert_robomimic_to_replay(
|
| 88 |
+
store=zarr.MemoryStore(),
|
| 89 |
+
shape_meta=shape_meta,
|
| 90 |
+
dataset_path=dataset_path,
|
| 91 |
+
abs_action=abs_action,
|
| 92 |
+
rotation_transformer=rotation_transformer,
|
| 93 |
+
n_demo=n_demo)
|
| 94 |
+
|
| 95 |
+
rgb_keys = list()
|
| 96 |
+
lowdim_keys = list()
|
| 97 |
+
obs_shape_meta = shape_meta['obs']
|
| 98 |
+
for key, attr in obs_shape_meta.items():
|
| 99 |
+
type = attr.get('type', 'low_dim')
|
| 100 |
+
if type == 'rgb':
|
| 101 |
+
rgb_keys.append(key)
|
| 102 |
+
elif type == 'low_dim':
|
| 103 |
+
lowdim_keys.append(key)
|
| 104 |
+
|
| 105 |
+
# for key in rgb_keys:
|
| 106 |
+
# replay_buffer[key].compressor.numthreads=1
|
| 107 |
+
|
| 108 |
+
key_first_k = dict()
|
| 109 |
+
if n_obs_steps is not None:
|
| 110 |
+
# only take first k obs from images
|
| 111 |
+
for key in rgb_keys + lowdim_keys:
|
| 112 |
+
key_first_k[key] = n_obs_steps
|
| 113 |
+
|
| 114 |
+
val_mask = get_val_mask(
|
| 115 |
+
n_episodes=replay_buffer.n_episodes,
|
| 116 |
+
val_ratio=val_ratio,
|
| 117 |
+
seed=seed)
|
| 118 |
+
train_mask = ~val_mask
|
| 119 |
+
sampler = SequenceSampler(
|
| 120 |
+
replay_buffer=replay_buffer,
|
| 121 |
+
sequence_length=horizon,
|
| 122 |
+
pad_before=pad_before,
|
| 123 |
+
pad_after=pad_after,
|
| 124 |
+
episode_mask=train_mask,
|
| 125 |
+
key_first_k=key_first_k)
|
| 126 |
+
|
| 127 |
+
self.replay_buffer = replay_buffer
|
| 128 |
+
self.sampler = sampler
|
| 129 |
+
self.shape_meta = shape_meta
|
| 130 |
+
self.rgb_keys = rgb_keys
|
| 131 |
+
self.lowdim_keys = lowdim_keys
|
| 132 |
+
self.abs_action = abs_action
|
| 133 |
+
self.n_obs_steps = n_obs_steps
|
| 134 |
+
self.train_mask = train_mask
|
| 135 |
+
self.horizon = horizon
|
| 136 |
+
self.pad_before = pad_before
|
| 137 |
+
self.pad_after = pad_after
|
| 138 |
+
self.use_legacy_normalizer = use_legacy_normalizer
|
| 139 |
+
|
| 140 |
+
def get_validation_dataset(self):
|
| 141 |
+
val_set = copy.copy(self)
|
| 142 |
+
val_set.sampler = SequenceSampler(
|
| 143 |
+
replay_buffer=self.replay_buffer,
|
| 144 |
+
sequence_length=self.horizon,
|
| 145 |
+
pad_before=self.pad_before,
|
| 146 |
+
pad_after=self.pad_after,
|
| 147 |
+
episode_mask=~self.train_mask
|
| 148 |
+
)
|
| 149 |
+
val_set.train_mask = ~self.train_mask
|
| 150 |
+
return val_set
|
| 151 |
+
|
| 152 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 153 |
+
normalizer = LinearNormalizer()
|
| 154 |
+
|
| 155 |
+
# action
|
| 156 |
+
stat = array_to_stats(self.replay_buffer['action'])
|
| 157 |
+
if self.abs_action:
|
| 158 |
+
if stat['mean'].shape[-1] > 10:
|
| 159 |
+
# dual arm
|
| 160 |
+
this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
|
| 161 |
+
else:
|
| 162 |
+
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
|
| 163 |
+
|
| 164 |
+
if self.use_legacy_normalizer:
|
| 165 |
+
this_normalizer = normalizer_from_stat(stat)
|
| 166 |
+
else:
|
| 167 |
+
# already normalized
|
| 168 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 169 |
+
normalizer['action'] = this_normalizer
|
| 170 |
+
|
| 171 |
+
# obs
|
| 172 |
+
for key in self.lowdim_keys:
|
| 173 |
+
stat = array_to_stats(self.replay_buffer[key])
|
| 174 |
+
|
| 175 |
+
if key.endswith('pos'):
|
| 176 |
+
this_normalizer = get_range_normalizer_from_stat(stat)
|
| 177 |
+
elif key.endswith('quat'):
|
| 178 |
+
# quaternion is in [-1,1] already
|
| 179 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 180 |
+
elif key.endswith('qpos'):
|
| 181 |
+
this_normalizer = get_range_normalizer_from_stat(stat)
|
| 182 |
+
else:
|
| 183 |
+
raise RuntimeError('unsupported')
|
| 184 |
+
normalizer[key] = this_normalizer
|
| 185 |
+
|
| 186 |
+
# image
|
| 187 |
+
for key in self.rgb_keys:
|
| 188 |
+
normalizer[key] = get_image_range_normalizer()
|
| 189 |
+
return normalizer
|
| 190 |
+
|
| 191 |
+
def get_all_actions(self) -> torch.Tensor:
|
| 192 |
+
return torch.from_numpy(self.replay_buffer['action'])
|
| 193 |
+
|
| 194 |
+
def __len__(self):
|
| 195 |
+
return len(self.sampler)
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 198 |
+
threadpool_limits(1)
|
| 199 |
+
data = self.sampler.sample_sequence(idx)
|
| 200 |
+
|
| 201 |
+
# to save RAM, only return first n_obs_steps of OBS
|
| 202 |
+
# since the rest will be discarded anyway.
|
| 203 |
+
# when self.n_obs_steps is None
|
| 204 |
+
# this slice does nothing (takes all)
|
| 205 |
+
T_slice = slice(self.n_obs_steps)
|
| 206 |
+
|
| 207 |
+
obs_dict = dict()
|
| 208 |
+
for key in self.rgb_keys:
|
| 209 |
+
# move channel last to channel first
|
| 210 |
+
# T,H,W,C
|
| 211 |
+
# convert uint8 image to float32
|
| 212 |
+
obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
|
| 213 |
+
).astype(np.float32) / 255.
|
| 214 |
+
# T,C,H,W
|
| 215 |
+
del data[key]
|
| 216 |
+
for key in self.lowdim_keys:
|
| 217 |
+
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
| 218 |
+
del data[key]
|
| 219 |
+
|
| 220 |
+
torch_data = {
|
| 221 |
+
'obs': dict_apply(obs_dict, torch.from_numpy),
|
| 222 |
+
'action': torch.from_numpy(data['action'].astype(np.float32))
|
| 223 |
+
}
|
| 224 |
+
return torch_data
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _convert_actions(raw_actions, abs_action, rotation_transformer):
|
| 228 |
+
actions = raw_actions
|
| 229 |
+
if abs_action:
|
| 230 |
+
is_dual_arm = False
|
| 231 |
+
if raw_actions.shape[-1] == 14:
|
| 232 |
+
# dual arm
|
| 233 |
+
raw_actions = raw_actions.reshape(-1,2,7)
|
| 234 |
+
is_dual_arm = True
|
| 235 |
+
|
| 236 |
+
pos = raw_actions[...,:3]
|
| 237 |
+
rot = raw_actions[...,3:6]
|
| 238 |
+
gripper = raw_actions[...,6:]
|
| 239 |
+
rot = rotation_transformer.forward(rot)
|
| 240 |
+
raw_actions = np.concatenate([
|
| 241 |
+
pos, rot, gripper
|
| 242 |
+
], axis=-1).astype(np.float32)
|
| 243 |
+
|
| 244 |
+
if is_dual_arm:
|
| 245 |
+
raw_actions = raw_actions.reshape(-1,20)
|
| 246 |
+
actions = raw_actions
|
| 247 |
+
return actions
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _convert_robomimic_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
|
| 251 |
+
n_workers=None, max_inflight_tasks=None, n_demo=100):
|
| 252 |
+
if n_workers is None:
|
| 253 |
+
n_workers = multiprocessing.cpu_count()
|
| 254 |
+
if max_inflight_tasks is None:
|
| 255 |
+
max_inflight_tasks = n_workers * 5
|
| 256 |
+
|
| 257 |
+
# parse shape_meta
|
| 258 |
+
rgb_keys = list()
|
| 259 |
+
lowdim_keys = list()
|
| 260 |
+
# construct compressors and chunks
|
| 261 |
+
obs_shape_meta = shape_meta['obs']
|
| 262 |
+
for key, attr in obs_shape_meta.items():
|
| 263 |
+
shape = attr['shape']
|
| 264 |
+
type = attr.get('type', 'low_dim')
|
| 265 |
+
if type == 'rgb':
|
| 266 |
+
rgb_keys.append(key)
|
| 267 |
+
elif type == 'low_dim':
|
| 268 |
+
lowdim_keys.append(key)
|
| 269 |
+
|
| 270 |
+
root = zarr.group(store)
|
| 271 |
+
data_group = root.require_group('data', overwrite=True)
|
| 272 |
+
meta_group = root.require_group('meta', overwrite=True)
|
| 273 |
+
|
| 274 |
+
with h5py.File(dataset_path) as file:
|
| 275 |
+
# count total steps
|
| 276 |
+
demos = file['data']
|
| 277 |
+
episode_ends = list()
|
| 278 |
+
prev_end = 0
|
| 279 |
+
for i in range(n_demo):
|
| 280 |
+
demo = demos[f'demo_{i}']
|
| 281 |
+
episode_length = demo['actions'].shape[0]
|
| 282 |
+
episode_end = prev_end + episode_length
|
| 283 |
+
prev_end = episode_end
|
| 284 |
+
episode_ends.append(episode_end)
|
| 285 |
+
n_steps = episode_ends[-1]
|
| 286 |
+
episode_starts = [0] + episode_ends[:-1]
|
| 287 |
+
_ = meta_group.array('episode_ends', episode_ends,
|
| 288 |
+
dtype=np.int64, compressor=None, overwrite=True)
|
| 289 |
+
|
| 290 |
+
# save lowdim data
|
| 291 |
+
for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
|
| 292 |
+
data_key = 'obs/' + key
|
| 293 |
+
if key == 'action':
|
| 294 |
+
data_key = 'actions'
|
| 295 |
+
this_data = list()
|
| 296 |
+
for i in range(n_demo):
|
| 297 |
+
demo = demos[f'demo_{i}']
|
| 298 |
+
this_data.append(demo[data_key][:].astype(np.float32))
|
| 299 |
+
this_data = np.concatenate(this_data, axis=0)
|
| 300 |
+
if key == 'action':
|
| 301 |
+
this_data = _convert_actions(
|
| 302 |
+
raw_actions=this_data,
|
| 303 |
+
abs_action=abs_action,
|
| 304 |
+
rotation_transformer=rotation_transformer
|
| 305 |
+
)
|
| 306 |
+
assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
|
| 307 |
+
else:
|
| 308 |
+
assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
|
| 309 |
+
_ = data_group.array(
|
| 310 |
+
name=key,
|
| 311 |
+
data=this_data,
|
| 312 |
+
shape=this_data.shape,
|
| 313 |
+
chunks=this_data.shape,
|
| 314 |
+
compressor=None,
|
| 315 |
+
dtype=this_data.dtype
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
|
| 319 |
+
try:
|
| 320 |
+
zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
|
| 321 |
+
# make sure we can successfully decode
|
| 322 |
+
_ = zarr_arr[zarr_idx]
|
| 323 |
+
return True
|
| 324 |
+
except Exception as e:
|
| 325 |
+
return False
|
| 326 |
+
|
| 327 |
+
with tqdm(total=n_steps*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
|
| 328 |
+
# one chunk per thread, therefore no synchronization needed
|
| 329 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
|
| 330 |
+
futures = set()
|
| 331 |
+
for key in rgb_keys:
|
| 332 |
+
data_key = 'obs/' + key
|
| 333 |
+
shape = tuple(shape_meta['obs'][key]['shape'])
|
| 334 |
+
c,h,w = shape
|
| 335 |
+
this_compressor = Jpeg2k(level=50)
|
| 336 |
+
img_arr = data_group.require_dataset(
|
| 337 |
+
name=key,
|
| 338 |
+
shape=(n_steps,h,w,c),
|
| 339 |
+
chunks=(1,h,w,c),
|
| 340 |
+
compressor=this_compressor,
|
| 341 |
+
dtype=np.uint8
|
| 342 |
+
)
|
| 343 |
+
for episode_idx in range(n_demo):
|
| 344 |
+
demo = demos[f'demo_{episode_idx}']
|
| 345 |
+
hdf5_arr = demo['obs'][key]
|
| 346 |
+
for hdf5_idx in range(hdf5_arr.shape[0]):
|
| 347 |
+
if len(futures) >= max_inflight_tasks:
|
| 348 |
+
# limit number of inflight tasks
|
| 349 |
+
completed, futures = concurrent.futures.wait(futures,
|
| 350 |
+
return_when=concurrent.futures.FIRST_COMPLETED)
|
| 351 |
+
for f in completed:
|
| 352 |
+
if not f.result():
|
| 353 |
+
raise RuntimeError('Failed to encode image!')
|
| 354 |
+
pbar.update(len(completed))
|
| 355 |
+
|
| 356 |
+
zarr_idx = episode_starts[episode_idx] + hdf5_idx
|
| 357 |
+
futures.add(
|
| 358 |
+
executor.submit(img_copy,
|
| 359 |
+
img_arr, zarr_idx, hdf5_arr, hdf5_idx))
|
| 360 |
+
completed, futures = concurrent.futures.wait(futures)
|
| 361 |
+
for f in completed:
|
| 362 |
+
if not f.result():
|
| 363 |
+
raise RuntimeError('Failed to encode image!')
|
| 364 |
+
pbar.update(len(completed))
|
| 365 |
+
|
| 366 |
+
replay_buffer = ReplayBuffer(root)
|
| 367 |
+
return replay_buffer
|
| 368 |
+
|
| 369 |
+
def normalizer_from_stat(stat):
|
| 370 |
+
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
|
| 371 |
+
scale = np.full_like(stat['max'], fill_value=1/max_abs)
|
| 372 |
+
offset = np.zeros_like(stat['max'])
|
| 373 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 374 |
+
scale=scale,
|
| 375 |
+
offset=offset,
|
| 376 |
+
input_stats_dict=stat
|
| 377 |
+
)
|
equidiff/equi_diffpo/dataset/robomimic_replay_image_sym_dataset.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from equi_diffpo.dataset.base_dataset import LinearNormalizer
|
| 2 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer
|
| 3 |
+
from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat
|
| 4 |
+
from equi_diffpo.common.normalize_util import (
|
| 5 |
+
robomimic_abs_action_only_symmetric_normalizer_from_stat,
|
| 6 |
+
get_range_normalizer_from_stat,
|
| 7 |
+
get_range_symmetric_normalizer_from_stat,
|
| 8 |
+
get_image_range_normalizer,
|
| 9 |
+
get_identity_normalizer_from_stat,
|
| 10 |
+
array_to_stats
|
| 11 |
+
)
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
shape_meta: dict,
|
| 17 |
+
dataset_path: str,
|
| 18 |
+
horizon=1,
|
| 19 |
+
pad_before=0,
|
| 20 |
+
pad_after=0,
|
| 21 |
+
n_obs_steps=None,
|
| 22 |
+
abs_action=False,
|
| 23 |
+
rotation_rep='rotation_6d', # ignored when abs_action=False
|
| 24 |
+
use_legacy_normalizer=False,
|
| 25 |
+
use_cache=False,
|
| 26 |
+
seed=42,
|
| 27 |
+
val_ratio=0.0,
|
| 28 |
+
n_demo=100
|
| 29 |
+
):
|
| 30 |
+
super().__init__(
|
| 31 |
+
shape_meta,
|
| 32 |
+
dataset_path,
|
| 33 |
+
horizon,
|
| 34 |
+
pad_before,
|
| 35 |
+
pad_after,
|
| 36 |
+
n_obs_steps,
|
| 37 |
+
abs_action,
|
| 38 |
+
rotation_rep,
|
| 39 |
+
use_legacy_normalizer,
|
| 40 |
+
use_cache,
|
| 41 |
+
seed,
|
| 42 |
+
val_ratio,
|
| 43 |
+
n_demo
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 47 |
+
normalizer = LinearNormalizer()
|
| 48 |
+
|
| 49 |
+
# action
|
| 50 |
+
stat = array_to_stats(self.replay_buffer['action'])
|
| 51 |
+
if self.abs_action:
|
| 52 |
+
if stat['mean'].shape[-1] > 10:
|
| 53 |
+
# dual arm
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
else:
|
| 56 |
+
this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
|
| 57 |
+
|
| 58 |
+
if self.use_legacy_normalizer:
|
| 59 |
+
this_normalizer = normalizer_from_stat(stat)
|
| 60 |
+
else:
|
| 61 |
+
# already normalized
|
| 62 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 63 |
+
normalizer['action'] = this_normalizer
|
| 64 |
+
|
| 65 |
+
# obs
|
| 66 |
+
for key in self.lowdim_keys:
|
| 67 |
+
stat = array_to_stats(self.replay_buffer[key])
|
| 68 |
+
|
| 69 |
+
if key.endswith('qpos'):
|
| 70 |
+
this_normalizer = get_range_normalizer_from_stat(stat)
|
| 71 |
+
elif key.endswith('pos'):
|
| 72 |
+
this_normalizer = get_range_symmetric_normalizer_from_stat(stat)
|
| 73 |
+
elif key.endswith('quat'):
|
| 74 |
+
# quaternion is in [-1,1] already
|
| 75 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 76 |
+
elif key.find('bbox') > -1:
|
| 77 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 78 |
+
else:
|
| 79 |
+
raise RuntimeError('unsupported')
|
| 80 |
+
normalizer[key] = this_normalizer
|
| 81 |
+
|
| 82 |
+
# image
|
| 83 |
+
for key in self.rgb_keys:
|
| 84 |
+
normalizer[key] = get_image_range_normalizer()
|
| 85 |
+
|
| 86 |
+
normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)})
|
| 87 |
+
normalizer['crops'] = get_image_range_normalizer()
|
| 88 |
+
|
| 89 |
+
return normalizer
|
| 90 |
+
|
equidiff/equi_diffpo/dataset/robomimic_replay_lowdim_dataset.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import h5py
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import copy
|
| 7 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 8 |
+
from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
|
| 9 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
| 10 |
+
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
|
| 11 |
+
from equi_diffpo.common.replay_buffer import ReplayBuffer
|
| 12 |
+
from equi_diffpo.common.sampler import (
|
| 13 |
+
SequenceSampler, get_val_mask, downsample_mask)
|
| 14 |
+
from equi_diffpo.common.normalize_util import (
|
| 15 |
+
robomimic_abs_action_only_normalizer_from_stat,
|
| 16 |
+
robomimic_abs_action_only_dual_arm_normalizer_from_stat,
|
| 17 |
+
get_identity_normalizer_from_stat,
|
| 18 |
+
array_to_stats
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
class RobomimicReplayLowdimDataset(BaseLowdimDataset):
|
| 22 |
+
def __init__(self,
|
| 23 |
+
dataset_path: str,
|
| 24 |
+
horizon=1,
|
| 25 |
+
pad_before=0,
|
| 26 |
+
pad_after=0,
|
| 27 |
+
obs_keys: List[str]=[
|
| 28 |
+
'object',
|
| 29 |
+
'robot0_eef_pos',
|
| 30 |
+
'robot0_eef_quat',
|
| 31 |
+
'robot0_gripper_qpos'],
|
| 32 |
+
abs_action=False,
|
| 33 |
+
rotation_rep='rotation_6d',
|
| 34 |
+
use_legacy_normalizer=False,
|
| 35 |
+
seed=42,
|
| 36 |
+
val_ratio=0.0,
|
| 37 |
+
max_train_episodes=None,
|
| 38 |
+
n_demo=100
|
| 39 |
+
):
|
| 40 |
+
obs_keys = list(obs_keys)
|
| 41 |
+
rotation_transformer = RotationTransformer(
|
| 42 |
+
from_rep='axis_angle', to_rep=rotation_rep)
|
| 43 |
+
|
| 44 |
+
replay_buffer = ReplayBuffer.create_empty_numpy()
|
| 45 |
+
with h5py.File(dataset_path) as file:
|
| 46 |
+
demos = file['data']
|
| 47 |
+
for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"):
|
| 48 |
+
demo = demos[f'demo_{i}']
|
| 49 |
+
episode = _data_to_obs(
|
| 50 |
+
raw_obs=demo['obs'],
|
| 51 |
+
raw_actions=demo['actions'][:].astype(np.float32),
|
| 52 |
+
obs_keys=obs_keys,
|
| 53 |
+
abs_action=abs_action,
|
| 54 |
+
rotation_transformer=rotation_transformer)
|
| 55 |
+
replay_buffer.add_episode(episode)
|
| 56 |
+
|
| 57 |
+
val_mask = get_val_mask(
|
| 58 |
+
n_episodes=replay_buffer.n_episodes,
|
| 59 |
+
val_ratio=val_ratio,
|
| 60 |
+
seed=seed)
|
| 61 |
+
train_mask = ~val_mask
|
| 62 |
+
train_mask = downsample_mask(
|
| 63 |
+
mask=train_mask,
|
| 64 |
+
max_n=max_train_episodes,
|
| 65 |
+
seed=seed)
|
| 66 |
+
|
| 67 |
+
sampler = SequenceSampler(
|
| 68 |
+
replay_buffer=replay_buffer,
|
| 69 |
+
sequence_length=horizon,
|
| 70 |
+
pad_before=pad_before,
|
| 71 |
+
pad_after=pad_after,
|
| 72 |
+
episode_mask=train_mask)
|
| 73 |
+
|
| 74 |
+
self.replay_buffer = replay_buffer
|
| 75 |
+
self.sampler = sampler
|
| 76 |
+
self.abs_action = abs_action
|
| 77 |
+
self.train_mask = train_mask
|
| 78 |
+
self.horizon = horizon
|
| 79 |
+
self.pad_before = pad_before
|
| 80 |
+
self.pad_after = pad_after
|
| 81 |
+
self.use_legacy_normalizer = use_legacy_normalizer
|
| 82 |
+
|
| 83 |
+
def get_validation_dataset(self):
|
| 84 |
+
val_set = copy.copy(self)
|
| 85 |
+
val_set.sampler = SequenceSampler(
|
| 86 |
+
replay_buffer=self.replay_buffer,
|
| 87 |
+
sequence_length=self.horizon,
|
| 88 |
+
pad_before=self.pad_before,
|
| 89 |
+
pad_after=self.pad_after,
|
| 90 |
+
episode_mask=~self.train_mask
|
| 91 |
+
)
|
| 92 |
+
val_set.train_mask = ~self.train_mask
|
| 93 |
+
return val_set
|
| 94 |
+
|
| 95 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 96 |
+
normalizer = LinearNormalizer()
|
| 97 |
+
|
| 98 |
+
# action
|
| 99 |
+
stat = array_to_stats(self.replay_buffer['action'])
|
| 100 |
+
if self.abs_action:
|
| 101 |
+
if stat['mean'].shape[-1] > 10:
|
| 102 |
+
# dual arm
|
| 103 |
+
this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
|
| 104 |
+
else:
|
| 105 |
+
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
|
| 106 |
+
|
| 107 |
+
if self.use_legacy_normalizer:
|
| 108 |
+
this_normalizer = normalizer_from_stat(stat)
|
| 109 |
+
else:
|
| 110 |
+
# already normalized
|
| 111 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 112 |
+
normalizer['action'] = this_normalizer
|
| 113 |
+
|
| 114 |
+
# aggregate obs stats
|
| 115 |
+
obs_stat = array_to_stats(self.replay_buffer['obs'])
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
normalizer['obs'] = normalizer_from_stat(obs_stat)
|
| 119 |
+
return normalizer
|
| 120 |
+
|
| 121 |
+
def get_all_actions(self) -> torch.Tensor:
|
| 122 |
+
return torch.from_numpy(self.replay_buffer['action'])
|
| 123 |
+
|
| 124 |
+
def __len__(self):
|
| 125 |
+
return len(self.sampler)
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 128 |
+
data = self.sampler.sample_sequence(idx)
|
| 129 |
+
torch_data = dict_apply(data, torch.from_numpy)
|
| 130 |
+
return torch_data
|
| 131 |
+
|
| 132 |
+
def normalizer_from_stat(stat):
|
| 133 |
+
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
|
| 134 |
+
scale = np.full_like(stat['max'], fill_value=1/max_abs)
|
| 135 |
+
offset = np.zeros_like(stat['max'])
|
| 136 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 137 |
+
scale=scale,
|
| 138 |
+
offset=offset,
|
| 139 |
+
input_stats_dict=stat
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
|
| 143 |
+
obs = np.concatenate([
|
| 144 |
+
raw_obs[key] for key in obs_keys
|
| 145 |
+
], axis=-1).astype(np.float32)
|
| 146 |
+
|
| 147 |
+
if abs_action:
|
| 148 |
+
is_dual_arm = False
|
| 149 |
+
if raw_actions.shape[-1] == 14:
|
| 150 |
+
# dual arm
|
| 151 |
+
raw_actions = raw_actions.reshape(-1,2,7)
|
| 152 |
+
is_dual_arm = True
|
| 153 |
+
|
| 154 |
+
pos = raw_actions[...,:3]
|
| 155 |
+
rot = raw_actions[...,3:6]
|
| 156 |
+
gripper = raw_actions[...,6:]
|
| 157 |
+
rot = rotation_transformer.forward(rot)
|
| 158 |
+
raw_actions = np.concatenate([
|
| 159 |
+
pos, rot, gripper
|
| 160 |
+
], axis=-1).astype(np.float32)
|
| 161 |
+
|
| 162 |
+
if is_dual_arm:
|
| 163 |
+
raw_actions = raw_actions.reshape(-1,20)
|
| 164 |
+
|
| 165 |
+
data = {
|
| 166 |
+
'obs': obs,
|
| 167 |
+
'action': raw_actions
|
| 168 |
+
}
|
| 169 |
+
return data
|
equidiff/equi_diffpo/dataset/robomimic_replay_lowdim_sym_dataset.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 5 |
+
from equi_diffpo.dataset.base_dataset import LinearNormalizer
|
| 6 |
+
from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat
|
| 7 |
+
from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat
|
| 8 |
+
from equi_diffpo.common.normalize_util import (
|
| 9 |
+
robomimic_abs_action_only_symmetric_normalizer_from_stat,
|
| 10 |
+
get_identity_normalizer_from_stat,
|
| 11 |
+
array_to_stats
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset):
|
| 16 |
+
def __init__(self,
|
| 17 |
+
dataset_path: str,
|
| 18 |
+
horizon=1,
|
| 19 |
+
pad_before=0,
|
| 20 |
+
pad_after=0,
|
| 21 |
+
obs_keys: List[str]=[
|
| 22 |
+
'object',
|
| 23 |
+
'robot0_eef_pos',
|
| 24 |
+
'robot0_eef_quat',
|
| 25 |
+
'robot0_gripper_qpos'],
|
| 26 |
+
abs_action=False,
|
| 27 |
+
rotation_rep='rotation_6d',
|
| 28 |
+
use_legacy_normalizer=False,
|
| 29 |
+
seed=42,
|
| 30 |
+
val_ratio=0.0,
|
| 31 |
+
max_train_episodes=None,
|
| 32 |
+
n_demo=100
|
| 33 |
+
):
|
| 34 |
+
super().__init__(
|
| 35 |
+
dataset_path,
|
| 36 |
+
horizon,
|
| 37 |
+
pad_before,
|
| 38 |
+
pad_after,
|
| 39 |
+
obs_keys,
|
| 40 |
+
abs_action,
|
| 41 |
+
rotation_rep,
|
| 42 |
+
use_legacy_normalizer,
|
| 43 |
+
seed,
|
| 44 |
+
val_ratio,
|
| 45 |
+
max_train_episodes,
|
| 46 |
+
n_demo,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 50 |
+
normalizer = LinearNormalizer()
|
| 51 |
+
|
| 52 |
+
# action
|
| 53 |
+
stat = array_to_stats(self.replay_buffer['action'])
|
| 54 |
+
if self.abs_action:
|
| 55 |
+
if stat['mean'].shape[-1] > 10:
|
| 56 |
+
# dual arm
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
else:
|
| 59 |
+
this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
|
| 60 |
+
|
| 61 |
+
if self.use_legacy_normalizer:
|
| 62 |
+
this_normalizer = normalizer_from_stat(stat)
|
| 63 |
+
else:
|
| 64 |
+
# already normalized
|
| 65 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 66 |
+
normalizer['action'] = this_normalizer
|
| 67 |
+
|
| 68 |
+
# aggregate obs stats
|
| 69 |
+
obs_stat = array_to_stats(self.replay_buffer['obs'])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
normalizer['obs'] = normalizer_from_stat(obs_stat)
|
| 73 |
+
return normalizer
|
equidiff/equi_diffpo/dataset/robomimic_replay_point_cloud_dataset.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import h5py
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import zarr
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import hashlib
|
| 12 |
+
from filelock import FileLock
|
| 13 |
+
from threadpoolctl import threadpool_limits
|
| 14 |
+
import concurrent.futures
|
| 15 |
+
import multiprocessing
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 18 |
+
from equi_diffpo.dataset.base_dataset import BaseImageDataset, LinearNormalizer
|
| 19 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
| 20 |
+
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
|
| 21 |
+
from equi_diffpo.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
|
| 22 |
+
from equi_diffpo.common.replay_buffer import ReplayBuffer
|
| 23 |
+
from equi_diffpo.common.sampler import SequenceSampler, get_val_mask
|
| 24 |
+
from equi_diffpo.common.normalize_util import (
|
| 25 |
+
robomimic_abs_action_only_normalizer_from_stat,
|
| 26 |
+
get_range_normalizer_from_stat,
|
| 27 |
+
get_voxel_identity_normalizer,
|
| 28 |
+
get_image_range_normalizer,
|
| 29 |
+
get_identity_normalizer_from_stat,
|
| 30 |
+
array_to_stats
|
| 31 |
+
)
|
| 32 |
+
register_codecs()
|
| 33 |
+
|
| 34 |
+
class RobomimicReplayPointCloudDataset(BaseImageDataset):
|
| 35 |
+
def __init__(self,
|
| 36 |
+
shape_meta: dict,
|
| 37 |
+
dataset_path: str,
|
| 38 |
+
horizon=1,
|
| 39 |
+
pad_before=0,
|
| 40 |
+
pad_after=0,
|
| 41 |
+
n_obs_steps=None,
|
| 42 |
+
abs_action=False,
|
| 43 |
+
rotation_rep='rotation_6d', # ignored when abs_action=False
|
| 44 |
+
use_legacy_normalizer=False,
|
| 45 |
+
use_cache=False,
|
| 46 |
+
seed=42,
|
| 47 |
+
val_ratio=0.0,
|
| 48 |
+
n_demo=100,
|
| 49 |
+
):
|
| 50 |
+
self.n_demo = n_demo
|
| 51 |
+
rotation_transformer = RotationTransformer(
|
| 52 |
+
from_rep='axis_angle', to_rep=rotation_rep)
|
| 53 |
+
|
| 54 |
+
replay_buffer = None
|
| 55 |
+
if use_cache:
|
| 56 |
+
cache_zarr_path = dataset_path + f'.{n_demo}.' + '.zarr.zip'
|
| 57 |
+
cache_lock_path = cache_zarr_path + '.lock'
|
| 58 |
+
print('Acquiring lock on cache.')
|
| 59 |
+
with FileLock(cache_lock_path):
|
| 60 |
+
if not os.path.exists(cache_zarr_path):
|
| 61 |
+
# cache does not exists
|
| 62 |
+
try:
|
| 63 |
+
print('Cache does not exist. Creating!')
|
| 64 |
+
# store = zarr.DirectoryStore(cache_zarr_path)
|
| 65 |
+
replay_buffer = _convert_point_cloud_to_replay(
|
| 66 |
+
store=zarr.MemoryStore(),
|
| 67 |
+
shape_meta=shape_meta,
|
| 68 |
+
dataset_path=dataset_path,
|
| 69 |
+
abs_action=abs_action,
|
| 70 |
+
rotation_transformer=rotation_transformer,
|
| 71 |
+
n_demo=n_demo)
|
| 72 |
+
print('Saving cache to disk.')
|
| 73 |
+
with zarr.ZipStore(cache_zarr_path) as zip_store:
|
| 74 |
+
replay_buffer.save_to_store(
|
| 75 |
+
store=zip_store
|
| 76 |
+
)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
shutil.rmtree(cache_zarr_path)
|
| 79 |
+
raise e
|
| 80 |
+
else:
|
| 81 |
+
print('Loading cached ReplayBuffer from Disk.')
|
| 82 |
+
with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
|
| 83 |
+
replay_buffer = ReplayBuffer.copy_from_store(
|
| 84 |
+
src_store=zip_store, store=zarr.MemoryStore())
|
| 85 |
+
print('Loaded!')
|
| 86 |
+
else:
|
| 87 |
+
replay_buffer = _convert_point_cloud_to_replay(
|
| 88 |
+
store=zarr.MemoryStore(),
|
| 89 |
+
shape_meta=shape_meta,
|
| 90 |
+
dataset_path=dataset_path,
|
| 91 |
+
abs_action=abs_action,
|
| 92 |
+
rotation_transformer=rotation_transformer,
|
| 93 |
+
n_demo=n_demo)
|
| 94 |
+
|
| 95 |
+
rgb_keys = list()
|
| 96 |
+
pc_keys = list()
|
| 97 |
+
lowdim_keys = list()
|
| 98 |
+
obs_shape_meta = shape_meta['obs']
|
| 99 |
+
for key, attr in obs_shape_meta.items():
|
| 100 |
+
type = attr.get('type', 'low_dim')
|
| 101 |
+
if type == 'rgb':
|
| 102 |
+
rgb_keys.append(key)
|
| 103 |
+
if type == 'point_cloud':
|
| 104 |
+
pc_keys.append(key)
|
| 105 |
+
elif type == 'low_dim':
|
| 106 |
+
lowdim_keys.append(key)
|
| 107 |
+
|
| 108 |
+
# for key in rgb_keys:
|
| 109 |
+
# replay_buffer[key].compressor.numthreads=1
|
| 110 |
+
|
| 111 |
+
key_first_k = dict()
|
| 112 |
+
if n_obs_steps is not None:
|
| 113 |
+
# only take first k obs from images
|
| 114 |
+
for key in rgb_keys + pc_keys + lowdim_keys:
|
| 115 |
+
key_first_k[key] = n_obs_steps
|
| 116 |
+
|
| 117 |
+
val_mask = get_val_mask(
|
| 118 |
+
n_episodes=replay_buffer.n_episodes,
|
| 119 |
+
val_ratio=val_ratio,
|
| 120 |
+
seed=seed)
|
| 121 |
+
train_mask = ~val_mask
|
| 122 |
+
sampler = SequenceSampler(
|
| 123 |
+
replay_buffer=replay_buffer,
|
| 124 |
+
sequence_length=horizon,
|
| 125 |
+
pad_before=pad_before,
|
| 126 |
+
pad_after=pad_after,
|
| 127 |
+
episode_mask=train_mask,
|
| 128 |
+
key_first_k=key_first_k)
|
| 129 |
+
|
| 130 |
+
self.replay_buffer = replay_buffer
|
| 131 |
+
self.sampler = sampler
|
| 132 |
+
self.shape_meta = shape_meta
|
| 133 |
+
self.rgb_keys = rgb_keys
|
| 134 |
+
self.pc_keys = pc_keys
|
| 135 |
+
self.lowdim_keys = lowdim_keys
|
| 136 |
+
self.abs_action = abs_action
|
| 137 |
+
self.n_obs_steps = n_obs_steps
|
| 138 |
+
self.train_mask = train_mask
|
| 139 |
+
self.horizon = horizon
|
| 140 |
+
self.pad_before = pad_before
|
| 141 |
+
self.pad_after = pad_after
|
| 142 |
+
self.use_legacy_normalizer = use_legacy_normalizer
|
| 143 |
+
|
| 144 |
+
def get_validation_dataset(self):
|
| 145 |
+
val_set = copy.copy(self)
|
| 146 |
+
val_set.sampler = SequenceSampler(
|
| 147 |
+
replay_buffer=self.replay_buffer,
|
| 148 |
+
sequence_length=self.horizon,
|
| 149 |
+
pad_before=self.pad_before,
|
| 150 |
+
pad_after=self.pad_after,
|
| 151 |
+
episode_mask=~self.train_mask
|
| 152 |
+
)
|
| 153 |
+
val_set.train_mask = ~self.train_mask
|
| 154 |
+
return val_set
|
| 155 |
+
|
| 156 |
+
def get_normalizer(self, mode='limits', **kwargs) -> LinearNormalizer:
|
| 157 |
+
data = {
|
| 158 |
+
'action': self.replay_buffer['action'],
|
| 159 |
+
'robot0_eef_pos': self.replay_buffer['robot0_eef_pos'][...,:],
|
| 160 |
+
'robot0_eef_quat': self.replay_buffer['robot0_eef_quat'][...,:],
|
| 161 |
+
'robot0_gripper_qpos': self.replay_buffer['robot0_gripper_qpos'][...,:],
|
| 162 |
+
'point_cloud': self.replay_buffer['point_cloud'],
|
| 163 |
+
}
|
| 164 |
+
normalizer = LinearNormalizer()
|
| 165 |
+
normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
|
| 166 |
+
# normalizer['point_cloud'] = SingleFieldLinearNormalizer.create_identity()
|
| 167 |
+
return normalizer
|
| 168 |
+
|
| 169 |
+
def get_all_actions(self) -> torch.Tensor:
|
| 170 |
+
return torch.from_numpy(self.replay_buffer['action'])
|
| 171 |
+
|
| 172 |
+
def __len__(self):
|
| 173 |
+
return len(self.sampler)
|
| 174 |
+
|
| 175 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 176 |
+
threadpool_limits(1)
|
| 177 |
+
data = self.sampler.sample_sequence(idx)
|
| 178 |
+
|
| 179 |
+
# to save RAM, only return first n_obs_steps of OBS
|
| 180 |
+
# since the rest will be discarded anyway.
|
| 181 |
+
# when self.n_obs_steps is None
|
| 182 |
+
# this slice does nothing (takes all)
|
| 183 |
+
T_slice = slice(self.n_obs_steps)
|
| 184 |
+
|
| 185 |
+
obs_dict = dict()
|
| 186 |
+
for key in self.rgb_keys:
|
| 187 |
+
# move channel last to channel first
|
| 188 |
+
# T,H,W,C
|
| 189 |
+
# convert uint8 image to float32
|
| 190 |
+
obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
|
| 191 |
+
).astype(np.float32) / 255.
|
| 192 |
+
# T,C,H,W
|
| 193 |
+
del data[key]
|
| 194 |
+
for key in self.pc_keys:
|
| 195 |
+
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
| 196 |
+
del data[key]
|
| 197 |
+
for key in self.lowdim_keys:
|
| 198 |
+
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
| 199 |
+
del data[key]
|
| 200 |
+
|
| 201 |
+
torch_data = {
|
| 202 |
+
'obs': dict_apply(obs_dict, torch.from_numpy),
|
| 203 |
+
'action': torch.from_numpy(data['action'].astype(np.float32))
|
| 204 |
+
}
|
| 205 |
+
return torch_data
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _convert_actions(raw_actions, abs_action, rotation_transformer):
|
| 209 |
+
actions = raw_actions
|
| 210 |
+
if abs_action:
|
| 211 |
+
is_dual_arm = False
|
| 212 |
+
if raw_actions.shape[-1] == 14:
|
| 213 |
+
# dual arm
|
| 214 |
+
raw_actions = raw_actions.reshape(-1,2,7)
|
| 215 |
+
is_dual_arm = True
|
| 216 |
+
|
| 217 |
+
pos = raw_actions[...,:3]
|
| 218 |
+
rot = raw_actions[...,3:6]
|
| 219 |
+
gripper = raw_actions[...,6:]
|
| 220 |
+
rot = rotation_transformer.forward(rot)
|
| 221 |
+
raw_actions = np.concatenate([
|
| 222 |
+
pos, rot, gripper
|
| 223 |
+
], axis=-1).astype(np.float32)
|
| 224 |
+
|
| 225 |
+
if is_dual_arm:
|
| 226 |
+
raw_actions = raw_actions.reshape(-1,20)
|
| 227 |
+
actions = raw_actions
|
| 228 |
+
return actions
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _convert_point_cloud_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
|
| 232 |
+
n_workers=None, max_inflight_tasks=None, n_demo=100):
|
| 233 |
+
if n_workers is None:
|
| 234 |
+
n_workers = multiprocessing.cpu_count()
|
| 235 |
+
if max_inflight_tasks is None:
|
| 236 |
+
max_inflight_tasks = n_workers * 5
|
| 237 |
+
|
| 238 |
+
# parse shape_meta
|
| 239 |
+
pc_keys = list()
|
| 240 |
+
rgb_keys = list()
|
| 241 |
+
lowdim_keys = list()
|
| 242 |
+
# construct compressors and chunks
|
| 243 |
+
obs_shape_meta = shape_meta['obs']
|
| 244 |
+
for key, attr in obs_shape_meta.items():
|
| 245 |
+
shape = attr['shape']
|
| 246 |
+
type = attr.get('type', 'low_dim')
|
| 247 |
+
if type == 'rgb':
|
| 248 |
+
rgb_keys.append(key)
|
| 249 |
+
elif type == 'point_cloud':
|
| 250 |
+
pc_keys.append(key)
|
| 251 |
+
elif type == 'low_dim':
|
| 252 |
+
lowdim_keys.append(key)
|
| 253 |
+
|
| 254 |
+
root = zarr.group(store)
|
| 255 |
+
data_group = root.require_group('data', overwrite=True)
|
| 256 |
+
meta_group = root.require_group('meta', overwrite=True)
|
| 257 |
+
|
| 258 |
+
with h5py.File(dataset_path) as file:
|
| 259 |
+
# count total steps
|
| 260 |
+
demos = file['data']
|
| 261 |
+
episode_ends = list()
|
| 262 |
+
prev_end = 0
|
| 263 |
+
n_demo = min(n_demo, len(demos))
|
| 264 |
+
for i in range(n_demo):
|
| 265 |
+
demo = demos[f'demo_{i}']
|
| 266 |
+
episode_length = demo['actions'].shape[0]
|
| 267 |
+
episode_end = prev_end + episode_length
|
| 268 |
+
prev_end = episode_end
|
| 269 |
+
episode_ends.append(episode_end)
|
| 270 |
+
n_steps = episode_ends[-1]
|
| 271 |
+
episode_starts = [0] + episode_ends[:-1]
|
| 272 |
+
_ = meta_group.array('episode_ends', episode_ends,
|
| 273 |
+
dtype=np.int64, compressor=None, overwrite=True)
|
| 274 |
+
|
| 275 |
+
# save lowdim data
|
| 276 |
+
for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
|
| 277 |
+
data_key = 'obs/' + key
|
| 278 |
+
if key == 'action':
|
| 279 |
+
data_key = 'actions'
|
| 280 |
+
this_data = list()
|
| 281 |
+
for i in range(n_demo):
|
| 282 |
+
demo = demos[f'demo_{i}']
|
| 283 |
+
this_data.append(demo[data_key][:].astype(np.float32))
|
| 284 |
+
this_data = np.concatenate(this_data, axis=0)
|
| 285 |
+
if key == 'action':
|
| 286 |
+
this_data = _convert_actions(
|
| 287 |
+
raw_actions=this_data,
|
| 288 |
+
abs_action=abs_action,
|
| 289 |
+
rotation_transformer=rotation_transformer
|
| 290 |
+
)
|
| 291 |
+
assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
|
| 292 |
+
else:
|
| 293 |
+
assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
|
| 294 |
+
_ = data_group.array(
|
| 295 |
+
name=key,
|
| 296 |
+
data=this_data,
|
| 297 |
+
shape=this_data.shape,
|
| 298 |
+
chunks=this_data.shape,
|
| 299 |
+
compressor=None,
|
| 300 |
+
dtype=this_data.dtype
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def pc_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
|
| 304 |
+
try:
|
| 305 |
+
zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
|
| 306 |
+
_ = zarr_arr[zarr_idx]
|
| 307 |
+
return True
|
| 308 |
+
except Exception as e:
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
|
| 312 |
+
try:
|
| 313 |
+
zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
|
| 314 |
+
# make sure we can successfully decode
|
| 315 |
+
_ = zarr_arr[zarr_idx]
|
| 316 |
+
return True
|
| 317 |
+
except Exception as e:
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
with tqdm(total=n_steps*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
|
| 321 |
+
# one chunk per thread, therefore no synchronization needed
|
| 322 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
|
| 323 |
+
futures = set()
|
| 324 |
+
for key in rgb_keys:
|
| 325 |
+
data_key = 'obs/' + key
|
| 326 |
+
shape = tuple(shape_meta['obs'][key]['shape'])
|
| 327 |
+
c,h,w = shape
|
| 328 |
+
this_compressor = Jpeg2k(level=50)
|
| 329 |
+
img_arr = data_group.require_dataset(
|
| 330 |
+
name=key,
|
| 331 |
+
shape=(n_steps,h,w,c),
|
| 332 |
+
chunks=(1,h,w,c),
|
| 333 |
+
compressor=this_compressor,
|
| 334 |
+
dtype=np.uint8
|
| 335 |
+
)
|
| 336 |
+
for episode_idx in range(n_demo):
|
| 337 |
+
demo = demos[f'demo_{episode_idx}']
|
| 338 |
+
hdf5_arr = demo['obs'][key]
|
| 339 |
+
for hdf5_idx in range(hdf5_arr.shape[0]):
|
| 340 |
+
if len(futures) >= max_inflight_tasks:
|
| 341 |
+
# limit number of inflight tasks
|
| 342 |
+
completed, futures = concurrent.futures.wait(futures,
|
| 343 |
+
return_when=concurrent.futures.FIRST_COMPLETED)
|
| 344 |
+
for f in completed:
|
| 345 |
+
if not f.result():
|
| 346 |
+
raise RuntimeError('Failed to encode image!')
|
| 347 |
+
pbar.update(len(completed))
|
| 348 |
+
|
| 349 |
+
zarr_idx = episode_starts[episode_idx] + hdf5_idx
|
| 350 |
+
futures.add(
|
| 351 |
+
executor.submit(img_copy,
|
| 352 |
+
img_arr, zarr_idx, hdf5_arr, hdf5_idx))
|
| 353 |
+
completed, futures = concurrent.futures.wait(futures)
|
| 354 |
+
for f in completed:
|
| 355 |
+
if not f.result():
|
| 356 |
+
raise RuntimeError('Failed to encode image!')
|
| 357 |
+
pbar.update(len(completed))
|
| 358 |
+
|
| 359 |
+
with tqdm(total=n_steps*len(pc_keys), desc="Loading point cloud data", mininterval=1.0) as pbar:
|
| 360 |
+
# one chunk per thread, therefore no synchronization needed
|
| 361 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
|
| 362 |
+
futures = set()
|
| 363 |
+
for key in pc_keys:
|
| 364 |
+
data_key = key
|
| 365 |
+
shape = tuple(shape_meta['obs'][key]['shape'])
|
| 366 |
+
n, c = shape
|
| 367 |
+
img_arr = data_group.require_dataset(
|
| 368 |
+
name=key,
|
| 369 |
+
shape=(n_steps, n, c),
|
| 370 |
+
chunks=(1, n, c),
|
| 371 |
+
dtype=np.float32
|
| 372 |
+
)
|
| 373 |
+
for episode_idx in range(n_demo):
|
| 374 |
+
demo = demos[f'demo_{episode_idx}']
|
| 375 |
+
hdf5_arr = demo['obs'][key]
|
| 376 |
+
for hdf5_idx in range(hdf5_arr.shape[0]):
|
| 377 |
+
if len(futures) >= max_inflight_tasks:
|
| 378 |
+
# limit number of inflight tasks
|
| 379 |
+
completed, futures = concurrent.futures.wait(futures,
|
| 380 |
+
return_when=concurrent.futures.FIRST_COMPLETED)
|
| 381 |
+
for f in completed:
|
| 382 |
+
if not f.result():
|
| 383 |
+
raise RuntimeError('Failed to encode image!')
|
| 384 |
+
pbar.update(len(completed))
|
| 385 |
+
|
| 386 |
+
zarr_idx = episode_starts[episode_idx] + hdf5_idx
|
| 387 |
+
futures.add(
|
| 388 |
+
executor.submit(pc_copy,
|
| 389 |
+
img_arr, zarr_idx, hdf5_arr, hdf5_idx))
|
| 390 |
+
completed, futures = concurrent.futures.wait(futures)
|
| 391 |
+
for f in completed:
|
| 392 |
+
if not f.result():
|
| 393 |
+
raise RuntimeError('Failed to encode image!')
|
| 394 |
+
pbar.update(len(completed))
|
| 395 |
+
|
| 396 |
+
replay_buffer = ReplayBuffer(root)
|
| 397 |
+
return replay_buffer
|
| 398 |
+
|
| 399 |
+
def normalizer_from_stat(stat):
|
| 400 |
+
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
|
| 401 |
+
scale = np.full_like(stat['max'], fill_value=1/max_abs)
|
| 402 |
+
offset = np.zeros_like(stat['max'])
|
| 403 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 404 |
+
scale=scale,
|
| 405 |
+
offset=offset,
|
| 406 |
+
input_stats_dict=stat
|
| 407 |
+
)
|
equidiff/equi_diffpo/dataset/robomimic_replay_voxel_sym_dataset.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import h5py
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import zarr
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import hashlib
|
| 12 |
+
from filelock import FileLock
|
| 13 |
+
from threadpoolctl import threadpool_limits
|
| 14 |
+
import concurrent.futures
|
| 15 |
+
import multiprocessing
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 18 |
+
from equi_diffpo.dataset.base_dataset import BaseImageDataset, LinearNormalizer
|
| 19 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
|
| 20 |
+
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
|
| 21 |
+
from equi_diffpo.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
|
| 22 |
+
from equi_diffpo.common.replay_buffer import ReplayBuffer
|
| 23 |
+
from equi_diffpo.common.sampler import SequenceSampler, get_val_mask
|
| 24 |
+
from equi_diffpo.common.normalize_util import (
|
| 25 |
+
robomimic_abs_action_only_normalizer_from_stat,
|
| 26 |
+
get_range_normalizer_from_stat,
|
| 27 |
+
get_voxel_identity_normalizer,
|
| 28 |
+
get_image_range_normalizer,
|
| 29 |
+
get_identity_normalizer_from_stat,
|
| 30 |
+
array_to_stats
|
| 31 |
+
)
|
| 32 |
+
register_codecs()
|
| 33 |
+
|
| 34 |
+
class RobomimicReplayVoxelSymDataset(BaseImageDataset):
|
| 35 |
+
def __init__(self,
|
| 36 |
+
shape_meta: dict,
|
| 37 |
+
dataset_path: str,
|
| 38 |
+
horizon=1,
|
| 39 |
+
pad_before=0,
|
| 40 |
+
pad_after=0,
|
| 41 |
+
n_obs_steps=None,
|
| 42 |
+
abs_action=False,
|
| 43 |
+
rotation_rep='rotation_6d', # ignored when abs_action=False
|
| 44 |
+
use_legacy_normalizer=False,
|
| 45 |
+
use_cache=False,
|
| 46 |
+
seed=42,
|
| 47 |
+
val_ratio=0.0,
|
| 48 |
+
n_demo=100,
|
| 49 |
+
ws_size=0.6,
|
| 50 |
+
ws_x_center=0,
|
| 51 |
+
ws_y_center=0,
|
| 52 |
+
):
|
| 53 |
+
self.n_demo = n_demo
|
| 54 |
+
self.ws_size = ws_size
|
| 55 |
+
self.ws_center = np.array([ws_x_center, ws_y_center])
|
| 56 |
+
rotation_transformer = RotationTransformer(
|
| 57 |
+
from_rep='axis_angle', to_rep=rotation_rep)
|
| 58 |
+
|
| 59 |
+
replay_buffer = None
|
| 60 |
+
if use_cache:
|
| 61 |
+
cache_zarr_path = dataset_path + f'.{n_demo}.' + '.zarr.zip'
|
| 62 |
+
cache_lock_path = cache_zarr_path + '.lock'
|
| 63 |
+
print('Acquiring lock on cache.')
|
| 64 |
+
with FileLock(cache_lock_path):
|
| 65 |
+
if not os.path.exists(cache_zarr_path):
|
| 66 |
+
# cache does not exists
|
| 67 |
+
try:
|
| 68 |
+
print('Cache does not exist. Creating!')
|
| 69 |
+
# store = zarr.DirectoryStore(cache_zarr_path)
|
| 70 |
+
replay_buffer = _convert_voxel_to_replay(
|
| 71 |
+
store=zarr.MemoryStore(),
|
| 72 |
+
shape_meta=shape_meta,
|
| 73 |
+
dataset_path=dataset_path,
|
| 74 |
+
abs_action=abs_action,
|
| 75 |
+
rotation_transformer=rotation_transformer,
|
| 76 |
+
n_demo=n_demo)
|
| 77 |
+
print('Saving cache to disk.')
|
| 78 |
+
with zarr.ZipStore(cache_zarr_path) as zip_store:
|
| 79 |
+
replay_buffer.save_to_store(
|
| 80 |
+
store=zip_store
|
| 81 |
+
)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
shutil.rmtree(cache_zarr_path)
|
| 84 |
+
raise e
|
| 85 |
+
else:
|
| 86 |
+
print('Loading cached ReplayBuffer from Disk.')
|
| 87 |
+
with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
|
| 88 |
+
replay_buffer = ReplayBuffer.copy_from_store(
|
| 89 |
+
src_store=zip_store, store=zarr.MemoryStore())
|
| 90 |
+
print('Loaded!')
|
| 91 |
+
else:
|
| 92 |
+
replay_buffer = _convert_voxel_to_replay(
|
| 93 |
+
store=zarr.MemoryStore(),
|
| 94 |
+
shape_meta=shape_meta,
|
| 95 |
+
dataset_path=dataset_path,
|
| 96 |
+
abs_action=abs_action,
|
| 97 |
+
rotation_transformer=rotation_transformer,
|
| 98 |
+
n_demo=n_demo)
|
| 99 |
+
|
| 100 |
+
rgb_keys = list()
|
| 101 |
+
voxel_keys = list()
|
| 102 |
+
lowdim_keys = list()
|
| 103 |
+
obs_shape_meta = shape_meta['obs']
|
| 104 |
+
for key, attr in obs_shape_meta.items():
|
| 105 |
+
type = attr.get('type', 'low_dim')
|
| 106 |
+
if type == 'rgb':
|
| 107 |
+
rgb_keys.append(key)
|
| 108 |
+
if type == 'voxel':
|
| 109 |
+
voxel_keys.append(key)
|
| 110 |
+
elif type == 'low_dim':
|
| 111 |
+
lowdim_keys.append(key)
|
| 112 |
+
|
| 113 |
+
# for key in rgb_keys:
|
| 114 |
+
# replay_buffer[key].compressor.numthreads=1
|
| 115 |
+
|
| 116 |
+
key_first_k = dict()
|
| 117 |
+
if n_obs_steps is not None:
|
| 118 |
+
# only take first k obs from images
|
| 119 |
+
for key in rgb_keys + voxel_keys + lowdim_keys:
|
| 120 |
+
key_first_k[key] = n_obs_steps
|
| 121 |
+
|
| 122 |
+
val_mask = get_val_mask(
|
| 123 |
+
n_episodes=replay_buffer.n_episodes,
|
| 124 |
+
val_ratio=val_ratio,
|
| 125 |
+
seed=seed)
|
| 126 |
+
train_mask = ~val_mask
|
| 127 |
+
sampler = SequenceSampler(
|
| 128 |
+
replay_buffer=replay_buffer,
|
| 129 |
+
sequence_length=horizon,
|
| 130 |
+
pad_before=pad_before,
|
| 131 |
+
pad_after=pad_after,
|
| 132 |
+
episode_mask=train_mask,
|
| 133 |
+
key_first_k=key_first_k)
|
| 134 |
+
|
| 135 |
+
self.replay_buffer = replay_buffer
|
| 136 |
+
self.sampler = sampler
|
| 137 |
+
self.shape_meta = shape_meta
|
| 138 |
+
self.rgb_keys = rgb_keys
|
| 139 |
+
self.voxel_keys = voxel_keys
|
| 140 |
+
self.lowdim_keys = lowdim_keys
|
| 141 |
+
self.abs_action = abs_action
|
| 142 |
+
self.n_obs_steps = n_obs_steps
|
| 143 |
+
self.train_mask = train_mask
|
| 144 |
+
self.horizon = horizon
|
| 145 |
+
self.pad_before = pad_before
|
| 146 |
+
self.pad_after = pad_after
|
| 147 |
+
self.use_legacy_normalizer = use_legacy_normalizer
|
| 148 |
+
|
| 149 |
+
def get_validation_dataset(self):
|
| 150 |
+
val_set = copy.copy(self)
|
| 151 |
+
val_set.sampler = SequenceSampler(
|
| 152 |
+
replay_buffer=self.replay_buffer,
|
| 153 |
+
sequence_length=self.horizon,
|
| 154 |
+
pad_before=self.pad_before,
|
| 155 |
+
pad_after=self.pad_after,
|
| 156 |
+
episode_mask=~self.train_mask
|
| 157 |
+
)
|
| 158 |
+
val_set.train_mask = ~self.train_mask
|
| 159 |
+
return val_set
|
| 160 |
+
|
| 161 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 162 |
+
normalizer = LinearNormalizer()
|
| 163 |
+
# action
|
| 164 |
+
stat = array_to_stats(self.replay_buffer['action'])
|
| 165 |
+
if self.abs_action:
|
| 166 |
+
if stat['mean'].shape[-1] > 10:
|
| 167 |
+
# dual arm
|
| 168 |
+
raise NotImplementedError
|
| 169 |
+
else:
|
| 170 |
+
magnitute = max(np.max([stat['max'][:2] - self.ws_center, self.ws_center - stat['min'][:2]]), self.ws_size/2)
|
| 171 |
+
stat['min'][:2] = self.ws_center - magnitute
|
| 172 |
+
stat['max'][:2] = self.ws_center + magnitute
|
| 173 |
+
stat['mean'][:2] = self.ws_center
|
| 174 |
+
this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
|
| 175 |
+
|
| 176 |
+
if self.use_legacy_normalizer:
|
| 177 |
+
this_normalizer = normalizer_from_stat(stat)
|
| 178 |
+
else:
|
| 179 |
+
# already normalized
|
| 180 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 181 |
+
normalizer['action'] = this_normalizer
|
| 182 |
+
|
| 183 |
+
# obs
|
| 184 |
+
for key in self.lowdim_keys:
|
| 185 |
+
stat = array_to_stats(self.replay_buffer[key])
|
| 186 |
+
|
| 187 |
+
if key.endswith('qpos'):
|
| 188 |
+
this_normalizer = get_range_normalizer_from_stat(stat)
|
| 189 |
+
elif key.endswith('pos'):
|
| 190 |
+
magnitute = max(np.max([stat['max'][:2] - self.ws_center, self.ws_center - stat['min'][:2]]), self.ws_size/2)
|
| 191 |
+
stat['min'][:2] = self.ws_center - magnitute
|
| 192 |
+
stat['max'][:2] = self.ws_center + magnitute
|
| 193 |
+
stat['mean'][:2] = self.ws_center
|
| 194 |
+
this_normalizer = get_range_normalizer_from_stat(stat)
|
| 195 |
+
elif key.endswith('quat'):
|
| 196 |
+
# quaternion is in [-1,1] already
|
| 197 |
+
this_normalizer = get_identity_normalizer_from_stat(stat)
|
| 198 |
+
else:
|
| 199 |
+
raise RuntimeError('unsupported')
|
| 200 |
+
normalizer[key] = this_normalizer
|
| 201 |
+
|
| 202 |
+
# image
|
| 203 |
+
for key in self.rgb_keys:
|
| 204 |
+
normalizer[key] = get_image_range_normalizer()
|
| 205 |
+
for key in self.voxel_keys:
|
| 206 |
+
normalizer[key] = get_voxel_identity_normalizer()
|
| 207 |
+
|
| 208 |
+
return normalizer
|
| 209 |
+
|
| 210 |
+
def get_all_actions(self) -> torch.Tensor:
|
| 211 |
+
return torch.from_numpy(self.replay_buffer['action'])
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.sampler)
|
| 215 |
+
|
| 216 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 217 |
+
threadpool_limits(1)
|
| 218 |
+
data = self.sampler.sample_sequence(idx)
|
| 219 |
+
|
| 220 |
+
# to save RAM, only return first n_obs_steps of OBS
|
| 221 |
+
# since the rest will be discarded anyway.
|
| 222 |
+
# when self.n_obs_steps is None
|
| 223 |
+
# this slice does nothing (takes all)
|
| 224 |
+
T_slice = slice(self.n_obs_steps)
|
| 225 |
+
|
| 226 |
+
obs_dict = dict()
|
| 227 |
+
for key in self.rgb_keys:
|
| 228 |
+
# move channel last to channel first
|
| 229 |
+
# T,H,W,C
|
| 230 |
+
# convert uint8 image to float32
|
| 231 |
+
obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
|
| 232 |
+
).astype(np.float32) / 255.
|
| 233 |
+
# T,C,H,W
|
| 234 |
+
del data[key]
|
| 235 |
+
for key in self.voxel_keys:
|
| 236 |
+
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
| 237 |
+
obs_dict[key][:, 1:] /= 255.
|
| 238 |
+
# # convert uint8 image to float32
|
| 239 |
+
# voxels = np.moveaxis(data[key][T_slice].astype(np.float32), [0, 1, 2, 3, 4], [0, 1, 4, 3, 2])
|
| 240 |
+
# voxels = np.flip(voxels, (2, 3))
|
| 241 |
+
# voxels[:, 1:] /= 255.
|
| 242 |
+
# obs_dict[key] = voxels.copy()
|
| 243 |
+
del data[key]
|
| 244 |
+
for key in self.lowdim_keys:
|
| 245 |
+
obs_dict[key] = data[key][T_slice].astype(np.float32)
|
| 246 |
+
del data[key]
|
| 247 |
+
|
| 248 |
+
torch_data = {
|
| 249 |
+
'obs': dict_apply(obs_dict, torch.from_numpy),
|
| 250 |
+
'action': torch.from_numpy(data['action'].astype(np.float32))
|
| 251 |
+
}
|
| 252 |
+
return torch_data
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _convert_actions(raw_actions, abs_action, rotation_transformer):
|
| 256 |
+
actions = raw_actions
|
| 257 |
+
if abs_action:
|
| 258 |
+
is_dual_arm = False
|
| 259 |
+
if raw_actions.shape[-1] == 14:
|
| 260 |
+
# dual arm
|
| 261 |
+
raw_actions = raw_actions.reshape(-1,2,7)
|
| 262 |
+
is_dual_arm = True
|
| 263 |
+
|
| 264 |
+
pos = raw_actions[...,:3]
|
| 265 |
+
rot = raw_actions[...,3:6]
|
| 266 |
+
gripper = raw_actions[...,6:]
|
| 267 |
+
rot = rotation_transformer.forward(rot)
|
| 268 |
+
raw_actions = np.concatenate([
|
| 269 |
+
pos, rot, gripper
|
| 270 |
+
], axis=-1).astype(np.float32)
|
| 271 |
+
|
| 272 |
+
if is_dual_arm:
|
| 273 |
+
raw_actions = raw_actions.reshape(-1,20)
|
| 274 |
+
actions = raw_actions
|
| 275 |
+
return actions
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _convert_voxel_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
|
| 279 |
+
n_workers=None, max_inflight_tasks=None, n_demo=100):
|
| 280 |
+
if n_workers is None:
|
| 281 |
+
n_workers = 24
|
| 282 |
+
if max_inflight_tasks is None:
|
| 283 |
+
max_inflight_tasks = n_workers * 5
|
| 284 |
+
|
| 285 |
+
# parse shape_meta
|
| 286 |
+
voxel_keys = list()
|
| 287 |
+
rgb_keys = list()
|
| 288 |
+
lowdim_keys = list()
|
| 289 |
+
# construct compressors and chunks
|
| 290 |
+
obs_shape_meta = shape_meta['obs']
|
| 291 |
+
for key, attr in obs_shape_meta.items():
|
| 292 |
+
shape = attr['shape']
|
| 293 |
+
type = attr.get('type', 'low_dim')
|
| 294 |
+
if type == 'rgb':
|
| 295 |
+
rgb_keys.append(key)
|
| 296 |
+
elif type == 'voxel':
|
| 297 |
+
voxel_keys.append(key)
|
| 298 |
+
elif type == 'low_dim':
|
| 299 |
+
lowdim_keys.append(key)
|
| 300 |
+
|
| 301 |
+
root = zarr.group(store)
|
| 302 |
+
data_group = root.require_group('data', overwrite=True)
|
| 303 |
+
meta_group = root.require_group('meta', overwrite=True)
|
| 304 |
+
|
| 305 |
+
with h5py.File(dataset_path) as file:
|
| 306 |
+
# count total steps
|
| 307 |
+
demos = file['data']
|
| 308 |
+
episode_ends = list()
|
| 309 |
+
prev_end = 0
|
| 310 |
+
n_demo = min(n_demo, len(demos))
|
| 311 |
+
for i in range(n_demo):
|
| 312 |
+
demo = demos[f'demo_{i}']
|
| 313 |
+
episode_length = demo['actions'].shape[0]
|
| 314 |
+
episode_end = prev_end + episode_length
|
| 315 |
+
prev_end = episode_end
|
| 316 |
+
episode_ends.append(episode_end)
|
| 317 |
+
n_steps = episode_ends[-1]
|
| 318 |
+
episode_starts = [0] + episode_ends[:-1]
|
| 319 |
+
_ = meta_group.array('episode_ends', episode_ends,
|
| 320 |
+
dtype=np.int64, compressor=None, overwrite=True)
|
| 321 |
+
|
| 322 |
+
# save lowdim data
|
| 323 |
+
for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
|
| 324 |
+
data_key = 'obs/' + key
|
| 325 |
+
if key == 'action':
|
| 326 |
+
data_key = 'actions'
|
| 327 |
+
this_data = list()
|
| 328 |
+
for i in range(n_demo):
|
| 329 |
+
demo = demos[f'demo_{i}']
|
| 330 |
+
this_data.append(demo[data_key][:].astype(np.float32))
|
| 331 |
+
this_data = np.concatenate(this_data, axis=0)
|
| 332 |
+
if key == 'action':
|
| 333 |
+
this_data = _convert_actions(
|
| 334 |
+
raw_actions=this_data,
|
| 335 |
+
abs_action=abs_action,
|
| 336 |
+
rotation_transformer=rotation_transformer
|
| 337 |
+
)
|
| 338 |
+
assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
|
| 339 |
+
else:
|
| 340 |
+
assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
|
| 341 |
+
_ = data_group.array(
|
| 342 |
+
name=key,
|
| 343 |
+
data=this_data,
|
| 344 |
+
shape=this_data.shape,
|
| 345 |
+
chunks=this_data.shape,
|
| 346 |
+
compressor=None,
|
| 347 |
+
dtype=this_data.dtype
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def copy_to_zarr(zarr_arr, hdf5_arr, start_idx, end_idx):
|
| 351 |
+
try:
|
| 352 |
+
zarr_arr[start_idx:end_idx] = hdf5_arr
|
| 353 |
+
# make sure we can successfully decode
|
| 354 |
+
_ = zarr_arr[start_idx:end_idx]
|
| 355 |
+
return True
|
| 356 |
+
except Exception as e:
|
| 357 |
+
return False
|
| 358 |
+
|
| 359 |
+
with tqdm(total=n_demo*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
|
| 360 |
+
# one chunk per thread, therefore no synchronization needed
|
| 361 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
|
| 362 |
+
futures = set()
|
| 363 |
+
for key in rgb_keys:
|
| 364 |
+
data_key = 'obs/' + key
|
| 365 |
+
shape = tuple(shape_meta['obs'][key]['shape'])
|
| 366 |
+
c,h,w = shape
|
| 367 |
+
this_compressor = Jpeg2k(level=50)
|
| 368 |
+
img_arr = data_group.require_dataset(
|
| 369 |
+
name=key,
|
| 370 |
+
shape=(n_steps,h,w,c),
|
| 371 |
+
chunks=(1,h,w,c),
|
| 372 |
+
compressor=this_compressor,
|
| 373 |
+
dtype=np.uint8
|
| 374 |
+
)
|
| 375 |
+
for episode_idx in range(n_demo):
|
| 376 |
+
demo = demos[f'demo_{episode_idx}']
|
| 377 |
+
hdf5_arr = demo['obs'][key][:]
|
| 378 |
+
start_idx = episode_starts[episode_idx]
|
| 379 |
+
if episode_idx < n_demo - 1:
|
| 380 |
+
end_idx = episode_starts[episode_idx+1]
|
| 381 |
+
else:
|
| 382 |
+
end_idx = n_steps
|
| 383 |
+
if len(futures) >= max_inflight_tasks:
|
| 384 |
+
# limit number of inflight tasks
|
| 385 |
+
completed, futures = concurrent.futures.wait(futures,
|
| 386 |
+
return_when=concurrent.futures.FIRST_COMPLETED)
|
| 387 |
+
for f in completed:
|
| 388 |
+
if not f.result():
|
| 389 |
+
raise RuntimeError('Failed to encode image!')
|
| 390 |
+
pbar.update(len(completed))
|
| 391 |
+
|
| 392 |
+
futures.add(
|
| 393 |
+
executor.submit(copy_to_zarr,
|
| 394 |
+
img_arr, hdf5_arr, start_idx, end_idx))
|
| 395 |
+
completed, futures = concurrent.futures.wait(futures)
|
| 396 |
+
for f in completed:
|
| 397 |
+
if not f.result():
|
| 398 |
+
raise RuntimeError('Failed to encode image!')
|
| 399 |
+
pbar.update(len(completed))
|
| 400 |
+
|
| 401 |
+
with tqdm(total=n_demo*len(voxel_keys), desc="Loading voxel data", mininterval=1.0) as pbar:
|
| 402 |
+
# one chunk per thread, therefore no synchronization needed
|
| 403 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
|
| 404 |
+
futures = set()
|
| 405 |
+
for key in voxel_keys:
|
| 406 |
+
data_key = key
|
| 407 |
+
shape = tuple(shape_meta['obs'][key]['shape'])
|
| 408 |
+
c,h,w,l = shape
|
| 409 |
+
img_arr = data_group.require_dataset(
|
| 410 |
+
name=key,
|
| 411 |
+
shape=(n_steps,c,h,w,l),
|
| 412 |
+
chunks=(1,c,h,w,l),
|
| 413 |
+
dtype=np.uint8
|
| 414 |
+
)
|
| 415 |
+
for episode_idx in range(n_demo):
|
| 416 |
+
demo = demos[f'demo_{episode_idx}']
|
| 417 |
+
hdf5_arr = demo['obs'][key][:]
|
| 418 |
+
start_idx = episode_starts[episode_idx]
|
| 419 |
+
if episode_idx < n_demo - 1:
|
| 420 |
+
end_idx = episode_starts[episode_idx+1]
|
| 421 |
+
else:
|
| 422 |
+
end_idx = n_steps
|
| 423 |
+
if len(futures) >= max_inflight_tasks:
|
| 424 |
+
# limit number of inflight tasks
|
| 425 |
+
completed, futures = concurrent.futures.wait(futures,
|
| 426 |
+
return_when=concurrent.futures.FIRST_COMPLETED)
|
| 427 |
+
for f in completed:
|
| 428 |
+
if not f.result():
|
| 429 |
+
raise RuntimeError('Failed to encode image!')
|
| 430 |
+
pbar.update(len(completed))
|
| 431 |
+
|
| 432 |
+
futures.add(
|
| 433 |
+
executor.submit(copy_to_zarr,
|
| 434 |
+
img_arr, hdf5_arr, start_idx, end_idx))
|
| 435 |
+
completed, futures = concurrent.futures.wait(futures)
|
| 436 |
+
for f in completed:
|
| 437 |
+
if not f.result():
|
| 438 |
+
raise RuntimeError('Failed to encode image!')
|
| 439 |
+
pbar.update(len(completed))
|
| 440 |
+
|
| 441 |
+
replay_buffer = ReplayBuffer(root)
|
| 442 |
+
return replay_buffer
|
| 443 |
+
|
| 444 |
+
def normalizer_from_stat(stat):
|
| 445 |
+
max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
|
| 446 |
+
scale = np.full_like(stat['max'], fill_value=1/max_abs)
|
| 447 |
+
offset = np.zeros_like(stat['max'])
|
| 448 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 449 |
+
scale=scale,
|
| 450 |
+
offset=offset,
|
| 451 |
+
input_stats_dict=stat
|
| 452 |
+
)
|
equidiff/equi_diffpo/env/robomimic/robomimic_image_wrapper.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from matplotlib.pyplot import fill
|
| 3 |
+
import numpy as np
|
| 4 |
+
import gym
|
| 5 |
+
from gym import spaces
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from robomimic.envs.env_robosuite import EnvRobosuite
|
| 8 |
+
|
| 9 |
+
class RobomimicImageWrapper(gym.Env):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
env: EnvRobosuite,
|
| 12 |
+
shape_meta: dict,
|
| 13 |
+
init_state: Optional[np.ndarray]=None,
|
| 14 |
+
render_obs_key='agentview_image',
|
| 15 |
+
):
|
| 16 |
+
|
| 17 |
+
self.env = env
|
| 18 |
+
self.render_obs_key = render_obs_key
|
| 19 |
+
self.init_state = init_state
|
| 20 |
+
self.seed_state_map = dict()
|
| 21 |
+
self._seed = None
|
| 22 |
+
self.shape_meta = shape_meta
|
| 23 |
+
self.render_cache = None
|
| 24 |
+
self.has_reset_before = False
|
| 25 |
+
|
| 26 |
+
# setup spaces
|
| 27 |
+
action_shape = shape_meta['action']['shape']
|
| 28 |
+
action_space = spaces.Box(
|
| 29 |
+
low=-1,
|
| 30 |
+
high=1,
|
| 31 |
+
shape=action_shape,
|
| 32 |
+
dtype=np.float32
|
| 33 |
+
)
|
| 34 |
+
self.action_space = action_space
|
| 35 |
+
|
| 36 |
+
observation_space = spaces.Dict()
|
| 37 |
+
for key, value in shape_meta['obs'].items():
|
| 38 |
+
shape = value['shape']
|
| 39 |
+
min_value, max_value = -1, 1
|
| 40 |
+
if key.endswith('image'):
|
| 41 |
+
min_value, max_value = 0, 1
|
| 42 |
+
elif key.endswith('depth'):
|
| 43 |
+
min_value, max_value = 0, 1
|
| 44 |
+
elif key.endswith('voxels'):
|
| 45 |
+
min_value, max_value = 0, 1
|
| 46 |
+
elif key.endswith('point_cloud'):
|
| 47 |
+
min_value, max_value = -10, 10
|
| 48 |
+
elif key.endswith('quat'):
|
| 49 |
+
min_value, max_value = -1, 1
|
| 50 |
+
elif key.endswith('qpos'):
|
| 51 |
+
min_value, max_value = -1, 1
|
| 52 |
+
elif key.endswith('pos'):
|
| 53 |
+
# better range?
|
| 54 |
+
min_value, max_value = -1, 1
|
| 55 |
+
else:
|
| 56 |
+
raise RuntimeError(f"Unsupported type {key}")
|
| 57 |
+
|
| 58 |
+
this_space = spaces.Box(
|
| 59 |
+
low=min_value,
|
| 60 |
+
high=max_value,
|
| 61 |
+
shape=shape,
|
| 62 |
+
dtype=np.float32
|
| 63 |
+
)
|
| 64 |
+
observation_space[key] = this_space
|
| 65 |
+
self.observation_space = observation_space
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_observation(self, raw_obs=None):
|
| 69 |
+
if raw_obs is None:
|
| 70 |
+
raw_obs = self.env.get_observation()
|
| 71 |
+
|
| 72 |
+
self.render_cache = raw_obs[self.render_obs_key]
|
| 73 |
+
|
| 74 |
+
obs = dict()
|
| 75 |
+
for key in self.observation_space.keys():
|
| 76 |
+
obs[key] = raw_obs[key]
|
| 77 |
+
return obs
|
| 78 |
+
|
| 79 |
+
def seed(self, seed=None):
|
| 80 |
+
np.random.seed(seed=seed)
|
| 81 |
+
self._seed = seed
|
| 82 |
+
|
| 83 |
+
def reset(self):
|
| 84 |
+
if self.init_state is not None:
|
| 85 |
+
if not self.has_reset_before:
|
| 86 |
+
# the env must be fully reset at least once to ensure correct rendering
|
| 87 |
+
self.env.reset()
|
| 88 |
+
self.has_reset_before = True
|
| 89 |
+
|
| 90 |
+
# always reset to the same state
|
| 91 |
+
# to be compatible with gym
|
| 92 |
+
raw_obs = self.env.reset_to({'states': self.init_state})
|
| 93 |
+
elif self._seed is not None:
|
| 94 |
+
# reset to a specific seed
|
| 95 |
+
seed = self._seed
|
| 96 |
+
if seed in self.seed_state_map:
|
| 97 |
+
# env.reset is expensive, use cache
|
| 98 |
+
raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]})
|
| 99 |
+
else:
|
| 100 |
+
# robosuite's initializes all use numpy global random state
|
| 101 |
+
np.random.seed(seed=seed)
|
| 102 |
+
raw_obs = self.env.reset()
|
| 103 |
+
state = self.env.get_state()['states']
|
| 104 |
+
self.seed_state_map[seed] = state
|
| 105 |
+
self._seed = None
|
| 106 |
+
else:
|
| 107 |
+
# random reset
|
| 108 |
+
raw_obs = self.env.reset()
|
| 109 |
+
|
| 110 |
+
# return obs
|
| 111 |
+
obs = self.get_observation(raw_obs)
|
| 112 |
+
return obs
|
| 113 |
+
|
| 114 |
+
def step(self, action):
|
| 115 |
+
raw_obs, reward, done, info = self.env.step(action)
|
| 116 |
+
obs = self.get_observation(raw_obs)
|
| 117 |
+
return obs, reward, done, info
|
| 118 |
+
|
| 119 |
+
def render(self, mode='rgb_array'):
|
| 120 |
+
if self.render_cache is None:
|
| 121 |
+
raise RuntimeError('Must run reset or step before render.')
|
| 122 |
+
img = np.moveaxis(self.render_cache, 0, -1)
|
| 123 |
+
img = (img * 255).astype(np.uint8)
|
| 124 |
+
return img
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test():
|
| 128 |
+
import os
|
| 129 |
+
from omegaconf import OmegaConf
|
| 130 |
+
cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
|
| 131 |
+
cfg = OmegaConf.load(cfg_path)
|
| 132 |
+
shape_meta = cfg['shape_meta']
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
import robomimic.utils.file_utils as FileUtils
|
| 136 |
+
import robomimic.utils.env_utils as EnvUtils
|
| 137 |
+
from matplotlib import pyplot as plt
|
| 138 |
+
|
| 139 |
+
dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5')
|
| 140 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(
|
| 141 |
+
dataset_path)
|
| 142 |
+
|
| 143 |
+
env = EnvUtils.create_env_from_metadata(
|
| 144 |
+
env_meta=env_meta,
|
| 145 |
+
render=False,
|
| 146 |
+
render_offscreen=False,
|
| 147 |
+
use_image_obs=True,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
wrapper = RobomimicImageWrapper(
|
| 151 |
+
env=env,
|
| 152 |
+
shape_meta=shape_meta
|
| 153 |
+
)
|
| 154 |
+
wrapper.seed(0)
|
| 155 |
+
obs = wrapper.reset()
|
| 156 |
+
img = wrapper.render()
|
| 157 |
+
plt.imshow(img)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# states = list()
|
| 161 |
+
# for _ in range(2):
|
| 162 |
+
# wrapper.seed(0)
|
| 163 |
+
# wrapper.reset()
|
| 164 |
+
# states.append(wrapper.env.get_state()['states'])
|
| 165 |
+
# assert np.allclose(states[0], states[1])
|
| 166 |
+
|
| 167 |
+
# img = wrapper.render()
|
| 168 |
+
# plt.imshow(img)
|
| 169 |
+
# wrapper.seed()
|
| 170 |
+
# states.append(wrapper.env.get_state()['states'])
|
equidiff/equi_diffpo/env/robomimic/robomimic_lowdim_wrapper.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gym
|
| 4 |
+
from gym.spaces import Box
|
| 5 |
+
from robomimic.envs.env_robosuite import EnvRobosuite
|
| 6 |
+
|
| 7 |
+
class RobomimicLowdimWrapper(gym.Env):
|
| 8 |
+
def __init__(self,
|
| 9 |
+
env: EnvRobosuite,
|
| 10 |
+
obs_keys: List[str]=[
|
| 11 |
+
'object',
|
| 12 |
+
'robot0_eef_pos',
|
| 13 |
+
'robot0_eef_quat',
|
| 14 |
+
'robot0_gripper_qpos'],
|
| 15 |
+
init_state: Optional[np.ndarray]=None,
|
| 16 |
+
render_hw=(256,256),
|
| 17 |
+
render_camera_name='agentview'
|
| 18 |
+
):
|
| 19 |
+
|
| 20 |
+
self.env = env
|
| 21 |
+
self.obs_keys = obs_keys
|
| 22 |
+
self.init_state = init_state
|
| 23 |
+
self.render_hw = render_hw
|
| 24 |
+
self.render_camera_name = render_camera_name
|
| 25 |
+
self.seed_state_map = dict()
|
| 26 |
+
self._seed = None
|
| 27 |
+
|
| 28 |
+
# setup spaces
|
| 29 |
+
low = np.full(env.action_dimension, fill_value=-1)
|
| 30 |
+
high = np.full(env.action_dimension, fill_value=1)
|
| 31 |
+
self.action_space = Box(
|
| 32 |
+
low=low,
|
| 33 |
+
high=high,
|
| 34 |
+
shape=low.shape,
|
| 35 |
+
dtype=low.dtype
|
| 36 |
+
)
|
| 37 |
+
obs_example = self.get_observation()
|
| 38 |
+
low = np.full_like(obs_example, fill_value=-1)
|
| 39 |
+
high = np.full_like(obs_example, fill_value=1)
|
| 40 |
+
self.observation_space = Box(
|
| 41 |
+
low=low,
|
| 42 |
+
high=high,
|
| 43 |
+
shape=low.shape,
|
| 44 |
+
dtype=low.dtype
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def get_observation(self):
|
| 48 |
+
raw_obs = self.env.get_observation()
|
| 49 |
+
obs = np.concatenate([
|
| 50 |
+
raw_obs[key] for key in self.obs_keys
|
| 51 |
+
], axis=0)
|
| 52 |
+
return obs
|
| 53 |
+
|
| 54 |
+
def seed(self, seed=None):
|
| 55 |
+
np.random.seed(seed=seed)
|
| 56 |
+
self._seed = seed
|
| 57 |
+
|
| 58 |
+
def reset(self):
|
| 59 |
+
if self.init_state is not None:
|
| 60 |
+
# always reset to the same state
|
| 61 |
+
# to be compatible with gym
|
| 62 |
+
self.env.reset_to({'states': self.init_state})
|
| 63 |
+
elif self._seed is not None:
|
| 64 |
+
# reset to a specific seed
|
| 65 |
+
seed = self._seed
|
| 66 |
+
if seed in self.seed_state_map:
|
| 67 |
+
# env.reset is expensive, use cache
|
| 68 |
+
self.env.reset_to({'states': self.seed_state_map[seed]})
|
| 69 |
+
else:
|
| 70 |
+
# robosuite's initializes all use numpy global random state
|
| 71 |
+
np.random.seed(seed=seed)
|
| 72 |
+
self.env.reset()
|
| 73 |
+
state = self.env.get_state()['states']
|
| 74 |
+
self.seed_state_map[seed] = state
|
| 75 |
+
self._seed = None
|
| 76 |
+
else:
|
| 77 |
+
# random reset
|
| 78 |
+
self.env.reset()
|
| 79 |
+
|
| 80 |
+
# return obs
|
| 81 |
+
obs = self.get_observation()
|
| 82 |
+
return obs
|
| 83 |
+
|
| 84 |
+
def step(self, action):
|
| 85 |
+
raw_obs, reward, done, info = self.env.step(action)
|
| 86 |
+
obs = np.concatenate([
|
| 87 |
+
raw_obs[key] for key in self.obs_keys
|
| 88 |
+
], axis=0)
|
| 89 |
+
return obs, reward, done, info
|
| 90 |
+
|
| 91 |
+
def render(self, mode='rgb_array'):
|
| 92 |
+
h, w = self.render_hw
|
| 93 |
+
return self.env.render(mode=mode,
|
| 94 |
+
height=h, width=w,
|
| 95 |
+
camera_name=self.render_camera_name)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test():
|
| 99 |
+
import robomimic.utils.file_utils as FileUtils
|
| 100 |
+
import robomimic.utils.env_utils as EnvUtils
|
| 101 |
+
from matplotlib import pyplot as plt
|
| 102 |
+
|
| 103 |
+
dataset_path = '/home/cchi/dev/diffusion_policy/data/robomimic/datasets/square/ph/low_dim.hdf5'
|
| 104 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(
|
| 105 |
+
dataset_path)
|
| 106 |
+
|
| 107 |
+
env = EnvUtils.create_env_from_metadata(
|
| 108 |
+
env_meta=env_meta,
|
| 109 |
+
render=False,
|
| 110 |
+
render_offscreen=False,
|
| 111 |
+
use_image_obs=False,
|
| 112 |
+
)
|
| 113 |
+
wrapper = RobomimicLowdimWrapper(
|
| 114 |
+
env=env,
|
| 115 |
+
obs_keys=[
|
| 116 |
+
'object',
|
| 117 |
+
'robot0_eef_pos',
|
| 118 |
+
'robot0_eef_quat',
|
| 119 |
+
'robot0_gripper_qpos'
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
states = list()
|
| 124 |
+
for _ in range(2):
|
| 125 |
+
wrapper.seed(0)
|
| 126 |
+
wrapper.reset()
|
| 127 |
+
states.append(wrapper.env.get_state()['states'])
|
| 128 |
+
assert np.allclose(states[0], states[1])
|
| 129 |
+
|
| 130 |
+
img = wrapper.render()
|
| 131 |
+
plt.imshow(img)
|
| 132 |
+
# wrapper.seed()
|
| 133 |
+
# states.append(wrapper.env.get_state()['states'])
|
equidiff/equi_diffpo/env_runner/robomimic_image_runner.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import wandb
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import collections
|
| 6 |
+
import pathlib
|
| 7 |
+
import tqdm
|
| 8 |
+
import h5py
|
| 9 |
+
import math
|
| 10 |
+
import dill
|
| 11 |
+
import wandb.sdk.data_types.video as wv
|
| 12 |
+
from equi_diffpo.gym_util.async_vector_env import AsyncVectorEnv
|
| 13 |
+
from equi_diffpo.gym_util.sync_vector_env import SyncVectorEnv
|
| 14 |
+
from equi_diffpo.gym_util.multistep_wrapper import MultiStepWrapper
|
| 15 |
+
from equi_diffpo.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
| 16 |
+
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
|
| 17 |
+
|
| 18 |
+
from equi_diffpo.policy.base_image_policy import BaseImagePolicy
|
| 19 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 20 |
+
from equi_diffpo.env_runner.base_image_runner import BaseImageRunner
|
| 21 |
+
from equi_diffpo.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper
|
| 22 |
+
import robomimic.utils.file_utils as FileUtils
|
| 23 |
+
import robomimic.utils.env_utils as EnvUtils
|
| 24 |
+
import robomimic.utils.obs_utils as ObsUtils
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_env(env_meta, shape_meta, enable_render=True):
|
| 28 |
+
modality_mapping = collections.defaultdict(list)
|
| 29 |
+
for key, attr in shape_meta['obs'].items():
|
| 30 |
+
modality_mapping[attr.get('type', 'low_dim')].append(key)
|
| 31 |
+
ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
|
| 32 |
+
|
| 33 |
+
env = EnvUtils.create_env_from_metadata(
|
| 34 |
+
env_meta=env_meta,
|
| 35 |
+
render=False,
|
| 36 |
+
render_offscreen=enable_render,
|
| 37 |
+
use_image_obs=enable_render,
|
| 38 |
+
)
|
| 39 |
+
return env
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RobomimicImageRunner(BaseImageRunner):
|
| 43 |
+
"""
|
| 44 |
+
Robomimic envs already enforces number of steps.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self,
|
| 48 |
+
output_dir,
|
| 49 |
+
dataset_path,
|
| 50 |
+
shape_meta:dict,
|
| 51 |
+
n_train=10,
|
| 52 |
+
n_train_vis=3,
|
| 53 |
+
train_start_idx=0,
|
| 54 |
+
n_test=22,
|
| 55 |
+
n_test_vis=6,
|
| 56 |
+
test_start_seed=10000,
|
| 57 |
+
max_steps=400,
|
| 58 |
+
n_obs_steps=2,
|
| 59 |
+
n_action_steps=8,
|
| 60 |
+
render_obs_key='agentview_image',
|
| 61 |
+
fps=10,
|
| 62 |
+
crf=22,
|
| 63 |
+
past_action=False,
|
| 64 |
+
abs_action=False,
|
| 65 |
+
tqdm_interval_sec=5.0,
|
| 66 |
+
n_envs=None
|
| 67 |
+
):
|
| 68 |
+
super().__init__(output_dir)
|
| 69 |
+
|
| 70 |
+
if n_envs is None:
|
| 71 |
+
n_envs = n_train + n_test
|
| 72 |
+
|
| 73 |
+
# assert n_obs_steps <= n_action_steps
|
| 74 |
+
dataset_path = os.path.expanduser(dataset_path)
|
| 75 |
+
robosuite_fps = 20
|
| 76 |
+
steps_per_render = max(robosuite_fps // fps, 1)
|
| 77 |
+
|
| 78 |
+
# read from dataset
|
| 79 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(
|
| 80 |
+
dataset_path)
|
| 81 |
+
# disable object state observation
|
| 82 |
+
env_meta['env_kwargs']['use_object_obs'] = False
|
| 83 |
+
|
| 84 |
+
rotation_transformer = None
|
| 85 |
+
if abs_action:
|
| 86 |
+
env_meta['env_kwargs']['controller_configs']['control_delta'] = False
|
| 87 |
+
rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')
|
| 88 |
+
|
| 89 |
+
def env_fn():
|
| 90 |
+
robomimic_env = create_env(
|
| 91 |
+
env_meta=env_meta,
|
| 92 |
+
shape_meta=shape_meta
|
| 93 |
+
)
|
| 94 |
+
# Robosuite's hard reset causes excessive memory consumption.
|
| 95 |
+
# Disabled to run more envs.
|
| 96 |
+
# https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
|
| 97 |
+
robomimic_env.env.hard_reset = False
|
| 98 |
+
return MultiStepWrapper(
|
| 99 |
+
VideoRecordingWrapper(
|
| 100 |
+
RobomimicImageWrapper(
|
| 101 |
+
env=robomimic_env,
|
| 102 |
+
shape_meta=shape_meta,
|
| 103 |
+
init_state=None,
|
| 104 |
+
render_obs_key=render_obs_key
|
| 105 |
+
),
|
| 106 |
+
video_recoder=VideoRecorder.create_h264(
|
| 107 |
+
fps=fps,
|
| 108 |
+
codec='h264',
|
| 109 |
+
input_pix_fmt='rgb24',
|
| 110 |
+
crf=crf,
|
| 111 |
+
thread_type='FRAME',
|
| 112 |
+
thread_count=1
|
| 113 |
+
),
|
| 114 |
+
file_path=None,
|
| 115 |
+
steps_per_render=steps_per_render
|
| 116 |
+
),
|
| 117 |
+
n_obs_steps=n_obs_steps,
|
| 118 |
+
n_action_steps=n_action_steps,
|
| 119 |
+
max_episode_steps=max_steps
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# For each process the OpenGL context can only be initialized once
|
| 123 |
+
# Since AsyncVectorEnv uses fork to create worker process,
|
| 124 |
+
# a separate env_fn that does not create OpenGL context (enable_render=False)
|
| 125 |
+
# is needed to initialize spaces.
|
| 126 |
+
def dummy_env_fn():
|
| 127 |
+
robomimic_env = create_env(
|
| 128 |
+
env_meta=env_meta,
|
| 129 |
+
shape_meta=shape_meta,
|
| 130 |
+
enable_render=False
|
| 131 |
+
)
|
| 132 |
+
return MultiStepWrapper(
|
| 133 |
+
VideoRecordingWrapper(
|
| 134 |
+
RobomimicImageWrapper(
|
| 135 |
+
env=robomimic_env,
|
| 136 |
+
shape_meta=shape_meta,
|
| 137 |
+
init_state=None,
|
| 138 |
+
render_obs_key=render_obs_key
|
| 139 |
+
),
|
| 140 |
+
video_recoder=VideoRecorder.create_h264(
|
| 141 |
+
fps=fps,
|
| 142 |
+
codec='h264',
|
| 143 |
+
input_pix_fmt='rgb24',
|
| 144 |
+
crf=crf,
|
| 145 |
+
thread_type='FRAME',
|
| 146 |
+
thread_count=1
|
| 147 |
+
),
|
| 148 |
+
file_path=None,
|
| 149 |
+
steps_per_render=steps_per_render
|
| 150 |
+
),
|
| 151 |
+
n_obs_steps=n_obs_steps,
|
| 152 |
+
n_action_steps=n_action_steps,
|
| 153 |
+
max_episode_steps=max_steps
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
env_fns = [env_fn] * n_envs
|
| 157 |
+
env_seeds = list()
|
| 158 |
+
env_prefixs = list()
|
| 159 |
+
env_init_fn_dills = list()
|
| 160 |
+
|
| 161 |
+
# train
|
| 162 |
+
with h5py.File(dataset_path, 'r') as f:
|
| 163 |
+
for i in range(n_train):
|
| 164 |
+
train_idx = train_start_idx + i
|
| 165 |
+
enable_render = i < n_train_vis
|
| 166 |
+
init_state = f[f'data/demo_{train_idx}/states'][0]
|
| 167 |
+
|
| 168 |
+
def init_fn(env, init_state=init_state,
|
| 169 |
+
enable_render=enable_render):
|
| 170 |
+
# setup rendering
|
| 171 |
+
# video_wrapper
|
| 172 |
+
assert isinstance(env.env, VideoRecordingWrapper)
|
| 173 |
+
env.env.video_recoder.stop()
|
| 174 |
+
env.env.file_path = None
|
| 175 |
+
if enable_render:
|
| 176 |
+
filename = pathlib.Path(output_dir).joinpath(
|
| 177 |
+
'media', wv.util.generate_id() + ".mp4")
|
| 178 |
+
filename.parent.mkdir(parents=False, exist_ok=True)
|
| 179 |
+
filename = str(filename)
|
| 180 |
+
env.env.file_path = filename
|
| 181 |
+
|
| 182 |
+
# switch to init_state reset
|
| 183 |
+
assert isinstance(env.env.env, RobomimicImageWrapper)
|
| 184 |
+
env.env.env.init_state = init_state
|
| 185 |
+
|
| 186 |
+
env_seeds.append(train_idx)
|
| 187 |
+
env_prefixs.append('train/')
|
| 188 |
+
env_init_fn_dills.append(dill.dumps(init_fn))
|
| 189 |
+
|
| 190 |
+
# test
|
| 191 |
+
for i in range(n_test):
|
| 192 |
+
seed = test_start_seed + i
|
| 193 |
+
enable_render = i < n_test_vis
|
| 194 |
+
|
| 195 |
+
def init_fn(env, seed=seed,
|
| 196 |
+
enable_render=enable_render):
|
| 197 |
+
# setup rendering
|
| 198 |
+
# video_wrapper
|
| 199 |
+
assert isinstance(env.env, VideoRecordingWrapper)
|
| 200 |
+
env.env.video_recoder.stop()
|
| 201 |
+
env.env.file_path = None
|
| 202 |
+
if enable_render:
|
| 203 |
+
filename = pathlib.Path(output_dir).joinpath(
|
| 204 |
+
'media', wv.util.generate_id() + ".mp4")
|
| 205 |
+
filename.parent.mkdir(parents=False, exist_ok=True)
|
| 206 |
+
filename = str(filename)
|
| 207 |
+
env.env.file_path = filename
|
| 208 |
+
|
| 209 |
+
# switch to seed reset
|
| 210 |
+
assert isinstance(env.env.env, RobomimicImageWrapper)
|
| 211 |
+
env.env.env.init_state = None
|
| 212 |
+
env.seed(seed)
|
| 213 |
+
|
| 214 |
+
env_seeds.append(seed)
|
| 215 |
+
env_prefixs.append('test/')
|
| 216 |
+
env_init_fn_dills.append(dill.dumps(init_fn))
|
| 217 |
+
|
| 218 |
+
env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
|
| 219 |
+
|
| 220 |
+
self.env_meta = env_meta
|
| 221 |
+
self.env = env
|
| 222 |
+
self.env_fns = env_fns
|
| 223 |
+
self.env_seeds = env_seeds
|
| 224 |
+
self.env_prefixs = env_prefixs
|
| 225 |
+
self.env_init_fn_dills = env_init_fn_dills
|
| 226 |
+
self.fps = fps
|
| 227 |
+
self.crf = crf
|
| 228 |
+
self.n_obs_steps = n_obs_steps
|
| 229 |
+
self.n_action_steps = n_action_steps
|
| 230 |
+
self.past_action = past_action
|
| 231 |
+
self.max_steps = max_steps
|
| 232 |
+
self.rotation_transformer = rotation_transformer
|
| 233 |
+
self.abs_action = abs_action
|
| 234 |
+
self.tqdm_interval_sec = tqdm_interval_sec
|
| 235 |
+
self.max_rewards = {}
|
| 236 |
+
for prefix in self.env_prefixs:
|
| 237 |
+
self.max_rewards[prefix] = 0
|
| 238 |
+
|
| 239 |
+
def run(self, policy: BaseImagePolicy):
|
| 240 |
+
device = policy.device
|
| 241 |
+
dtype = policy.dtype
|
| 242 |
+
env = self.env
|
| 243 |
+
|
| 244 |
+
# plan for rollout
|
| 245 |
+
n_envs = len(self.env_fns)
|
| 246 |
+
n_inits = len(self.env_init_fn_dills)
|
| 247 |
+
n_chunks = math.ceil(n_inits / n_envs)
|
| 248 |
+
|
| 249 |
+
# allocate data
|
| 250 |
+
all_video_paths = [None] * n_inits
|
| 251 |
+
all_rewards = [None] * n_inits
|
| 252 |
+
|
| 253 |
+
for chunk_idx in range(n_chunks):
|
| 254 |
+
start = chunk_idx * n_envs
|
| 255 |
+
end = min(n_inits, start + n_envs)
|
| 256 |
+
this_global_slice = slice(start, end)
|
| 257 |
+
this_n_active_envs = end - start
|
| 258 |
+
this_local_slice = slice(0,this_n_active_envs)
|
| 259 |
+
|
| 260 |
+
this_init_fns = self.env_init_fn_dills[this_global_slice]
|
| 261 |
+
n_diff = n_envs - len(this_init_fns)
|
| 262 |
+
if n_diff > 0:
|
| 263 |
+
this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
|
| 264 |
+
assert len(this_init_fns) == n_envs
|
| 265 |
+
|
| 266 |
+
# init envs
|
| 267 |
+
env.call_each('run_dill_function',
|
| 268 |
+
args_list=[(x,) for x in this_init_fns])
|
| 269 |
+
|
| 270 |
+
# start rollout
|
| 271 |
+
obs = env.reset()
|
| 272 |
+
past_action = None
|
| 273 |
+
policy.reset()
|
| 274 |
+
|
| 275 |
+
env_name = self.env_meta['env_name']
|
| 276 |
+
pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx+1}/{n_chunks}",
|
| 277 |
+
leave=False, mininterval=self.tqdm_interval_sec)
|
| 278 |
+
|
| 279 |
+
done = False
|
| 280 |
+
while not done:
|
| 281 |
+
# create obs dict
|
| 282 |
+
np_obs_dict = dict(obs)
|
| 283 |
+
if self.past_action and (past_action is not None):
|
| 284 |
+
# TODO: not tested
|
| 285 |
+
np_obs_dict['past_action'] = past_action[
|
| 286 |
+
:,-(self.n_obs_steps-1):].astype(np.float32)
|
| 287 |
+
|
| 288 |
+
# device transfer
|
| 289 |
+
obs_dict = dict_apply(np_obs_dict,
|
| 290 |
+
lambda x: torch.from_numpy(x).to(
|
| 291 |
+
device=device))
|
| 292 |
+
|
| 293 |
+
# run policy
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
action_dict = policy.predict_action(obs_dict)
|
| 296 |
+
|
| 297 |
+
# device_transfer
|
| 298 |
+
np_action_dict = dict_apply(action_dict,
|
| 299 |
+
lambda x: x.detach().to('cpu').numpy())
|
| 300 |
+
|
| 301 |
+
action = np_action_dict['action']
|
| 302 |
+
if not np.all(np.isfinite(action)):
|
| 303 |
+
print(action)
|
| 304 |
+
raise RuntimeError("Nan or Inf action")
|
| 305 |
+
|
| 306 |
+
# step env
|
| 307 |
+
env_action = action
|
| 308 |
+
if self.abs_action:
|
| 309 |
+
env_action = self.undo_transform_action(action)
|
| 310 |
+
|
| 311 |
+
obs, reward, done, info = env.step(env_action)
|
| 312 |
+
done = np.all(done)
|
| 313 |
+
past_action = action
|
| 314 |
+
|
| 315 |
+
# update pbar
|
| 316 |
+
pbar.update(action.shape[1])
|
| 317 |
+
pbar.close()
|
| 318 |
+
|
| 319 |
+
# collect data for this round
|
| 320 |
+
all_video_paths[this_global_slice] = env.render()[this_local_slice]
|
| 321 |
+
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
| 322 |
+
# clear out video buffer
|
| 323 |
+
_ = env.reset()
|
| 324 |
+
|
| 325 |
+
# log
|
| 326 |
+
max_rewards = collections.defaultdict(list)
|
| 327 |
+
log_data = dict()
|
| 328 |
+
# results reported in the paper are generated using the commented out line below
|
| 329 |
+
# which will only report and average metrics from first n_envs initial condition and seeds
|
| 330 |
+
# fortunately this won't invalidate our conclusion since
|
| 331 |
+
# 1. This bug only affects the variance of metrics, not their mean
|
| 332 |
+
# 2. All baseline methods are evaluated using the same code
|
| 333 |
+
# to completely reproduce reported numbers, uncomment this line:
|
| 334 |
+
# for i in range(len(self.env_fns)):
|
| 335 |
+
# and comment out this line
|
| 336 |
+
for i in range(n_inits):
|
| 337 |
+
seed = self.env_seeds[i]
|
| 338 |
+
prefix = self.env_prefixs[i]
|
| 339 |
+
max_reward = np.max(all_rewards[i])
|
| 340 |
+
max_rewards[prefix].append(max_reward)
|
| 341 |
+
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
|
| 342 |
+
|
| 343 |
+
# visualize sim
|
| 344 |
+
video_path = all_video_paths[i]
|
| 345 |
+
if video_path is not None:
|
| 346 |
+
sim_video = wandb.Video(video_path)
|
| 347 |
+
log_data[prefix+f'sim_video_{seed}'] = sim_video
|
| 348 |
+
|
| 349 |
+
# log aggregate metrics
|
| 350 |
+
for prefix, value in max_rewards.items():
|
| 351 |
+
name = prefix+'mean_score'
|
| 352 |
+
value = np.mean(value)
|
| 353 |
+
log_data[name] = value
|
| 354 |
+
self.max_rewards[prefix] = max(self.max_rewards[prefix], value)
|
| 355 |
+
log_data[prefix+'max_score'] = self.max_rewards[prefix]
|
| 356 |
+
|
| 357 |
+
return log_data
|
| 358 |
+
|
| 359 |
+
def undo_transform_action(self, action):
|
| 360 |
+
raw_shape = action.shape
|
| 361 |
+
if raw_shape[-1] == 20:
|
| 362 |
+
# dual arm
|
| 363 |
+
action = action.reshape(-1,2,10)
|
| 364 |
+
|
| 365 |
+
d_rot = action.shape[-1] - 4
|
| 366 |
+
pos = action[...,:3]
|
| 367 |
+
rot = action[...,3:3+d_rot]
|
| 368 |
+
gripper = action[...,[-1]]
|
| 369 |
+
rot = self.rotation_transformer.inverse(rot)
|
| 370 |
+
uaction = np.concatenate([
|
| 371 |
+
pos, rot, gripper
|
| 372 |
+
], axis=-1)
|
| 373 |
+
|
| 374 |
+
if raw_shape[-1] == 20:
|
| 375 |
+
# dual arm
|
| 376 |
+
uaction = uaction.reshape(*raw_shape[:-1], 14)
|
| 377 |
+
|
| 378 |
+
return uaction
|
equidiff/equi_diffpo/env_runner/robomimic_lowdim_runner.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import wandb
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import collections
|
| 6 |
+
import pathlib
|
| 7 |
+
import tqdm
|
| 8 |
+
import h5py
|
| 9 |
+
import dill
|
| 10 |
+
import math
|
| 11 |
+
import wandb.sdk.data_types.video as wv
|
| 12 |
+
from equi_diffpo.gym_util.async_vector_env import AsyncVectorEnv
|
| 13 |
+
# from equi_diffpo.gym_util.sync_vector_env import SyncVectorEnv
|
| 14 |
+
from equi_diffpo.gym_util.multistep_wrapper import MultiStepWrapper
|
| 15 |
+
from equi_diffpo.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
| 16 |
+
from equi_diffpo.model.common.rotation_transformer import RotationTransformer
|
| 17 |
+
|
| 18 |
+
from equi_diffpo.policy.base_lowdim_policy import BaseLowdimPolicy
|
| 19 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 20 |
+
from equi_diffpo.env_runner.base_lowdim_runner import BaseLowdimRunner
|
| 21 |
+
from equi_diffpo.env.robomimic.robomimic_lowdim_wrapper import RobomimicLowdimWrapper
|
| 22 |
+
import robomimic.utils.file_utils as FileUtils
|
| 23 |
+
import robomimic.utils.env_utils as EnvUtils
|
| 24 |
+
import robomimic.utils.obs_utils as ObsUtils
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_env(env_meta, obs_keys, enable_render=True):
|
| 28 |
+
ObsUtils.initialize_obs_modality_mapping_from_dict(
|
| 29 |
+
{'low_dim': obs_keys})
|
| 30 |
+
env = EnvUtils.create_env_from_metadata(
|
| 31 |
+
env_meta=env_meta,
|
| 32 |
+
render=False,
|
| 33 |
+
# only way to not show collision geometry
|
| 34 |
+
# is to enable render_offscreen
|
| 35 |
+
# which uses a lot of RAM.
|
| 36 |
+
render_offscreen=enable_render,
|
| 37 |
+
use_image_obs=enable_render,
|
| 38 |
+
)
|
| 39 |
+
return env
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RobomimicLowdimRunner(BaseLowdimRunner):
|
| 43 |
+
"""
|
| 44 |
+
Robomimic envs already enforces number of steps.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self,
|
| 48 |
+
output_dir,
|
| 49 |
+
dataset_path,
|
| 50 |
+
obs_keys,
|
| 51 |
+
n_train=10,
|
| 52 |
+
n_train_vis=3,
|
| 53 |
+
train_start_idx=0,
|
| 54 |
+
n_test=22,
|
| 55 |
+
n_test_vis=6,
|
| 56 |
+
test_start_seed=10000,
|
| 57 |
+
max_steps=400,
|
| 58 |
+
n_obs_steps=2,
|
| 59 |
+
n_action_steps=8,
|
| 60 |
+
n_latency_steps=0,
|
| 61 |
+
render_hw=(256,256),
|
| 62 |
+
render_camera_name='agentview',
|
| 63 |
+
fps=10,
|
| 64 |
+
crf=22,
|
| 65 |
+
past_action=False,
|
| 66 |
+
abs_action=False,
|
| 67 |
+
tqdm_interval_sec=5.0,
|
| 68 |
+
n_envs=None
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Assuming:
|
| 72 |
+
n_obs_steps=2
|
| 73 |
+
n_latency_steps=3
|
| 74 |
+
n_action_steps=4
|
| 75 |
+
o: obs
|
| 76 |
+
i: inference
|
| 77 |
+
a: action
|
| 78 |
+
Batch t:
|
| 79 |
+
|o|o| | | | | | |
|
| 80 |
+
| |i|i|i| | | | |
|
| 81 |
+
| | | | |a|a|a|a|
|
| 82 |
+
Batch t+1
|
| 83 |
+
| | | | |o|o| | | | | | |
|
| 84 |
+
| | | | | |i|i|i| | | | |
|
| 85 |
+
| | | | | | | | |a|a|a|a|
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
super().__init__(output_dir)
|
| 89 |
+
|
| 90 |
+
if n_envs is None:
|
| 91 |
+
n_envs = n_train + n_test
|
| 92 |
+
|
| 93 |
+
# handle latency step
|
| 94 |
+
# to mimic latency, we request n_latency_steps additional steps
|
| 95 |
+
# of past observations, and the discard the last n_latency_steps
|
| 96 |
+
env_n_obs_steps = n_obs_steps + n_latency_steps
|
| 97 |
+
env_n_action_steps = n_action_steps
|
| 98 |
+
|
| 99 |
+
# assert n_obs_steps <= n_action_steps
|
| 100 |
+
dataset_path = os.path.expanduser(dataset_path)
|
| 101 |
+
robosuite_fps = 20
|
| 102 |
+
steps_per_render = max(robosuite_fps // fps, 1)
|
| 103 |
+
|
| 104 |
+
# read from dataset
|
| 105 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(
|
| 106 |
+
dataset_path)
|
| 107 |
+
rotation_transformer = None
|
| 108 |
+
if abs_action:
|
| 109 |
+
env_meta['env_kwargs']['controller_configs']['control_delta'] = False
|
| 110 |
+
rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')
|
| 111 |
+
|
| 112 |
+
def env_fn():
|
| 113 |
+
robomimic_env = create_env(
|
| 114 |
+
env_meta=env_meta,
|
| 115 |
+
obs_keys=obs_keys
|
| 116 |
+
)
|
| 117 |
+
# Robosuite's hard reset causes excessive memory consumption.
|
| 118 |
+
# Disabled to run more envs.
|
| 119 |
+
# https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
|
| 120 |
+
robomimic_env.env.hard_reset = False
|
| 121 |
+
return MultiStepWrapper(
|
| 122 |
+
VideoRecordingWrapper(
|
| 123 |
+
RobomimicLowdimWrapper(
|
| 124 |
+
env=robomimic_env,
|
| 125 |
+
obs_keys=obs_keys,
|
| 126 |
+
init_state=None,
|
| 127 |
+
render_hw=render_hw,
|
| 128 |
+
render_camera_name=render_camera_name
|
| 129 |
+
),
|
| 130 |
+
video_recoder=VideoRecorder.create_h264(
|
| 131 |
+
fps=fps,
|
| 132 |
+
codec='h264',
|
| 133 |
+
input_pix_fmt='rgb24',
|
| 134 |
+
crf=crf,
|
| 135 |
+
thread_type='FRAME',
|
| 136 |
+
thread_count=1
|
| 137 |
+
),
|
| 138 |
+
file_path=None,
|
| 139 |
+
steps_per_render=steps_per_render
|
| 140 |
+
),
|
| 141 |
+
n_obs_steps=n_obs_steps,
|
| 142 |
+
n_action_steps=n_action_steps,
|
| 143 |
+
max_episode_steps=max_steps
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# For each process the OpenGL context can only be initialized once
|
| 147 |
+
# Since AsyncVectorEnv uses fork to create worker process,
|
| 148 |
+
# a separate env_fn that does not create OpenGL context (enable_render=False)
|
| 149 |
+
# is needed to initialize spaces.
|
| 150 |
+
def dummy_env_fn():
|
| 151 |
+
robomimic_env = create_env(
|
| 152 |
+
env_meta=env_meta,
|
| 153 |
+
obs_keys=obs_keys,
|
| 154 |
+
enable_render=False
|
| 155 |
+
)
|
| 156 |
+
return MultiStepWrapper(
|
| 157 |
+
VideoRecordingWrapper(
|
| 158 |
+
RobomimicLowdimWrapper(
|
| 159 |
+
env=robomimic_env,
|
| 160 |
+
obs_keys=obs_keys,
|
| 161 |
+
init_state=None,
|
| 162 |
+
render_hw=render_hw,
|
| 163 |
+
render_camera_name=render_camera_name
|
| 164 |
+
),
|
| 165 |
+
video_recoder=VideoRecorder.create_h264(
|
| 166 |
+
fps=fps,
|
| 167 |
+
codec='h264',
|
| 168 |
+
input_pix_fmt='rgb24',
|
| 169 |
+
crf=crf,
|
| 170 |
+
thread_type='FRAME',
|
| 171 |
+
thread_count=1
|
| 172 |
+
),
|
| 173 |
+
file_path=None,
|
| 174 |
+
steps_per_render=steps_per_render
|
| 175 |
+
),
|
| 176 |
+
n_obs_steps=n_obs_steps,
|
| 177 |
+
n_action_steps=n_action_steps,
|
| 178 |
+
max_episode_steps=max_steps
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
env_fns = [env_fn] * n_envs
|
| 182 |
+
env_seeds = list()
|
| 183 |
+
env_prefixs = list()
|
| 184 |
+
env_init_fn_dills = list()
|
| 185 |
+
|
| 186 |
+
# train
|
| 187 |
+
with h5py.File(dataset_path, 'r') as f:
|
| 188 |
+
for i in range(n_train):
|
| 189 |
+
train_idx = train_start_idx + i
|
| 190 |
+
enable_render = i < n_train_vis
|
| 191 |
+
init_state = f[f'data/demo_{train_idx}/states'][0]
|
| 192 |
+
|
| 193 |
+
def init_fn(env, init_state=init_state,
|
| 194 |
+
enable_render=enable_render):
|
| 195 |
+
# setup rendering
|
| 196 |
+
# video_wrapper
|
| 197 |
+
assert isinstance(env.env, VideoRecordingWrapper)
|
| 198 |
+
env.env.video_recoder.stop()
|
| 199 |
+
env.env.file_path = None
|
| 200 |
+
if enable_render:
|
| 201 |
+
filename = pathlib.Path(output_dir).joinpath(
|
| 202 |
+
'media', wv.util.generate_id() + ".mp4")
|
| 203 |
+
filename.parent.mkdir(parents=False, exist_ok=True)
|
| 204 |
+
filename = str(filename)
|
| 205 |
+
env.env.file_path = filename
|
| 206 |
+
|
| 207 |
+
# switch to init_state reset
|
| 208 |
+
assert isinstance(env.env.env, RobomimicLowdimWrapper)
|
| 209 |
+
env.env.env.init_state = init_state
|
| 210 |
+
|
| 211 |
+
env_seeds.append(train_idx)
|
| 212 |
+
env_prefixs.append('train/')
|
| 213 |
+
env_init_fn_dills.append(dill.dumps(init_fn))
|
| 214 |
+
|
| 215 |
+
# test
|
| 216 |
+
for i in range(n_test):
|
| 217 |
+
seed = test_start_seed + i
|
| 218 |
+
enable_render = i < n_test_vis
|
| 219 |
+
|
| 220 |
+
def init_fn(env, seed=seed,
|
| 221 |
+
enable_render=enable_render):
|
| 222 |
+
# setup rendering
|
| 223 |
+
# video_wrapper
|
| 224 |
+
assert isinstance(env.env, VideoRecordingWrapper)
|
| 225 |
+
env.env.video_recoder.stop()
|
| 226 |
+
env.env.file_path = None
|
| 227 |
+
if enable_render:
|
| 228 |
+
filename = pathlib.Path(output_dir).joinpath(
|
| 229 |
+
'media', wv.util.generate_id() + ".mp4")
|
| 230 |
+
filename.parent.mkdir(parents=False, exist_ok=True)
|
| 231 |
+
filename = str(filename)
|
| 232 |
+
env.env.file_path = filename
|
| 233 |
+
|
| 234 |
+
# switch to seed reset
|
| 235 |
+
assert isinstance(env.env.env, RobomimicLowdimWrapper)
|
| 236 |
+
env.env.env.init_state = None
|
| 237 |
+
env.seed(seed)
|
| 238 |
+
|
| 239 |
+
env_seeds.append(seed)
|
| 240 |
+
env_prefixs.append('test/')
|
| 241 |
+
env_init_fn_dills.append(dill.dumps(init_fn))
|
| 242 |
+
|
| 243 |
+
env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
|
| 244 |
+
# env = SyncVectorEnv(env_fns)
|
| 245 |
+
|
| 246 |
+
self.env_meta = env_meta
|
| 247 |
+
self.env = env
|
| 248 |
+
self.env_fns = env_fns
|
| 249 |
+
self.env_seeds = env_seeds
|
| 250 |
+
self.env_prefixs = env_prefixs
|
| 251 |
+
self.env_init_fn_dills = env_init_fn_dills
|
| 252 |
+
self.fps = fps
|
| 253 |
+
self.crf = crf
|
| 254 |
+
self.n_obs_steps = n_obs_steps
|
| 255 |
+
self.n_action_steps = n_action_steps
|
| 256 |
+
self.n_latency_steps = n_latency_steps
|
| 257 |
+
self.env_n_obs_steps = env_n_obs_steps
|
| 258 |
+
self.env_n_action_steps = env_n_action_steps
|
| 259 |
+
self.past_action = past_action
|
| 260 |
+
self.max_steps = max_steps
|
| 261 |
+
self.rotation_transformer = rotation_transformer
|
| 262 |
+
self.abs_action = abs_action
|
| 263 |
+
self.tqdm_interval_sec = tqdm_interval_sec
|
| 264 |
+
|
| 265 |
+
def run(self, policy: BaseLowdimPolicy):
|
| 266 |
+
device = policy.device
|
| 267 |
+
dtype = policy.dtype
|
| 268 |
+
env = self.env
|
| 269 |
+
|
| 270 |
+
# plan for rollout
|
| 271 |
+
n_envs = len(self.env_fns)
|
| 272 |
+
n_inits = len(self.env_init_fn_dills)
|
| 273 |
+
n_chunks = math.ceil(n_inits / n_envs)
|
| 274 |
+
|
| 275 |
+
# allocate data
|
| 276 |
+
all_video_paths = [None] * n_inits
|
| 277 |
+
all_rewards = [None] * n_inits
|
| 278 |
+
|
| 279 |
+
for chunk_idx in range(n_chunks):
|
| 280 |
+
start = chunk_idx * n_envs
|
| 281 |
+
end = min(n_inits, start + n_envs)
|
| 282 |
+
this_global_slice = slice(start, end)
|
| 283 |
+
this_n_active_envs = end - start
|
| 284 |
+
this_local_slice = slice(0,this_n_active_envs)
|
| 285 |
+
|
| 286 |
+
this_init_fns = self.env_init_fn_dills[this_global_slice]
|
| 287 |
+
n_diff = n_envs - len(this_init_fns)
|
| 288 |
+
if n_diff > 0:
|
| 289 |
+
this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
|
| 290 |
+
assert len(this_init_fns) == n_envs
|
| 291 |
+
|
| 292 |
+
# init envs
|
| 293 |
+
env.call_each('run_dill_function',
|
| 294 |
+
args_list=[(x,) for x in this_init_fns])
|
| 295 |
+
|
| 296 |
+
# start rollout
|
| 297 |
+
obs = env.reset()
|
| 298 |
+
past_action = None
|
| 299 |
+
policy.reset()
|
| 300 |
+
|
| 301 |
+
env_name = self.env_meta['env_name']
|
| 302 |
+
pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Lowdim {chunk_idx+1}/{n_chunks}",
|
| 303 |
+
leave=False, mininterval=self.tqdm_interval_sec)
|
| 304 |
+
|
| 305 |
+
done = False
|
| 306 |
+
while not done:
|
| 307 |
+
# create obs dict
|
| 308 |
+
np_obs_dict = {
|
| 309 |
+
# handle n_latency_steps by discarding the last n_latency_steps
|
| 310 |
+
'obs': obs[:,:self.n_obs_steps].astype(np.float32)
|
| 311 |
+
}
|
| 312 |
+
if self.past_action and (past_action is not None):
|
| 313 |
+
# TODO: not tested
|
| 314 |
+
np_obs_dict['past_action'] = past_action[
|
| 315 |
+
:,-(self.n_obs_steps-1):].astype(np.float32)
|
| 316 |
+
|
| 317 |
+
# device transfer
|
| 318 |
+
obs_dict = dict_apply(np_obs_dict,
|
| 319 |
+
lambda x: torch.from_numpy(x).to(
|
| 320 |
+
device=device))
|
| 321 |
+
|
| 322 |
+
# run policy
|
| 323 |
+
with torch.no_grad():
|
| 324 |
+
action_dict = policy.predict_action(obs_dict)
|
| 325 |
+
|
| 326 |
+
# device_transfer
|
| 327 |
+
np_action_dict = dict_apply(action_dict,
|
| 328 |
+
lambda x: x.detach().to('cpu').numpy())
|
| 329 |
+
|
| 330 |
+
# handle latency_steps, we discard the first n_latency_steps actions
|
| 331 |
+
# to simulate latency
|
| 332 |
+
action = np_action_dict['action'][:,self.n_latency_steps:]
|
| 333 |
+
if not np.all(np.isfinite(action)):
|
| 334 |
+
print(action)
|
| 335 |
+
raise RuntimeError("Nan or Inf action")
|
| 336 |
+
|
| 337 |
+
# step env
|
| 338 |
+
env_action = action
|
| 339 |
+
if self.abs_action:
|
| 340 |
+
env_action = self.undo_transform_action(action)
|
| 341 |
+
|
| 342 |
+
obs, reward, done, info = env.step(env_action)
|
| 343 |
+
done = np.all(done)
|
| 344 |
+
past_action = action
|
| 345 |
+
|
| 346 |
+
# update pbar
|
| 347 |
+
pbar.update(action.shape[1])
|
| 348 |
+
pbar.close()
|
| 349 |
+
|
| 350 |
+
# collect data for this round
|
| 351 |
+
all_video_paths[this_global_slice] = env.render()[this_local_slice]
|
| 352 |
+
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
| 353 |
+
|
| 354 |
+
# log
|
| 355 |
+
max_rewards = collections.defaultdict(list)
|
| 356 |
+
log_data = dict()
|
| 357 |
+
# results reported in the paper are generated using the commented out line below
|
| 358 |
+
# which will only report and average metrics from first n_envs initial condition and seeds
|
| 359 |
+
# fortunately this won't invalidate our conclusion since
|
| 360 |
+
# 1. This bug only affects the variance of metrics, not their mean
|
| 361 |
+
# 2. All baseline methods are evaluated using the same code
|
| 362 |
+
# to completely reproduce reported numbers, uncomment this line:
|
| 363 |
+
# for i in range(len(self.env_fns)):
|
| 364 |
+
# and comment out this line
|
| 365 |
+
for i in range(n_inits):
|
| 366 |
+
seed = self.env_seeds[i]
|
| 367 |
+
prefix = self.env_prefixs[i]
|
| 368 |
+
max_reward = np.max(all_rewards[i])
|
| 369 |
+
max_rewards[prefix].append(max_reward)
|
| 370 |
+
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
|
| 371 |
+
|
| 372 |
+
# visualize sim
|
| 373 |
+
video_path = all_video_paths[i]
|
| 374 |
+
if video_path is not None:
|
| 375 |
+
sim_video = wandb.Video(video_path)
|
| 376 |
+
log_data[prefix+f'sim_video_{seed}'] = sim_video
|
| 377 |
+
|
| 378 |
+
# log aggregate metrics
|
| 379 |
+
for prefix, value in max_rewards.items():
|
| 380 |
+
name = prefix+'mean_score'
|
| 381 |
+
value = np.mean(value)
|
| 382 |
+
log_data[name] = value
|
| 383 |
+
|
| 384 |
+
return log_data
|
| 385 |
+
|
| 386 |
+
def undo_transform_action(self, action):
|
| 387 |
+
raw_shape = action.shape
|
| 388 |
+
if raw_shape[-1] == 20:
|
| 389 |
+
# dual arm
|
| 390 |
+
action = action.reshape(-1,2,10)
|
| 391 |
+
|
| 392 |
+
d_rot = action.shape[-1] - 4
|
| 393 |
+
pos = action[...,:3]
|
| 394 |
+
rot = action[...,3:3+d_rot]
|
| 395 |
+
gripper = action[...,[-1]]
|
| 396 |
+
rot = self.rotation_transformer.inverse(rot)
|
| 397 |
+
uaction = np.concatenate([
|
| 398 |
+
pos, rot, gripper
|
| 399 |
+
], axis=-1)
|
| 400 |
+
|
| 401 |
+
if raw_shape[-1] == 20:
|
| 402 |
+
# dual arm
|
| 403 |
+
uaction = uaction.reshape(*raw_shape[:-1], 14)
|
| 404 |
+
|
| 405 |
+
return uaction
|
equidiff/equi_diffpo/policy/robomimic_image_policy.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
import torch
|
| 3 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer
|
| 4 |
+
from equi_diffpo.policy.base_image_policy import BaseImagePolicy
|
| 5 |
+
from equi_diffpo.common.pytorch_util import dict_apply
|
| 6 |
+
|
| 7 |
+
from robomimic.algo import algo_factory
|
| 8 |
+
from robomimic.algo.algo import PolicyAlgo
|
| 9 |
+
import robomimic.utils.obs_utils as ObsUtils
|
| 10 |
+
from equi_diffpo.common.robomimic_config_util import get_robomimic_config
|
| 11 |
+
|
| 12 |
+
class RobomimicImagePolicy(BaseImagePolicy):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
shape_meta: dict,
|
| 15 |
+
algo_name='bc_rnn',
|
| 16 |
+
obs_type='image',
|
| 17 |
+
task_name='square',
|
| 18 |
+
dataset_type='ph',
|
| 19 |
+
crop_shape=(76,76)
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
# parse shape_meta
|
| 24 |
+
action_shape = shape_meta['action']['shape']
|
| 25 |
+
assert len(action_shape) == 1
|
| 26 |
+
action_dim = action_shape[0]
|
| 27 |
+
obs_shape_meta = shape_meta['obs']
|
| 28 |
+
obs_config = {
|
| 29 |
+
'low_dim': [],
|
| 30 |
+
'rgb': [],
|
| 31 |
+
'depth': [],
|
| 32 |
+
'scan': []
|
| 33 |
+
}
|
| 34 |
+
obs_key_shapes = dict()
|
| 35 |
+
for key, attr in obs_shape_meta.items():
|
| 36 |
+
shape = attr['shape']
|
| 37 |
+
obs_key_shapes[key] = list(shape)
|
| 38 |
+
|
| 39 |
+
type = attr.get('type', 'low_dim')
|
| 40 |
+
if type == 'rgb':
|
| 41 |
+
obs_config['rgb'].append(key)
|
| 42 |
+
elif type == 'low_dim':
|
| 43 |
+
obs_config['low_dim'].append(key)
|
| 44 |
+
else:
|
| 45 |
+
raise RuntimeError(f"Unsupported obs type: {type}")
|
| 46 |
+
|
| 47 |
+
# get raw robomimic config
|
| 48 |
+
config = get_robomimic_config(
|
| 49 |
+
algo_name=algo_name,
|
| 50 |
+
hdf5_type=obs_type,
|
| 51 |
+
task_name=task_name,
|
| 52 |
+
dataset_type=dataset_type)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
with config.unlocked():
|
| 56 |
+
# set config with shape_meta
|
| 57 |
+
config.observation.modalities.obs = obs_config
|
| 58 |
+
|
| 59 |
+
if crop_shape is None:
|
| 60 |
+
for key, modality in config.observation.encoder.items():
|
| 61 |
+
if modality.obs_randomizer_class == 'CropRandomizer':
|
| 62 |
+
modality['obs_randomizer_class'] = None
|
| 63 |
+
else:
|
| 64 |
+
# set random crop parameter
|
| 65 |
+
ch, cw = crop_shape
|
| 66 |
+
for key, modality in config.observation.encoder.items():
|
| 67 |
+
if modality.obs_randomizer_class == 'CropRandomizer':
|
| 68 |
+
modality.obs_randomizer_kwargs.crop_height = ch
|
| 69 |
+
modality.obs_randomizer_kwargs.crop_width = cw
|
| 70 |
+
|
| 71 |
+
# init global state
|
| 72 |
+
ObsUtils.initialize_obs_utils_with_config(config)
|
| 73 |
+
|
| 74 |
+
# load model
|
| 75 |
+
model: PolicyAlgo = algo_factory(
|
| 76 |
+
algo_name=config.algo_name,
|
| 77 |
+
config=config,
|
| 78 |
+
obs_key_shapes=obs_key_shapes,
|
| 79 |
+
ac_dim=action_dim,
|
| 80 |
+
device='cpu',
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.model = model
|
| 84 |
+
self.nets = model.nets
|
| 85 |
+
self.normalizer = LinearNormalizer()
|
| 86 |
+
self.config = config
|
| 87 |
+
|
| 88 |
+
def to(self,*args,**kwargs):
|
| 89 |
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
| 90 |
+
if device is not None:
|
| 91 |
+
self.model.device = device
|
| 92 |
+
super().to(*args,**kwargs)
|
| 93 |
+
|
| 94 |
+
# =========== inference =============
|
| 95 |
+
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 96 |
+
nobs_dict = self.normalizer(obs_dict)
|
| 97 |
+
robomimic_obs_dict = dict_apply(nobs_dict, lambda x: x[:,0,...])
|
| 98 |
+
naction = self.model.get_action(robomimic_obs_dict)
|
| 99 |
+
action = self.normalizer['action'].unnormalize(naction)
|
| 100 |
+
# (B, Da)
|
| 101 |
+
result = {
|
| 102 |
+
'action': action[:,None,:] # (B, 1, Da)
|
| 103 |
+
}
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
def reset(self):
|
| 107 |
+
self.model.reset()
|
| 108 |
+
|
| 109 |
+
# =========== training ==============
|
| 110 |
+
def set_normalizer(self, normalizer: LinearNormalizer):
|
| 111 |
+
self.normalizer.load_state_dict(normalizer.state_dict())
|
| 112 |
+
|
| 113 |
+
def train_on_batch(self, batch, epoch, validate=False):
|
| 114 |
+
nobs = self.normalizer.normalize(batch['obs'])
|
| 115 |
+
nactions = self.normalizer['action'].normalize(batch['action'])
|
| 116 |
+
robomimic_batch = {
|
| 117 |
+
'obs': nobs,
|
| 118 |
+
'actions': nactions
|
| 119 |
+
}
|
| 120 |
+
input_batch = self.model.process_batch_for_training(
|
| 121 |
+
robomimic_batch)
|
| 122 |
+
info = self.model.train_on_batch(
|
| 123 |
+
batch=input_batch, epoch=epoch, validate=validate)
|
| 124 |
+
# keys: losses, predictions
|
| 125 |
+
return info
|
| 126 |
+
|
| 127 |
+
def on_epoch_end(self, epoch):
|
| 128 |
+
self.model.on_epoch_end(epoch)
|
| 129 |
+
|
| 130 |
+
def get_optimizer(self):
|
| 131 |
+
return self.model.optimizers['policy']
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def test():
|
| 135 |
+
import os
|
| 136 |
+
from omegaconf import OmegaConf
|
| 137 |
+
cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
|
| 138 |
+
cfg = OmegaConf.load(cfg_path)
|
| 139 |
+
shape_meta = cfg.shape_meta
|
| 140 |
+
|
| 141 |
+
policy = RobomimicImagePolicy(shape_meta=shape_meta)
|
| 142 |
+
|
equidiff/equi_diffpo/scripts/robomimic_dataset_action_comparison.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import click
|
| 11 |
+
import pathlib
|
| 12 |
+
import h5py
|
| 13 |
+
import numpy as np
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from scipy.spatial.transform import Rotation
|
| 16 |
+
|
| 17 |
+
def read_all_actions(hdf5_file, metric_skip_steps=1):
|
| 18 |
+
n_demos = len(hdf5_file['data'])
|
| 19 |
+
all_actions = list()
|
| 20 |
+
for i in tqdm(range(n_demos)):
|
| 21 |
+
actions = hdf5_file[f'data/demo_{i}/actions'][:]
|
| 22 |
+
all_actions.append(actions[metric_skip_steps:])
|
| 23 |
+
all_actions = np.concatenate(all_actions, axis=0)
|
| 24 |
+
return all_actions
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@click.command()
|
| 28 |
+
@click.option('-i', '--input', required=True, help='input hdf5 path')
|
| 29 |
+
@click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
|
| 30 |
+
def main(input, output):
|
| 31 |
+
# process inputs
|
| 32 |
+
input = pathlib.Path(input).expanduser()
|
| 33 |
+
assert input.is_file()
|
| 34 |
+
output = pathlib.Path(output).expanduser()
|
| 35 |
+
assert output.is_file()
|
| 36 |
+
|
| 37 |
+
input_file = h5py.File(str(input), 'r')
|
| 38 |
+
output_file = h5py.File(str(output), 'r')
|
| 39 |
+
|
| 40 |
+
input_all_actions = read_all_actions(input_file)
|
| 41 |
+
output_all_actions = read_all_actions(output_file)
|
| 42 |
+
pos_dist = np.linalg.norm(input_all_actions[:,:3] - output_all_actions[:,:3], axis=-1)
|
| 43 |
+
rot_dist = (Rotation.from_rotvec(input_all_actions[:,3:6]
|
| 44 |
+
) * Rotation.from_rotvec(output_all_actions[:,3:6]).inv()
|
| 45 |
+
).magnitude()
|
| 46 |
+
|
| 47 |
+
print(f'max pos dist: {pos_dist.max()}')
|
| 48 |
+
print(f'max rot dist: {rot_dist.max()}')
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
equidiff/equi_diffpo/scripts/robomimic_dataset_conversion.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
|
| 9 |
+
import multiprocessing
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import click
|
| 13 |
+
import pathlib
|
| 14 |
+
import h5py
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import collections
|
| 17 |
+
import pickle
|
| 18 |
+
from equi_diffpo.common.robomimic_util import RobomimicAbsoluteActionConverter
|
| 19 |
+
|
| 20 |
+
def worker(x):
|
| 21 |
+
path, idx, do_eval = x
|
| 22 |
+
converter = RobomimicAbsoluteActionConverter(path)
|
| 23 |
+
if do_eval:
|
| 24 |
+
abs_actions, info = converter.convert_and_eval_idx(idx)
|
| 25 |
+
else:
|
| 26 |
+
abs_actions = converter.convert_idx(idx)
|
| 27 |
+
info = dict()
|
| 28 |
+
return abs_actions, info
|
| 29 |
+
|
| 30 |
+
@click.command()
|
| 31 |
+
@click.option('-i', '--input', required=True, help='input hdf5 path')
|
| 32 |
+
@click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
|
| 33 |
+
@click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics')
|
| 34 |
+
@click.option('-n', '--num_workers', default=None, type=int)
|
| 35 |
+
def main(input, output, eval_dir, num_workers):
|
| 36 |
+
# process inputs
|
| 37 |
+
input = pathlib.Path(input).expanduser()
|
| 38 |
+
assert input.is_file()
|
| 39 |
+
output = pathlib.Path(output).expanduser()
|
| 40 |
+
assert output.parent.is_dir()
|
| 41 |
+
assert not output.is_dir()
|
| 42 |
+
|
| 43 |
+
do_eval = False
|
| 44 |
+
if eval_dir is not None:
|
| 45 |
+
eval_dir = pathlib.Path(eval_dir).expanduser()
|
| 46 |
+
assert eval_dir.parent.exists()
|
| 47 |
+
do_eval = True
|
| 48 |
+
|
| 49 |
+
converter = RobomimicAbsoluteActionConverter(input)
|
| 50 |
+
|
| 51 |
+
# run
|
| 52 |
+
with multiprocessing.Pool(num_workers) as pool:
|
| 53 |
+
results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))])
|
| 54 |
+
|
| 55 |
+
# save output
|
| 56 |
+
print('Copying hdf5')
|
| 57 |
+
shutil.copy(str(input), str(output))
|
| 58 |
+
|
| 59 |
+
# modify action
|
| 60 |
+
with h5py.File(output, 'r+') as out_file:
|
| 61 |
+
for i in tqdm(range(len(converter)), desc="Writing to output"):
|
| 62 |
+
abs_actions, info = results[i]
|
| 63 |
+
demo = out_file[f'data/demo_{i}']
|
| 64 |
+
demo['actions'][:] = abs_actions
|
| 65 |
+
|
| 66 |
+
# save eval
|
| 67 |
+
if do_eval:
|
| 68 |
+
eval_dir.mkdir(parents=False, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
print("Writing error_stats.pkl")
|
| 71 |
+
infos = [info for _, info in results]
|
| 72 |
+
pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb'))
|
| 73 |
+
|
| 74 |
+
print("Generating visualization")
|
| 75 |
+
metrics = ['pos', 'rot']
|
| 76 |
+
metrics_dicts = dict()
|
| 77 |
+
for m in metrics:
|
| 78 |
+
metrics_dicts[m] = collections.defaultdict(list)
|
| 79 |
+
|
| 80 |
+
for i in range(len(infos)):
|
| 81 |
+
info = infos[i]
|
| 82 |
+
for k, v in info.items():
|
| 83 |
+
for m in metrics:
|
| 84 |
+
metrics_dicts[m][k].append(v[m])
|
| 85 |
+
|
| 86 |
+
from matplotlib import pyplot as plt
|
| 87 |
+
plt.switch_backend('PDF')
|
| 88 |
+
|
| 89 |
+
fig, ax = plt.subplots(1, len(metrics))
|
| 90 |
+
for i in range(len(metrics)):
|
| 91 |
+
axis = ax[i]
|
| 92 |
+
data = metrics_dicts[metrics[i]]
|
| 93 |
+
for key, value in data.items():
|
| 94 |
+
axis.plot(value, label=key)
|
| 95 |
+
axis.legend()
|
| 96 |
+
axis.set_title(metrics[i])
|
| 97 |
+
fig.set_size_inches(10,4)
|
| 98 |
+
fig.savefig(str(eval_dir.joinpath('error_stats.pdf')))
|
| 99 |
+
fig.savefig(str(eval_dir.joinpath('error_stats.png')))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
equidiff/equi_diffpo/scripts/robomimic_dataset_obs_conversion.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
|
| 7 |
+
sys.path.append(ROOT_DIR)
|
| 8 |
+
|
| 9 |
+
import multiprocessing
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import click
|
| 13 |
+
import pathlib
|
| 14 |
+
import h5py
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import numpy as np
|
| 17 |
+
import collections
|
| 18 |
+
import pickle
|
| 19 |
+
from equi_diffpo.common.robomimic_util import RobomimicObsConverter
|
| 20 |
+
|
| 21 |
+
multiprocessing.set_start_method('spawn', force=True)
|
| 22 |
+
|
| 23 |
+
def worker(x):
|
| 24 |
+
path, idx = x
|
| 25 |
+
converter = RobomimicObsConverter(path)
|
| 26 |
+
obss = converter.convert_idx(idx)
|
| 27 |
+
return obss
|
| 28 |
+
|
| 29 |
+
@click.command()
|
| 30 |
+
@click.option('-i', '--input', required=True, help='input hdf5 path')
|
| 31 |
+
@click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
|
| 32 |
+
@click.option('-n', '--num_workers', default=None, type=int)
|
| 33 |
+
def main(input, output, num_workers):
|
| 34 |
+
# process inputs
|
| 35 |
+
input = pathlib.Path(input).expanduser()
|
| 36 |
+
assert input.is_file()
|
| 37 |
+
output = pathlib.Path(output).expanduser()
|
| 38 |
+
assert output.parent.is_dir()
|
| 39 |
+
assert not output.is_dir()
|
| 40 |
+
|
| 41 |
+
converter = RobomimicObsConverter(input)
|
| 42 |
+
|
| 43 |
+
# save output
|
| 44 |
+
print('Copying hdf5')
|
| 45 |
+
shutil.copy(str(input), str(output))
|
| 46 |
+
|
| 47 |
+
# run
|
| 48 |
+
idx = 0
|
| 49 |
+
while idx < len(converter):
|
| 50 |
+
with multiprocessing.Pool(num_workers) as pool:
|
| 51 |
+
end = min(idx + num_workers, len(converter))
|
| 52 |
+
results = pool.map(worker, [(input, i) for i in range(idx, end)])
|
| 53 |
+
|
| 54 |
+
# modify action
|
| 55 |
+
print('Writing {} to {}'.format(idx, end))
|
| 56 |
+
with h5py.File(output, 'r+') as out_file:
|
| 57 |
+
for i in tqdm(range(idx, end), desc="Writing to output"):
|
| 58 |
+
obss = results[i - idx]
|
| 59 |
+
demo = out_file[f'data/demo_{i}']
|
| 60 |
+
del demo['obs']
|
| 61 |
+
for k in obss:
|
| 62 |
+
demo.create_dataset("obs/{}".format(k), data=np.array(obss[k]), compression="gzip")
|
| 63 |
+
|
| 64 |
+
idx = end
|
| 65 |
+
del results
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|