Lillianwei commited on
Commit
1501ed7
·
1 Parent(s): 342fd2c
equidiff/equi_diffpo/common/robomimic_config_util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from robomimic.config import config_factory
3
+ import robomimic.scripts.generate_paper_configs as gpc
4
+ from robomimic.scripts.generate_paper_configs import (
5
+ modify_config_for_default_image_exp,
6
+ modify_config_for_default_low_dim_exp,
7
+ modify_config_for_dataset,
8
+ )
9
+
10
+ def get_robomimic_config(
11
+ algo_name='bc_rnn',
12
+ hdf5_type='low_dim',
13
+ task_name='square',
14
+ dataset_type='ph'
15
+ ):
16
+ base_dataset_dir = '/tmp/null'
17
+ filter_key = None
18
+
19
+ # decide whether to use low-dim or image training defaults
20
+ modifier_for_obs = modify_config_for_default_image_exp
21
+ if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
22
+ modifier_for_obs = modify_config_for_default_low_dim_exp
23
+
24
+ algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name
25
+ config = config_factory(algo_name=algo_config_name)
26
+ # turn into default config for observation modalities (e.g.: low-dim or rgb)
27
+ config = modifier_for_obs(config)
28
+ # add in config based on the dataset
29
+ config = modify_config_for_dataset(
30
+ config=config,
31
+ task_name=task_name,
32
+ dataset_type=dataset_type,
33
+ hdf5_type=hdf5_type,
34
+ base_dataset_dir=base_dataset_dir,
35
+ filter_key=filter_key,
36
+ )
37
+ # add in algo hypers based on dataset
38
+ algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset')
39
+ config = algo_config_modifier(
40
+ config=config,
41
+ task_name=task_name,
42
+ dataset_type=dataset_type,
43
+ hdf5_type=hdf5_type,
44
+ )
45
+ return config
46
+
47
+
equidiff/equi_diffpo/common/robomimic_util.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import copy
3
+
4
+ import h5py
5
+ import robomimic.utils.obs_utils as ObsUtils
6
+ import robomimic.utils.file_utils as FileUtils
7
+ import robomimic.utils.env_utils as EnvUtils
8
+ import robomimic.utils.tensor_utils as TensorUtils
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from robomimic.config import config_factory
12
+
13
+
14
+ class RobomimicAbsoluteActionConverter:
15
+ def __init__(self, dataset_path, algo_name='bc'):
16
+ # default BC config
17
+ config = config_factory(algo_name=algo_name)
18
+
19
+ # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
20
+ # must ran before create dataset
21
+ ObsUtils.initialize_obs_utils_with_config(config)
22
+
23
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
24
+ abs_env_meta = copy.deepcopy(env_meta)
25
+ abs_env_meta['env_kwargs']['controller_configs']['control_delta'] = False
26
+
27
+ env = EnvUtils.create_env_from_metadata(
28
+ env_meta=env_meta,
29
+ render=False,
30
+ render_offscreen=False,
31
+ use_image_obs=False,
32
+ )
33
+ assert len(env.env.robots) in (1, 2)
34
+
35
+ abs_env = EnvUtils.create_env_from_metadata(
36
+ env_meta=abs_env_meta,
37
+ render=False,
38
+ render_offscreen=False,
39
+ use_image_obs=False,
40
+ )
41
+ assert not abs_env.env.robots[0].controller.use_delta
42
+
43
+ self.env = env
44
+ self.abs_env = abs_env
45
+ self.file = h5py.File(dataset_path, 'r')
46
+
47
+ def __len__(self):
48
+ return len(self.file['data'])
49
+
50
+ def convert_actions(self,
51
+ states: np.ndarray,
52
+ actions: np.ndarray) -> np.ndarray:
53
+ """
54
+ Given state and delta action sequence
55
+ generate equivalent goal position and orientation for each step
56
+ keep the original gripper action intact.
57
+ """
58
+ # in case of multi robot
59
+ # reshape (N,14) to (N,2,7)
60
+ # or (N,7) to (N,1,7)
61
+ stacked_actions = actions.reshape(*actions.shape[:-1],-1,7)
62
+
63
+ env = self.env
64
+ # generate abs actions
65
+ action_goal_pos = np.zeros(
66
+ stacked_actions.shape[:-1]+(3,),
67
+ dtype=stacked_actions.dtype)
68
+ action_goal_ori = np.zeros(
69
+ stacked_actions.shape[:-1]+(3,),
70
+ dtype=stacked_actions.dtype)
71
+ action_gripper = stacked_actions[...,[-1]]
72
+ for i in range(len(states)):
73
+ _ = env.reset_to({'states': states[i]})
74
+
75
+ # taken from robot_env.py L#454
76
+ for idx, robot in enumerate(env.env.robots):
77
+ # run controller goal generator
78
+ robot.control(stacked_actions[i,idx], policy_step=True)
79
+
80
+ # read pos and ori from robots
81
+ controller = robot.controller
82
+ action_goal_pos[i,idx] = controller.goal_pos
83
+ action_goal_ori[i,idx] = Rotation.from_matrix(
84
+ controller.goal_ori).as_rotvec()
85
+
86
+ stacked_abs_actions = np.concatenate([
87
+ action_goal_pos,
88
+ action_goal_ori,
89
+ action_gripper
90
+ ], axis=-1)
91
+ abs_actions = stacked_abs_actions.reshape(actions.shape)
92
+ return abs_actions
93
+
94
+ def convert_idx(self, idx):
95
+ file = self.file
96
+ demo = file[f'data/demo_{idx}']
97
+ # input
98
+ states = demo['states'][:]
99
+ actions = demo['actions'][:]
100
+
101
+ # generate abs actions
102
+ abs_actions = self.convert_actions(states, actions)
103
+ return abs_actions
104
+
105
+ def convert_and_eval_idx(self, idx):
106
+ env = self.env
107
+ abs_env = self.abs_env
108
+ file = self.file
109
+ # first step have high error for some reason, not representative
110
+ eval_skip_steps = 1
111
+
112
+ demo = file[f'data/demo_{idx}']
113
+ # input
114
+ states = demo['states'][:]
115
+ actions = demo['actions'][:]
116
+
117
+ # generate abs actions
118
+ abs_actions = self.convert_actions(states, actions)
119
+
120
+ # verify
121
+ robot0_eef_pos = demo['obs']['robot0_eef_pos'][:]
122
+ robot0_eef_quat = demo['obs']['robot0_eef_quat'][:]
123
+
124
+ delta_error_info = self.evaluate_rollout_error(
125
+ env, states, actions, robot0_eef_pos, robot0_eef_quat,
126
+ metric_skip_steps=eval_skip_steps)
127
+ abs_error_info = self.evaluate_rollout_error(
128
+ abs_env, states, abs_actions, robot0_eef_pos, robot0_eef_quat,
129
+ metric_skip_steps=eval_skip_steps)
130
+
131
+ info = {
132
+ 'delta_max_error': delta_error_info,
133
+ 'abs_max_error': abs_error_info
134
+ }
135
+ return abs_actions, info
136
+
137
+ @staticmethod
138
+ def evaluate_rollout_error(env,
139
+ states, actions,
140
+ robot0_eef_pos,
141
+ robot0_eef_quat,
142
+ metric_skip_steps=1):
143
+ # first step have high error for some reason, not representative
144
+
145
+ # evaluate abs actions
146
+ rollout_next_states = list()
147
+ rollout_next_eef_pos = list()
148
+ rollout_next_eef_quat = list()
149
+ obs = env.reset_to({'states': states[0]})
150
+ for i in range(len(states)):
151
+ obs = env.reset_to({'states': states[i]})
152
+ obs, reward, done, info = env.step(actions[i])
153
+ obs = env.get_observation()
154
+ rollout_next_states.append(env.get_state()['states'])
155
+ rollout_next_eef_pos.append(obs['robot0_eef_pos'])
156
+ rollout_next_eef_quat.append(obs['robot0_eef_quat'])
157
+ rollout_next_states = np.array(rollout_next_states)
158
+ rollout_next_eef_pos = np.array(rollout_next_eef_pos)
159
+ rollout_next_eef_quat = np.array(rollout_next_eef_quat)
160
+
161
+ next_state_diff = states[1:] - rollout_next_states[:-1]
162
+ max_next_state_diff = np.max(np.abs(next_state_diff[metric_skip_steps:]))
163
+
164
+ next_eef_pos_diff = robot0_eef_pos[1:] - rollout_next_eef_pos[:-1]
165
+ next_eef_pos_dist = np.linalg.norm(next_eef_pos_diff, axis=-1)
166
+ max_next_eef_pos_dist = next_eef_pos_dist[metric_skip_steps:].max()
167
+
168
+ next_eef_rot_diff = Rotation.from_quat(robot0_eef_quat[1:]) \
169
+ * Rotation.from_quat(rollout_next_eef_quat[:-1]).inv()
170
+ next_eef_rot_dist = next_eef_rot_diff.magnitude()
171
+ max_next_eef_rot_dist = next_eef_rot_dist[metric_skip_steps:].max()
172
+
173
+ info = {
174
+ 'state': max_next_state_diff,
175
+ 'pos': max_next_eef_pos_dist,
176
+ 'rot': max_next_eef_rot_dist
177
+ }
178
+ return info
179
+
180
+ class RobomimicObsConverter:
181
+ def __init__(self, dataset_path, algo_name='bc'):
182
+ # default BC config
183
+ # config = config_factory(algo_name=algo_name)
184
+
185
+ # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
186
+ # must ran before create dataset
187
+ # ObsUtils.initialize_obs_utils_with_config(config)
188
+
189
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
190
+ # env_meta['env_kwargs']['camera_names'] = ['birdview', 'agentview', 'sideview', 'robot0_eye_in_hand']
191
+
192
+ env = EnvUtils.create_env_for_data_processing(
193
+ env_meta=env_meta,
194
+ # camera_names=['frontview', 'birdview', 'agentview', 'sideview', 'agentview_full', 'robot0_robotview', 'robot0_eye_in_hand'],
195
+ camera_names=['birdview', 'agentview', 'sideview', 'robot0_eye_in_hand'],
196
+ camera_height=84,
197
+ camera_width=84,
198
+ reward_shaping=False,
199
+ )
200
+ # env = EnvUtils.create_env_from_metadata(
201
+ # env_meta=env_meta,
202
+ # render=True,
203
+ # render_offscreen=True,
204
+ # use_image_obs=True,
205
+ # )
206
+
207
+ self.env = env
208
+ self.file = h5py.File(dataset_path, 'r')
209
+
210
+ def __len__(self):
211
+ return len(self.file['data'])
212
+
213
+ def convert_obs(self, initial_state, states):
214
+ obss = []
215
+ self.env.reset()
216
+ obs = self.env.reset_to(initial_state)
217
+ obss.append(obs)
218
+ for i in range(1, len(states)):
219
+ obs = self.env.reset_to({'states': states[i]})
220
+ obss.append(obs)
221
+ return TensorUtils.list_of_flat_dict_to_dict_of_list(obss)
222
+
223
+ def convert_idx(self, idx):
224
+ file = self.file
225
+ demo = file[f'data/demo_{idx}']
226
+ # input
227
+
228
+ states = demo['states'][:]
229
+ initial_state = dict(states=states[0])
230
+ initial_state["model"] = demo.attrs["model_file"]
231
+
232
+ # generate abs actions
233
+ obss = self.convert_obs(initial_state, states)
234
+ del obss['birdview_image']
235
+ del obss['birdview_depth']
236
+ del obss['agentview_depth']
237
+ del obss['sideview_image']
238
+ del obss['sideview_depth']
239
+ del obss['robot0_eye_in_hand_depth']
240
+ return obss
equidiff/equi_diffpo/dataset/robomimic_replay_image_dataset.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ from tqdm import tqdm
6
+ import zarr
7
+ import os
8
+ import shutil
9
+ import copy
10
+ import json
11
+ import hashlib
12
+ from filelock import FileLock
13
+ from threadpoolctl import threadpool_limits
14
+ import concurrent.futures
15
+ import multiprocessing
16
+ from omegaconf import OmegaConf
17
+ from equi_diffpo.common.pytorch_util import dict_apply
18
+ from equi_diffpo.dataset.base_dataset import BaseImageDataset, LinearNormalizer
19
+ from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
20
+ from equi_diffpo.model.common.rotation_transformer import RotationTransformer
21
+ from equi_diffpo.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
22
+ from equi_diffpo.common.replay_buffer import ReplayBuffer
23
+ from equi_diffpo.common.sampler import SequenceSampler, get_val_mask
24
+ from equi_diffpo.common.normalize_util import (
25
+ robomimic_abs_action_only_normalizer_from_stat,
26
+ robomimic_abs_action_only_dual_arm_normalizer_from_stat,
27
+ get_range_normalizer_from_stat,
28
+ get_image_range_normalizer,
29
+ get_identity_normalizer_from_stat,
30
+ array_to_stats
31
+ )
32
+ register_codecs()
33
+
34
+ class RobomimicReplayImageDataset(BaseImageDataset):
35
+ def __init__(self,
36
+ shape_meta: dict,
37
+ dataset_path: str,
38
+ horizon=1,
39
+ pad_before=0,
40
+ pad_after=0,
41
+ n_obs_steps=None,
42
+ abs_action=False,
43
+ rotation_rep='rotation_6d', # ignored when abs_action=False
44
+ use_legacy_normalizer=False,
45
+ use_cache=False,
46
+ seed=42,
47
+ val_ratio=0.0,
48
+ n_demo=100
49
+ ):
50
+ self.n_demo = n_demo
51
+ rotation_transformer = RotationTransformer(
52
+ from_rep='axis_angle', to_rep=rotation_rep)
53
+
54
+ replay_buffer = None
55
+ if use_cache:
56
+ cache_zarr_path = dataset_path + f'.{n_demo}.' + '.zarr.zip'
57
+ cache_lock_path = cache_zarr_path + '.lock'
58
+ print('Acquiring lock on cache.')
59
+ with FileLock(cache_lock_path):
60
+ if not os.path.exists(cache_zarr_path):
61
+ # cache does not exists
62
+ try:
63
+ print('Cache does not exist. Creating!')
64
+ # store = zarr.DirectoryStore(cache_zarr_path)
65
+ replay_buffer = _convert_robomimic_to_replay(
66
+ store=zarr.MemoryStore(),
67
+ shape_meta=shape_meta,
68
+ dataset_path=dataset_path,
69
+ abs_action=abs_action,
70
+ rotation_transformer=rotation_transformer,
71
+ n_demo=n_demo)
72
+ print('Saving cache to disk.')
73
+ with zarr.ZipStore(cache_zarr_path) as zip_store:
74
+ replay_buffer.save_to_store(
75
+ store=zip_store
76
+ )
77
+ except Exception as e:
78
+ shutil.rmtree(cache_zarr_path)
79
+ raise e
80
+ else:
81
+ print('Loading cached ReplayBuffer from Disk.')
82
+ with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
83
+ replay_buffer = ReplayBuffer.copy_from_store(
84
+ src_store=zip_store, store=zarr.MemoryStore())
85
+ print('Loaded!')
86
+ else:
87
+ replay_buffer = _convert_robomimic_to_replay(
88
+ store=zarr.MemoryStore(),
89
+ shape_meta=shape_meta,
90
+ dataset_path=dataset_path,
91
+ abs_action=abs_action,
92
+ rotation_transformer=rotation_transformer,
93
+ n_demo=n_demo)
94
+
95
+ rgb_keys = list()
96
+ lowdim_keys = list()
97
+ obs_shape_meta = shape_meta['obs']
98
+ for key, attr in obs_shape_meta.items():
99
+ type = attr.get('type', 'low_dim')
100
+ if type == 'rgb':
101
+ rgb_keys.append(key)
102
+ elif type == 'low_dim':
103
+ lowdim_keys.append(key)
104
+
105
+ # for key in rgb_keys:
106
+ # replay_buffer[key].compressor.numthreads=1
107
+
108
+ key_first_k = dict()
109
+ if n_obs_steps is not None:
110
+ # only take first k obs from images
111
+ for key in rgb_keys + lowdim_keys:
112
+ key_first_k[key] = n_obs_steps
113
+
114
+ val_mask = get_val_mask(
115
+ n_episodes=replay_buffer.n_episodes,
116
+ val_ratio=val_ratio,
117
+ seed=seed)
118
+ train_mask = ~val_mask
119
+ sampler = SequenceSampler(
120
+ replay_buffer=replay_buffer,
121
+ sequence_length=horizon,
122
+ pad_before=pad_before,
123
+ pad_after=pad_after,
124
+ episode_mask=train_mask,
125
+ key_first_k=key_first_k)
126
+
127
+ self.replay_buffer = replay_buffer
128
+ self.sampler = sampler
129
+ self.shape_meta = shape_meta
130
+ self.rgb_keys = rgb_keys
131
+ self.lowdim_keys = lowdim_keys
132
+ self.abs_action = abs_action
133
+ self.n_obs_steps = n_obs_steps
134
+ self.train_mask = train_mask
135
+ self.horizon = horizon
136
+ self.pad_before = pad_before
137
+ self.pad_after = pad_after
138
+ self.use_legacy_normalizer = use_legacy_normalizer
139
+
140
+ def get_validation_dataset(self):
141
+ val_set = copy.copy(self)
142
+ val_set.sampler = SequenceSampler(
143
+ replay_buffer=self.replay_buffer,
144
+ sequence_length=self.horizon,
145
+ pad_before=self.pad_before,
146
+ pad_after=self.pad_after,
147
+ episode_mask=~self.train_mask
148
+ )
149
+ val_set.train_mask = ~self.train_mask
150
+ return val_set
151
+
152
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
153
+ normalizer = LinearNormalizer()
154
+
155
+ # action
156
+ stat = array_to_stats(self.replay_buffer['action'])
157
+ if self.abs_action:
158
+ if stat['mean'].shape[-1] > 10:
159
+ # dual arm
160
+ this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
161
+ else:
162
+ this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
163
+
164
+ if self.use_legacy_normalizer:
165
+ this_normalizer = normalizer_from_stat(stat)
166
+ else:
167
+ # already normalized
168
+ this_normalizer = get_identity_normalizer_from_stat(stat)
169
+ normalizer['action'] = this_normalizer
170
+
171
+ # obs
172
+ for key in self.lowdim_keys:
173
+ stat = array_to_stats(self.replay_buffer[key])
174
+
175
+ if key.endswith('pos'):
176
+ this_normalizer = get_range_normalizer_from_stat(stat)
177
+ elif key.endswith('quat'):
178
+ # quaternion is in [-1,1] already
179
+ this_normalizer = get_identity_normalizer_from_stat(stat)
180
+ elif key.endswith('qpos'):
181
+ this_normalizer = get_range_normalizer_from_stat(stat)
182
+ else:
183
+ raise RuntimeError('unsupported')
184
+ normalizer[key] = this_normalizer
185
+
186
+ # image
187
+ for key in self.rgb_keys:
188
+ normalizer[key] = get_image_range_normalizer()
189
+ return normalizer
190
+
191
+ def get_all_actions(self) -> torch.Tensor:
192
+ return torch.from_numpy(self.replay_buffer['action'])
193
+
194
+ def __len__(self):
195
+ return len(self.sampler)
196
+
197
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
198
+ threadpool_limits(1)
199
+ data = self.sampler.sample_sequence(idx)
200
+
201
+ # to save RAM, only return first n_obs_steps of OBS
202
+ # since the rest will be discarded anyway.
203
+ # when self.n_obs_steps is None
204
+ # this slice does nothing (takes all)
205
+ T_slice = slice(self.n_obs_steps)
206
+
207
+ obs_dict = dict()
208
+ for key in self.rgb_keys:
209
+ # move channel last to channel first
210
+ # T,H,W,C
211
+ # convert uint8 image to float32
212
+ obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
213
+ ).astype(np.float32) / 255.
214
+ # T,C,H,W
215
+ del data[key]
216
+ for key in self.lowdim_keys:
217
+ obs_dict[key] = data[key][T_slice].astype(np.float32)
218
+ del data[key]
219
+
220
+ torch_data = {
221
+ 'obs': dict_apply(obs_dict, torch.from_numpy),
222
+ 'action': torch.from_numpy(data['action'].astype(np.float32))
223
+ }
224
+ return torch_data
225
+
226
+
227
+ def _convert_actions(raw_actions, abs_action, rotation_transformer):
228
+ actions = raw_actions
229
+ if abs_action:
230
+ is_dual_arm = False
231
+ if raw_actions.shape[-1] == 14:
232
+ # dual arm
233
+ raw_actions = raw_actions.reshape(-1,2,7)
234
+ is_dual_arm = True
235
+
236
+ pos = raw_actions[...,:3]
237
+ rot = raw_actions[...,3:6]
238
+ gripper = raw_actions[...,6:]
239
+ rot = rotation_transformer.forward(rot)
240
+ raw_actions = np.concatenate([
241
+ pos, rot, gripper
242
+ ], axis=-1).astype(np.float32)
243
+
244
+ if is_dual_arm:
245
+ raw_actions = raw_actions.reshape(-1,20)
246
+ actions = raw_actions
247
+ return actions
248
+
249
+
250
+ def _convert_robomimic_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
251
+ n_workers=None, max_inflight_tasks=None, n_demo=100):
252
+ if n_workers is None:
253
+ n_workers = multiprocessing.cpu_count()
254
+ if max_inflight_tasks is None:
255
+ max_inflight_tasks = n_workers * 5
256
+
257
+ # parse shape_meta
258
+ rgb_keys = list()
259
+ lowdim_keys = list()
260
+ # construct compressors and chunks
261
+ obs_shape_meta = shape_meta['obs']
262
+ for key, attr in obs_shape_meta.items():
263
+ shape = attr['shape']
264
+ type = attr.get('type', 'low_dim')
265
+ if type == 'rgb':
266
+ rgb_keys.append(key)
267
+ elif type == 'low_dim':
268
+ lowdim_keys.append(key)
269
+
270
+ root = zarr.group(store)
271
+ data_group = root.require_group('data', overwrite=True)
272
+ meta_group = root.require_group('meta', overwrite=True)
273
+
274
+ with h5py.File(dataset_path) as file:
275
+ # count total steps
276
+ demos = file['data']
277
+ episode_ends = list()
278
+ prev_end = 0
279
+ for i in range(n_demo):
280
+ demo = demos[f'demo_{i}']
281
+ episode_length = demo['actions'].shape[0]
282
+ episode_end = prev_end + episode_length
283
+ prev_end = episode_end
284
+ episode_ends.append(episode_end)
285
+ n_steps = episode_ends[-1]
286
+ episode_starts = [0] + episode_ends[:-1]
287
+ _ = meta_group.array('episode_ends', episode_ends,
288
+ dtype=np.int64, compressor=None, overwrite=True)
289
+
290
+ # save lowdim data
291
+ for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
292
+ data_key = 'obs/' + key
293
+ if key == 'action':
294
+ data_key = 'actions'
295
+ this_data = list()
296
+ for i in range(n_demo):
297
+ demo = demos[f'demo_{i}']
298
+ this_data.append(demo[data_key][:].astype(np.float32))
299
+ this_data = np.concatenate(this_data, axis=0)
300
+ if key == 'action':
301
+ this_data = _convert_actions(
302
+ raw_actions=this_data,
303
+ abs_action=abs_action,
304
+ rotation_transformer=rotation_transformer
305
+ )
306
+ assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
307
+ else:
308
+ assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
309
+ _ = data_group.array(
310
+ name=key,
311
+ data=this_data,
312
+ shape=this_data.shape,
313
+ chunks=this_data.shape,
314
+ compressor=None,
315
+ dtype=this_data.dtype
316
+ )
317
+
318
+ def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
319
+ try:
320
+ zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
321
+ # make sure we can successfully decode
322
+ _ = zarr_arr[zarr_idx]
323
+ return True
324
+ except Exception as e:
325
+ return False
326
+
327
+ with tqdm(total=n_steps*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
328
+ # one chunk per thread, therefore no synchronization needed
329
+ with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
330
+ futures = set()
331
+ for key in rgb_keys:
332
+ data_key = 'obs/' + key
333
+ shape = tuple(shape_meta['obs'][key]['shape'])
334
+ c,h,w = shape
335
+ this_compressor = Jpeg2k(level=50)
336
+ img_arr = data_group.require_dataset(
337
+ name=key,
338
+ shape=(n_steps,h,w,c),
339
+ chunks=(1,h,w,c),
340
+ compressor=this_compressor,
341
+ dtype=np.uint8
342
+ )
343
+ for episode_idx in range(n_demo):
344
+ demo = demos[f'demo_{episode_idx}']
345
+ hdf5_arr = demo['obs'][key]
346
+ for hdf5_idx in range(hdf5_arr.shape[0]):
347
+ if len(futures) >= max_inflight_tasks:
348
+ # limit number of inflight tasks
349
+ completed, futures = concurrent.futures.wait(futures,
350
+ return_when=concurrent.futures.FIRST_COMPLETED)
351
+ for f in completed:
352
+ if not f.result():
353
+ raise RuntimeError('Failed to encode image!')
354
+ pbar.update(len(completed))
355
+
356
+ zarr_idx = episode_starts[episode_idx] + hdf5_idx
357
+ futures.add(
358
+ executor.submit(img_copy,
359
+ img_arr, zarr_idx, hdf5_arr, hdf5_idx))
360
+ completed, futures = concurrent.futures.wait(futures)
361
+ for f in completed:
362
+ if not f.result():
363
+ raise RuntimeError('Failed to encode image!')
364
+ pbar.update(len(completed))
365
+
366
+ replay_buffer = ReplayBuffer(root)
367
+ return replay_buffer
368
+
369
+ def normalizer_from_stat(stat):
370
+ max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
371
+ scale = np.full_like(stat['max'], fill_value=1/max_abs)
372
+ offset = np.zeros_like(stat['max'])
373
+ return SingleFieldLinearNormalizer.create_manual(
374
+ scale=scale,
375
+ offset=offset,
376
+ input_stats_dict=stat
377
+ )
equidiff/equi_diffpo/dataset/robomimic_replay_image_sym_dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from equi_diffpo.dataset.base_dataset import LinearNormalizer
2
+ from equi_diffpo.model.common.normalizer import LinearNormalizer
3
+ from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat
4
+ from equi_diffpo.common.normalize_util import (
5
+ robomimic_abs_action_only_symmetric_normalizer_from_stat,
6
+ get_range_normalizer_from_stat,
7
+ get_range_symmetric_normalizer_from_stat,
8
+ get_image_range_normalizer,
9
+ get_identity_normalizer_from_stat,
10
+ array_to_stats
11
+ )
12
+ import numpy as np
13
+
14
+ class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset):
15
+ def __init__(self,
16
+ shape_meta: dict,
17
+ dataset_path: str,
18
+ horizon=1,
19
+ pad_before=0,
20
+ pad_after=0,
21
+ n_obs_steps=None,
22
+ abs_action=False,
23
+ rotation_rep='rotation_6d', # ignored when abs_action=False
24
+ use_legacy_normalizer=False,
25
+ use_cache=False,
26
+ seed=42,
27
+ val_ratio=0.0,
28
+ n_demo=100
29
+ ):
30
+ super().__init__(
31
+ shape_meta,
32
+ dataset_path,
33
+ horizon,
34
+ pad_before,
35
+ pad_after,
36
+ n_obs_steps,
37
+ abs_action,
38
+ rotation_rep,
39
+ use_legacy_normalizer,
40
+ use_cache,
41
+ seed,
42
+ val_ratio,
43
+ n_demo
44
+ )
45
+
46
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
47
+ normalizer = LinearNormalizer()
48
+
49
+ # action
50
+ stat = array_to_stats(self.replay_buffer['action'])
51
+ if self.abs_action:
52
+ if stat['mean'].shape[-1] > 10:
53
+ # dual arm
54
+ raise NotImplementedError
55
+ else:
56
+ this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
57
+
58
+ if self.use_legacy_normalizer:
59
+ this_normalizer = normalizer_from_stat(stat)
60
+ else:
61
+ # already normalized
62
+ this_normalizer = get_identity_normalizer_from_stat(stat)
63
+ normalizer['action'] = this_normalizer
64
+
65
+ # obs
66
+ for key in self.lowdim_keys:
67
+ stat = array_to_stats(self.replay_buffer[key])
68
+
69
+ if key.endswith('qpos'):
70
+ this_normalizer = get_range_normalizer_from_stat(stat)
71
+ elif key.endswith('pos'):
72
+ this_normalizer = get_range_symmetric_normalizer_from_stat(stat)
73
+ elif key.endswith('quat'):
74
+ # quaternion is in [-1,1] already
75
+ this_normalizer = get_identity_normalizer_from_stat(stat)
76
+ elif key.find('bbox') > -1:
77
+ this_normalizer = get_identity_normalizer_from_stat(stat)
78
+ else:
79
+ raise RuntimeError('unsupported')
80
+ normalizer[key] = this_normalizer
81
+
82
+ # image
83
+ for key in self.rgb_keys:
84
+ normalizer[key] = get_image_range_normalizer()
85
+
86
+ normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)})
87
+ normalizer['crops'] = get_image_range_normalizer()
88
+
89
+ return normalizer
90
+
equidiff/equi_diffpo/dataset/robomimic_replay_lowdim_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ from tqdm import tqdm
6
+ import copy
7
+ from equi_diffpo.common.pytorch_util import dict_apply
8
+ from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
9
+ from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
10
+ from equi_diffpo.model.common.rotation_transformer import RotationTransformer
11
+ from equi_diffpo.common.replay_buffer import ReplayBuffer
12
+ from equi_diffpo.common.sampler import (
13
+ SequenceSampler, get_val_mask, downsample_mask)
14
+ from equi_diffpo.common.normalize_util import (
15
+ robomimic_abs_action_only_normalizer_from_stat,
16
+ robomimic_abs_action_only_dual_arm_normalizer_from_stat,
17
+ get_identity_normalizer_from_stat,
18
+ array_to_stats
19
+ )
20
+
21
+ class RobomimicReplayLowdimDataset(BaseLowdimDataset):
22
+ def __init__(self,
23
+ dataset_path: str,
24
+ horizon=1,
25
+ pad_before=0,
26
+ pad_after=0,
27
+ obs_keys: List[str]=[
28
+ 'object',
29
+ 'robot0_eef_pos',
30
+ 'robot0_eef_quat',
31
+ 'robot0_gripper_qpos'],
32
+ abs_action=False,
33
+ rotation_rep='rotation_6d',
34
+ use_legacy_normalizer=False,
35
+ seed=42,
36
+ val_ratio=0.0,
37
+ max_train_episodes=None,
38
+ n_demo=100
39
+ ):
40
+ obs_keys = list(obs_keys)
41
+ rotation_transformer = RotationTransformer(
42
+ from_rep='axis_angle', to_rep=rotation_rep)
43
+
44
+ replay_buffer = ReplayBuffer.create_empty_numpy()
45
+ with h5py.File(dataset_path) as file:
46
+ demos = file['data']
47
+ for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"):
48
+ demo = demos[f'demo_{i}']
49
+ episode = _data_to_obs(
50
+ raw_obs=demo['obs'],
51
+ raw_actions=demo['actions'][:].astype(np.float32),
52
+ obs_keys=obs_keys,
53
+ abs_action=abs_action,
54
+ rotation_transformer=rotation_transformer)
55
+ replay_buffer.add_episode(episode)
56
+
57
+ val_mask = get_val_mask(
58
+ n_episodes=replay_buffer.n_episodes,
59
+ val_ratio=val_ratio,
60
+ seed=seed)
61
+ train_mask = ~val_mask
62
+ train_mask = downsample_mask(
63
+ mask=train_mask,
64
+ max_n=max_train_episodes,
65
+ seed=seed)
66
+
67
+ sampler = SequenceSampler(
68
+ replay_buffer=replay_buffer,
69
+ sequence_length=horizon,
70
+ pad_before=pad_before,
71
+ pad_after=pad_after,
72
+ episode_mask=train_mask)
73
+
74
+ self.replay_buffer = replay_buffer
75
+ self.sampler = sampler
76
+ self.abs_action = abs_action
77
+ self.train_mask = train_mask
78
+ self.horizon = horizon
79
+ self.pad_before = pad_before
80
+ self.pad_after = pad_after
81
+ self.use_legacy_normalizer = use_legacy_normalizer
82
+
83
+ def get_validation_dataset(self):
84
+ val_set = copy.copy(self)
85
+ val_set.sampler = SequenceSampler(
86
+ replay_buffer=self.replay_buffer,
87
+ sequence_length=self.horizon,
88
+ pad_before=self.pad_before,
89
+ pad_after=self.pad_after,
90
+ episode_mask=~self.train_mask
91
+ )
92
+ val_set.train_mask = ~self.train_mask
93
+ return val_set
94
+
95
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
96
+ normalizer = LinearNormalizer()
97
+
98
+ # action
99
+ stat = array_to_stats(self.replay_buffer['action'])
100
+ if self.abs_action:
101
+ if stat['mean'].shape[-1] > 10:
102
+ # dual arm
103
+ this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
104
+ else:
105
+ this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
106
+
107
+ if self.use_legacy_normalizer:
108
+ this_normalizer = normalizer_from_stat(stat)
109
+ else:
110
+ # already normalized
111
+ this_normalizer = get_identity_normalizer_from_stat(stat)
112
+ normalizer['action'] = this_normalizer
113
+
114
+ # aggregate obs stats
115
+ obs_stat = array_to_stats(self.replay_buffer['obs'])
116
+
117
+
118
+ normalizer['obs'] = normalizer_from_stat(obs_stat)
119
+ return normalizer
120
+
121
+ def get_all_actions(self) -> torch.Tensor:
122
+ return torch.from_numpy(self.replay_buffer['action'])
123
+
124
+ def __len__(self):
125
+ return len(self.sampler)
126
+
127
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
128
+ data = self.sampler.sample_sequence(idx)
129
+ torch_data = dict_apply(data, torch.from_numpy)
130
+ return torch_data
131
+
132
+ def normalizer_from_stat(stat):
133
+ max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
134
+ scale = np.full_like(stat['max'], fill_value=1/max_abs)
135
+ offset = np.zeros_like(stat['max'])
136
+ return SingleFieldLinearNormalizer.create_manual(
137
+ scale=scale,
138
+ offset=offset,
139
+ input_stats_dict=stat
140
+ )
141
+
142
+ def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
143
+ obs = np.concatenate([
144
+ raw_obs[key] for key in obs_keys
145
+ ], axis=-1).astype(np.float32)
146
+
147
+ if abs_action:
148
+ is_dual_arm = False
149
+ if raw_actions.shape[-1] == 14:
150
+ # dual arm
151
+ raw_actions = raw_actions.reshape(-1,2,7)
152
+ is_dual_arm = True
153
+
154
+ pos = raw_actions[...,:3]
155
+ rot = raw_actions[...,3:6]
156
+ gripper = raw_actions[...,6:]
157
+ rot = rotation_transformer.forward(rot)
158
+ raw_actions = np.concatenate([
159
+ pos, rot, gripper
160
+ ], axis=-1).astype(np.float32)
161
+
162
+ if is_dual_arm:
163
+ raw_actions = raw_actions.reshape(-1,20)
164
+
165
+ data = {
166
+ 'obs': obs,
167
+ 'action': raw_actions
168
+ }
169
+ return data
equidiff/equi_diffpo/dataset/robomimic_replay_lowdim_sym_dataset.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ import numpy as np
4
+ from equi_diffpo.common.pytorch_util import dict_apply
5
+ from equi_diffpo.dataset.base_dataset import LinearNormalizer
6
+ from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat
7
+ from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat
8
+ from equi_diffpo.common.normalize_util import (
9
+ robomimic_abs_action_only_symmetric_normalizer_from_stat,
10
+ get_identity_normalizer_from_stat,
11
+ array_to_stats
12
+ )
13
+
14
+
15
+ class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset):
16
+ def __init__(self,
17
+ dataset_path: str,
18
+ horizon=1,
19
+ pad_before=0,
20
+ pad_after=0,
21
+ obs_keys: List[str]=[
22
+ 'object',
23
+ 'robot0_eef_pos',
24
+ 'robot0_eef_quat',
25
+ 'robot0_gripper_qpos'],
26
+ abs_action=False,
27
+ rotation_rep='rotation_6d',
28
+ use_legacy_normalizer=False,
29
+ seed=42,
30
+ val_ratio=0.0,
31
+ max_train_episodes=None,
32
+ n_demo=100
33
+ ):
34
+ super().__init__(
35
+ dataset_path,
36
+ horizon,
37
+ pad_before,
38
+ pad_after,
39
+ obs_keys,
40
+ abs_action,
41
+ rotation_rep,
42
+ use_legacy_normalizer,
43
+ seed,
44
+ val_ratio,
45
+ max_train_episodes,
46
+ n_demo,
47
+ )
48
+
49
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
50
+ normalizer = LinearNormalizer()
51
+
52
+ # action
53
+ stat = array_to_stats(self.replay_buffer['action'])
54
+ if self.abs_action:
55
+ if stat['mean'].shape[-1] > 10:
56
+ # dual arm
57
+ raise NotImplementedError
58
+ else:
59
+ this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
60
+
61
+ if self.use_legacy_normalizer:
62
+ this_normalizer = normalizer_from_stat(stat)
63
+ else:
64
+ # already normalized
65
+ this_normalizer = get_identity_normalizer_from_stat(stat)
66
+ normalizer['action'] = this_normalizer
67
+
68
+ # aggregate obs stats
69
+ obs_stat = array_to_stats(self.replay_buffer['obs'])
70
+
71
+
72
+ normalizer['obs'] = normalizer_from_stat(obs_stat)
73
+ return normalizer
equidiff/equi_diffpo/dataset/robomimic_replay_point_cloud_dataset.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ from tqdm import tqdm
6
+ import zarr
7
+ import os
8
+ import shutil
9
+ import copy
10
+ import json
11
+ import hashlib
12
+ from filelock import FileLock
13
+ from threadpoolctl import threadpool_limits
14
+ import concurrent.futures
15
+ import multiprocessing
16
+ from omegaconf import OmegaConf
17
+ from equi_diffpo.common.pytorch_util import dict_apply
18
+ from equi_diffpo.dataset.base_dataset import BaseImageDataset, LinearNormalizer
19
+ from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
20
+ from equi_diffpo.model.common.rotation_transformer import RotationTransformer
21
+ from equi_diffpo.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
22
+ from equi_diffpo.common.replay_buffer import ReplayBuffer
23
+ from equi_diffpo.common.sampler import SequenceSampler, get_val_mask
24
+ from equi_diffpo.common.normalize_util import (
25
+ robomimic_abs_action_only_normalizer_from_stat,
26
+ get_range_normalizer_from_stat,
27
+ get_voxel_identity_normalizer,
28
+ get_image_range_normalizer,
29
+ get_identity_normalizer_from_stat,
30
+ array_to_stats
31
+ )
32
+ register_codecs()
33
+
34
+ class RobomimicReplayPointCloudDataset(BaseImageDataset):
35
+ def __init__(self,
36
+ shape_meta: dict,
37
+ dataset_path: str,
38
+ horizon=1,
39
+ pad_before=0,
40
+ pad_after=0,
41
+ n_obs_steps=None,
42
+ abs_action=False,
43
+ rotation_rep='rotation_6d', # ignored when abs_action=False
44
+ use_legacy_normalizer=False,
45
+ use_cache=False,
46
+ seed=42,
47
+ val_ratio=0.0,
48
+ n_demo=100,
49
+ ):
50
+ self.n_demo = n_demo
51
+ rotation_transformer = RotationTransformer(
52
+ from_rep='axis_angle', to_rep=rotation_rep)
53
+
54
+ replay_buffer = None
55
+ if use_cache:
56
+ cache_zarr_path = dataset_path + f'.{n_demo}.' + '.zarr.zip'
57
+ cache_lock_path = cache_zarr_path + '.lock'
58
+ print('Acquiring lock on cache.')
59
+ with FileLock(cache_lock_path):
60
+ if not os.path.exists(cache_zarr_path):
61
+ # cache does not exists
62
+ try:
63
+ print('Cache does not exist. Creating!')
64
+ # store = zarr.DirectoryStore(cache_zarr_path)
65
+ replay_buffer = _convert_point_cloud_to_replay(
66
+ store=zarr.MemoryStore(),
67
+ shape_meta=shape_meta,
68
+ dataset_path=dataset_path,
69
+ abs_action=abs_action,
70
+ rotation_transformer=rotation_transformer,
71
+ n_demo=n_demo)
72
+ print('Saving cache to disk.')
73
+ with zarr.ZipStore(cache_zarr_path) as zip_store:
74
+ replay_buffer.save_to_store(
75
+ store=zip_store
76
+ )
77
+ except Exception as e:
78
+ shutil.rmtree(cache_zarr_path)
79
+ raise e
80
+ else:
81
+ print('Loading cached ReplayBuffer from Disk.')
82
+ with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
83
+ replay_buffer = ReplayBuffer.copy_from_store(
84
+ src_store=zip_store, store=zarr.MemoryStore())
85
+ print('Loaded!')
86
+ else:
87
+ replay_buffer = _convert_point_cloud_to_replay(
88
+ store=zarr.MemoryStore(),
89
+ shape_meta=shape_meta,
90
+ dataset_path=dataset_path,
91
+ abs_action=abs_action,
92
+ rotation_transformer=rotation_transformer,
93
+ n_demo=n_demo)
94
+
95
+ rgb_keys = list()
96
+ pc_keys = list()
97
+ lowdim_keys = list()
98
+ obs_shape_meta = shape_meta['obs']
99
+ for key, attr in obs_shape_meta.items():
100
+ type = attr.get('type', 'low_dim')
101
+ if type == 'rgb':
102
+ rgb_keys.append(key)
103
+ if type == 'point_cloud':
104
+ pc_keys.append(key)
105
+ elif type == 'low_dim':
106
+ lowdim_keys.append(key)
107
+
108
+ # for key in rgb_keys:
109
+ # replay_buffer[key].compressor.numthreads=1
110
+
111
+ key_first_k = dict()
112
+ if n_obs_steps is not None:
113
+ # only take first k obs from images
114
+ for key in rgb_keys + pc_keys + lowdim_keys:
115
+ key_first_k[key] = n_obs_steps
116
+
117
+ val_mask = get_val_mask(
118
+ n_episodes=replay_buffer.n_episodes,
119
+ val_ratio=val_ratio,
120
+ seed=seed)
121
+ train_mask = ~val_mask
122
+ sampler = SequenceSampler(
123
+ replay_buffer=replay_buffer,
124
+ sequence_length=horizon,
125
+ pad_before=pad_before,
126
+ pad_after=pad_after,
127
+ episode_mask=train_mask,
128
+ key_first_k=key_first_k)
129
+
130
+ self.replay_buffer = replay_buffer
131
+ self.sampler = sampler
132
+ self.shape_meta = shape_meta
133
+ self.rgb_keys = rgb_keys
134
+ self.pc_keys = pc_keys
135
+ self.lowdim_keys = lowdim_keys
136
+ self.abs_action = abs_action
137
+ self.n_obs_steps = n_obs_steps
138
+ self.train_mask = train_mask
139
+ self.horizon = horizon
140
+ self.pad_before = pad_before
141
+ self.pad_after = pad_after
142
+ self.use_legacy_normalizer = use_legacy_normalizer
143
+
144
+ def get_validation_dataset(self):
145
+ val_set = copy.copy(self)
146
+ val_set.sampler = SequenceSampler(
147
+ replay_buffer=self.replay_buffer,
148
+ sequence_length=self.horizon,
149
+ pad_before=self.pad_before,
150
+ pad_after=self.pad_after,
151
+ episode_mask=~self.train_mask
152
+ )
153
+ val_set.train_mask = ~self.train_mask
154
+ return val_set
155
+
156
+ def get_normalizer(self, mode='limits', **kwargs) -> LinearNormalizer:
157
+ data = {
158
+ 'action': self.replay_buffer['action'],
159
+ 'robot0_eef_pos': self.replay_buffer['robot0_eef_pos'][...,:],
160
+ 'robot0_eef_quat': self.replay_buffer['robot0_eef_quat'][...,:],
161
+ 'robot0_gripper_qpos': self.replay_buffer['robot0_gripper_qpos'][...,:],
162
+ 'point_cloud': self.replay_buffer['point_cloud'],
163
+ }
164
+ normalizer = LinearNormalizer()
165
+ normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
166
+ # normalizer['point_cloud'] = SingleFieldLinearNormalizer.create_identity()
167
+ return normalizer
168
+
169
+ def get_all_actions(self) -> torch.Tensor:
170
+ return torch.from_numpy(self.replay_buffer['action'])
171
+
172
+ def __len__(self):
173
+ return len(self.sampler)
174
+
175
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
176
+ threadpool_limits(1)
177
+ data = self.sampler.sample_sequence(idx)
178
+
179
+ # to save RAM, only return first n_obs_steps of OBS
180
+ # since the rest will be discarded anyway.
181
+ # when self.n_obs_steps is None
182
+ # this slice does nothing (takes all)
183
+ T_slice = slice(self.n_obs_steps)
184
+
185
+ obs_dict = dict()
186
+ for key in self.rgb_keys:
187
+ # move channel last to channel first
188
+ # T,H,W,C
189
+ # convert uint8 image to float32
190
+ obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
191
+ ).astype(np.float32) / 255.
192
+ # T,C,H,W
193
+ del data[key]
194
+ for key in self.pc_keys:
195
+ obs_dict[key] = data[key][T_slice].astype(np.float32)
196
+ del data[key]
197
+ for key in self.lowdim_keys:
198
+ obs_dict[key] = data[key][T_slice].astype(np.float32)
199
+ del data[key]
200
+
201
+ torch_data = {
202
+ 'obs': dict_apply(obs_dict, torch.from_numpy),
203
+ 'action': torch.from_numpy(data['action'].astype(np.float32))
204
+ }
205
+ return torch_data
206
+
207
+
208
+ def _convert_actions(raw_actions, abs_action, rotation_transformer):
209
+ actions = raw_actions
210
+ if abs_action:
211
+ is_dual_arm = False
212
+ if raw_actions.shape[-1] == 14:
213
+ # dual arm
214
+ raw_actions = raw_actions.reshape(-1,2,7)
215
+ is_dual_arm = True
216
+
217
+ pos = raw_actions[...,:3]
218
+ rot = raw_actions[...,3:6]
219
+ gripper = raw_actions[...,6:]
220
+ rot = rotation_transformer.forward(rot)
221
+ raw_actions = np.concatenate([
222
+ pos, rot, gripper
223
+ ], axis=-1).astype(np.float32)
224
+
225
+ if is_dual_arm:
226
+ raw_actions = raw_actions.reshape(-1,20)
227
+ actions = raw_actions
228
+ return actions
229
+
230
+
231
+ def _convert_point_cloud_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
232
+ n_workers=None, max_inflight_tasks=None, n_demo=100):
233
+ if n_workers is None:
234
+ n_workers = multiprocessing.cpu_count()
235
+ if max_inflight_tasks is None:
236
+ max_inflight_tasks = n_workers * 5
237
+
238
+ # parse shape_meta
239
+ pc_keys = list()
240
+ rgb_keys = list()
241
+ lowdim_keys = list()
242
+ # construct compressors and chunks
243
+ obs_shape_meta = shape_meta['obs']
244
+ for key, attr in obs_shape_meta.items():
245
+ shape = attr['shape']
246
+ type = attr.get('type', 'low_dim')
247
+ if type == 'rgb':
248
+ rgb_keys.append(key)
249
+ elif type == 'point_cloud':
250
+ pc_keys.append(key)
251
+ elif type == 'low_dim':
252
+ lowdim_keys.append(key)
253
+
254
+ root = zarr.group(store)
255
+ data_group = root.require_group('data', overwrite=True)
256
+ meta_group = root.require_group('meta', overwrite=True)
257
+
258
+ with h5py.File(dataset_path) as file:
259
+ # count total steps
260
+ demos = file['data']
261
+ episode_ends = list()
262
+ prev_end = 0
263
+ n_demo = min(n_demo, len(demos))
264
+ for i in range(n_demo):
265
+ demo = demos[f'demo_{i}']
266
+ episode_length = demo['actions'].shape[0]
267
+ episode_end = prev_end + episode_length
268
+ prev_end = episode_end
269
+ episode_ends.append(episode_end)
270
+ n_steps = episode_ends[-1]
271
+ episode_starts = [0] + episode_ends[:-1]
272
+ _ = meta_group.array('episode_ends', episode_ends,
273
+ dtype=np.int64, compressor=None, overwrite=True)
274
+
275
+ # save lowdim data
276
+ for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
277
+ data_key = 'obs/' + key
278
+ if key == 'action':
279
+ data_key = 'actions'
280
+ this_data = list()
281
+ for i in range(n_demo):
282
+ demo = demos[f'demo_{i}']
283
+ this_data.append(demo[data_key][:].astype(np.float32))
284
+ this_data = np.concatenate(this_data, axis=0)
285
+ if key == 'action':
286
+ this_data = _convert_actions(
287
+ raw_actions=this_data,
288
+ abs_action=abs_action,
289
+ rotation_transformer=rotation_transformer
290
+ )
291
+ assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
292
+ else:
293
+ assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
294
+ _ = data_group.array(
295
+ name=key,
296
+ data=this_data,
297
+ shape=this_data.shape,
298
+ chunks=this_data.shape,
299
+ compressor=None,
300
+ dtype=this_data.dtype
301
+ )
302
+
303
+ def pc_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
304
+ try:
305
+ zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
306
+ _ = zarr_arr[zarr_idx]
307
+ return True
308
+ except Exception as e:
309
+ return False
310
+
311
+ def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
312
+ try:
313
+ zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
314
+ # make sure we can successfully decode
315
+ _ = zarr_arr[zarr_idx]
316
+ return True
317
+ except Exception as e:
318
+ return False
319
+
320
+ with tqdm(total=n_steps*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
321
+ # one chunk per thread, therefore no synchronization needed
322
+ with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
323
+ futures = set()
324
+ for key in rgb_keys:
325
+ data_key = 'obs/' + key
326
+ shape = tuple(shape_meta['obs'][key]['shape'])
327
+ c,h,w = shape
328
+ this_compressor = Jpeg2k(level=50)
329
+ img_arr = data_group.require_dataset(
330
+ name=key,
331
+ shape=(n_steps,h,w,c),
332
+ chunks=(1,h,w,c),
333
+ compressor=this_compressor,
334
+ dtype=np.uint8
335
+ )
336
+ for episode_idx in range(n_demo):
337
+ demo = demos[f'demo_{episode_idx}']
338
+ hdf5_arr = demo['obs'][key]
339
+ for hdf5_idx in range(hdf5_arr.shape[0]):
340
+ if len(futures) >= max_inflight_tasks:
341
+ # limit number of inflight tasks
342
+ completed, futures = concurrent.futures.wait(futures,
343
+ return_when=concurrent.futures.FIRST_COMPLETED)
344
+ for f in completed:
345
+ if not f.result():
346
+ raise RuntimeError('Failed to encode image!')
347
+ pbar.update(len(completed))
348
+
349
+ zarr_idx = episode_starts[episode_idx] + hdf5_idx
350
+ futures.add(
351
+ executor.submit(img_copy,
352
+ img_arr, zarr_idx, hdf5_arr, hdf5_idx))
353
+ completed, futures = concurrent.futures.wait(futures)
354
+ for f in completed:
355
+ if not f.result():
356
+ raise RuntimeError('Failed to encode image!')
357
+ pbar.update(len(completed))
358
+
359
+ with tqdm(total=n_steps*len(pc_keys), desc="Loading point cloud data", mininterval=1.0) as pbar:
360
+ # one chunk per thread, therefore no synchronization needed
361
+ with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
362
+ futures = set()
363
+ for key in pc_keys:
364
+ data_key = key
365
+ shape = tuple(shape_meta['obs'][key]['shape'])
366
+ n, c = shape
367
+ img_arr = data_group.require_dataset(
368
+ name=key,
369
+ shape=(n_steps, n, c),
370
+ chunks=(1, n, c),
371
+ dtype=np.float32
372
+ )
373
+ for episode_idx in range(n_demo):
374
+ demo = demos[f'demo_{episode_idx}']
375
+ hdf5_arr = demo['obs'][key]
376
+ for hdf5_idx in range(hdf5_arr.shape[0]):
377
+ if len(futures) >= max_inflight_tasks:
378
+ # limit number of inflight tasks
379
+ completed, futures = concurrent.futures.wait(futures,
380
+ return_when=concurrent.futures.FIRST_COMPLETED)
381
+ for f in completed:
382
+ if not f.result():
383
+ raise RuntimeError('Failed to encode image!')
384
+ pbar.update(len(completed))
385
+
386
+ zarr_idx = episode_starts[episode_idx] + hdf5_idx
387
+ futures.add(
388
+ executor.submit(pc_copy,
389
+ img_arr, zarr_idx, hdf5_arr, hdf5_idx))
390
+ completed, futures = concurrent.futures.wait(futures)
391
+ for f in completed:
392
+ if not f.result():
393
+ raise RuntimeError('Failed to encode image!')
394
+ pbar.update(len(completed))
395
+
396
+ replay_buffer = ReplayBuffer(root)
397
+ return replay_buffer
398
+
399
+ def normalizer_from_stat(stat):
400
+ max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
401
+ scale = np.full_like(stat['max'], fill_value=1/max_abs)
402
+ offset = np.zeros_like(stat['max'])
403
+ return SingleFieldLinearNormalizer.create_manual(
404
+ scale=scale,
405
+ offset=offset,
406
+ input_stats_dict=stat
407
+ )
equidiff/equi_diffpo/dataset/robomimic_replay_voxel_sym_dataset.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ from tqdm import tqdm
6
+ import zarr
7
+ import os
8
+ import shutil
9
+ import copy
10
+ import json
11
+ import hashlib
12
+ from filelock import FileLock
13
+ from threadpoolctl import threadpool_limits
14
+ import concurrent.futures
15
+ import multiprocessing
16
+ from omegaconf import OmegaConf
17
+ from equi_diffpo.common.pytorch_util import dict_apply
18
+ from equi_diffpo.dataset.base_dataset import BaseImageDataset, LinearNormalizer
19
+ from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
20
+ from equi_diffpo.model.common.rotation_transformer import RotationTransformer
21
+ from equi_diffpo.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
22
+ from equi_diffpo.common.replay_buffer import ReplayBuffer
23
+ from equi_diffpo.common.sampler import SequenceSampler, get_val_mask
24
+ from equi_diffpo.common.normalize_util import (
25
+ robomimic_abs_action_only_normalizer_from_stat,
26
+ get_range_normalizer_from_stat,
27
+ get_voxel_identity_normalizer,
28
+ get_image_range_normalizer,
29
+ get_identity_normalizer_from_stat,
30
+ array_to_stats
31
+ )
32
+ register_codecs()
33
+
34
+ class RobomimicReplayVoxelSymDataset(BaseImageDataset):
35
+ def __init__(self,
36
+ shape_meta: dict,
37
+ dataset_path: str,
38
+ horizon=1,
39
+ pad_before=0,
40
+ pad_after=0,
41
+ n_obs_steps=None,
42
+ abs_action=False,
43
+ rotation_rep='rotation_6d', # ignored when abs_action=False
44
+ use_legacy_normalizer=False,
45
+ use_cache=False,
46
+ seed=42,
47
+ val_ratio=0.0,
48
+ n_demo=100,
49
+ ws_size=0.6,
50
+ ws_x_center=0,
51
+ ws_y_center=0,
52
+ ):
53
+ self.n_demo = n_demo
54
+ self.ws_size = ws_size
55
+ self.ws_center = np.array([ws_x_center, ws_y_center])
56
+ rotation_transformer = RotationTransformer(
57
+ from_rep='axis_angle', to_rep=rotation_rep)
58
+
59
+ replay_buffer = None
60
+ if use_cache:
61
+ cache_zarr_path = dataset_path + f'.{n_demo}.' + '.zarr.zip'
62
+ cache_lock_path = cache_zarr_path + '.lock'
63
+ print('Acquiring lock on cache.')
64
+ with FileLock(cache_lock_path):
65
+ if not os.path.exists(cache_zarr_path):
66
+ # cache does not exists
67
+ try:
68
+ print('Cache does not exist. Creating!')
69
+ # store = zarr.DirectoryStore(cache_zarr_path)
70
+ replay_buffer = _convert_voxel_to_replay(
71
+ store=zarr.MemoryStore(),
72
+ shape_meta=shape_meta,
73
+ dataset_path=dataset_path,
74
+ abs_action=abs_action,
75
+ rotation_transformer=rotation_transformer,
76
+ n_demo=n_demo)
77
+ print('Saving cache to disk.')
78
+ with zarr.ZipStore(cache_zarr_path) as zip_store:
79
+ replay_buffer.save_to_store(
80
+ store=zip_store
81
+ )
82
+ except Exception as e:
83
+ shutil.rmtree(cache_zarr_path)
84
+ raise e
85
+ else:
86
+ print('Loading cached ReplayBuffer from Disk.')
87
+ with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
88
+ replay_buffer = ReplayBuffer.copy_from_store(
89
+ src_store=zip_store, store=zarr.MemoryStore())
90
+ print('Loaded!')
91
+ else:
92
+ replay_buffer = _convert_voxel_to_replay(
93
+ store=zarr.MemoryStore(),
94
+ shape_meta=shape_meta,
95
+ dataset_path=dataset_path,
96
+ abs_action=abs_action,
97
+ rotation_transformer=rotation_transformer,
98
+ n_demo=n_demo)
99
+
100
+ rgb_keys = list()
101
+ voxel_keys = list()
102
+ lowdim_keys = list()
103
+ obs_shape_meta = shape_meta['obs']
104
+ for key, attr in obs_shape_meta.items():
105
+ type = attr.get('type', 'low_dim')
106
+ if type == 'rgb':
107
+ rgb_keys.append(key)
108
+ if type == 'voxel':
109
+ voxel_keys.append(key)
110
+ elif type == 'low_dim':
111
+ lowdim_keys.append(key)
112
+
113
+ # for key in rgb_keys:
114
+ # replay_buffer[key].compressor.numthreads=1
115
+
116
+ key_first_k = dict()
117
+ if n_obs_steps is not None:
118
+ # only take first k obs from images
119
+ for key in rgb_keys + voxel_keys + lowdim_keys:
120
+ key_first_k[key] = n_obs_steps
121
+
122
+ val_mask = get_val_mask(
123
+ n_episodes=replay_buffer.n_episodes,
124
+ val_ratio=val_ratio,
125
+ seed=seed)
126
+ train_mask = ~val_mask
127
+ sampler = SequenceSampler(
128
+ replay_buffer=replay_buffer,
129
+ sequence_length=horizon,
130
+ pad_before=pad_before,
131
+ pad_after=pad_after,
132
+ episode_mask=train_mask,
133
+ key_first_k=key_first_k)
134
+
135
+ self.replay_buffer = replay_buffer
136
+ self.sampler = sampler
137
+ self.shape_meta = shape_meta
138
+ self.rgb_keys = rgb_keys
139
+ self.voxel_keys = voxel_keys
140
+ self.lowdim_keys = lowdim_keys
141
+ self.abs_action = abs_action
142
+ self.n_obs_steps = n_obs_steps
143
+ self.train_mask = train_mask
144
+ self.horizon = horizon
145
+ self.pad_before = pad_before
146
+ self.pad_after = pad_after
147
+ self.use_legacy_normalizer = use_legacy_normalizer
148
+
149
+ def get_validation_dataset(self):
150
+ val_set = copy.copy(self)
151
+ val_set.sampler = SequenceSampler(
152
+ replay_buffer=self.replay_buffer,
153
+ sequence_length=self.horizon,
154
+ pad_before=self.pad_before,
155
+ pad_after=self.pad_after,
156
+ episode_mask=~self.train_mask
157
+ )
158
+ val_set.train_mask = ~self.train_mask
159
+ return val_set
160
+
161
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
162
+ normalizer = LinearNormalizer()
163
+ # action
164
+ stat = array_to_stats(self.replay_buffer['action'])
165
+ if self.abs_action:
166
+ if stat['mean'].shape[-1] > 10:
167
+ # dual arm
168
+ raise NotImplementedError
169
+ else:
170
+ magnitute = max(np.max([stat['max'][:2] - self.ws_center, self.ws_center - stat['min'][:2]]), self.ws_size/2)
171
+ stat['min'][:2] = self.ws_center - magnitute
172
+ stat['max'][:2] = self.ws_center + magnitute
173
+ stat['mean'][:2] = self.ws_center
174
+ this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
175
+
176
+ if self.use_legacy_normalizer:
177
+ this_normalizer = normalizer_from_stat(stat)
178
+ else:
179
+ # already normalized
180
+ this_normalizer = get_identity_normalizer_from_stat(stat)
181
+ normalizer['action'] = this_normalizer
182
+
183
+ # obs
184
+ for key in self.lowdim_keys:
185
+ stat = array_to_stats(self.replay_buffer[key])
186
+
187
+ if key.endswith('qpos'):
188
+ this_normalizer = get_range_normalizer_from_stat(stat)
189
+ elif key.endswith('pos'):
190
+ magnitute = max(np.max([stat['max'][:2] - self.ws_center, self.ws_center - stat['min'][:2]]), self.ws_size/2)
191
+ stat['min'][:2] = self.ws_center - magnitute
192
+ stat['max'][:2] = self.ws_center + magnitute
193
+ stat['mean'][:2] = self.ws_center
194
+ this_normalizer = get_range_normalizer_from_stat(stat)
195
+ elif key.endswith('quat'):
196
+ # quaternion is in [-1,1] already
197
+ this_normalizer = get_identity_normalizer_from_stat(stat)
198
+ else:
199
+ raise RuntimeError('unsupported')
200
+ normalizer[key] = this_normalizer
201
+
202
+ # image
203
+ for key in self.rgb_keys:
204
+ normalizer[key] = get_image_range_normalizer()
205
+ for key in self.voxel_keys:
206
+ normalizer[key] = get_voxel_identity_normalizer()
207
+
208
+ return normalizer
209
+
210
+ def get_all_actions(self) -> torch.Tensor:
211
+ return torch.from_numpy(self.replay_buffer['action'])
212
+
213
+ def __len__(self):
214
+ return len(self.sampler)
215
+
216
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
217
+ threadpool_limits(1)
218
+ data = self.sampler.sample_sequence(idx)
219
+
220
+ # to save RAM, only return first n_obs_steps of OBS
221
+ # since the rest will be discarded anyway.
222
+ # when self.n_obs_steps is None
223
+ # this slice does nothing (takes all)
224
+ T_slice = slice(self.n_obs_steps)
225
+
226
+ obs_dict = dict()
227
+ for key in self.rgb_keys:
228
+ # move channel last to channel first
229
+ # T,H,W,C
230
+ # convert uint8 image to float32
231
+ obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
232
+ ).astype(np.float32) / 255.
233
+ # T,C,H,W
234
+ del data[key]
235
+ for key in self.voxel_keys:
236
+ obs_dict[key] = data[key][T_slice].astype(np.float32)
237
+ obs_dict[key][:, 1:] /= 255.
238
+ # # convert uint8 image to float32
239
+ # voxels = np.moveaxis(data[key][T_slice].astype(np.float32), [0, 1, 2, 3, 4], [0, 1, 4, 3, 2])
240
+ # voxels = np.flip(voxels, (2, 3))
241
+ # voxels[:, 1:] /= 255.
242
+ # obs_dict[key] = voxels.copy()
243
+ del data[key]
244
+ for key in self.lowdim_keys:
245
+ obs_dict[key] = data[key][T_slice].astype(np.float32)
246
+ del data[key]
247
+
248
+ torch_data = {
249
+ 'obs': dict_apply(obs_dict, torch.from_numpy),
250
+ 'action': torch.from_numpy(data['action'].astype(np.float32))
251
+ }
252
+ return torch_data
253
+
254
+
255
+ def _convert_actions(raw_actions, abs_action, rotation_transformer):
256
+ actions = raw_actions
257
+ if abs_action:
258
+ is_dual_arm = False
259
+ if raw_actions.shape[-1] == 14:
260
+ # dual arm
261
+ raw_actions = raw_actions.reshape(-1,2,7)
262
+ is_dual_arm = True
263
+
264
+ pos = raw_actions[...,:3]
265
+ rot = raw_actions[...,3:6]
266
+ gripper = raw_actions[...,6:]
267
+ rot = rotation_transformer.forward(rot)
268
+ raw_actions = np.concatenate([
269
+ pos, rot, gripper
270
+ ], axis=-1).astype(np.float32)
271
+
272
+ if is_dual_arm:
273
+ raw_actions = raw_actions.reshape(-1,20)
274
+ actions = raw_actions
275
+ return actions
276
+
277
+
278
+ def _convert_voxel_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer,
279
+ n_workers=None, max_inflight_tasks=None, n_demo=100):
280
+ if n_workers is None:
281
+ n_workers = 24
282
+ if max_inflight_tasks is None:
283
+ max_inflight_tasks = n_workers * 5
284
+
285
+ # parse shape_meta
286
+ voxel_keys = list()
287
+ rgb_keys = list()
288
+ lowdim_keys = list()
289
+ # construct compressors and chunks
290
+ obs_shape_meta = shape_meta['obs']
291
+ for key, attr in obs_shape_meta.items():
292
+ shape = attr['shape']
293
+ type = attr.get('type', 'low_dim')
294
+ if type == 'rgb':
295
+ rgb_keys.append(key)
296
+ elif type == 'voxel':
297
+ voxel_keys.append(key)
298
+ elif type == 'low_dim':
299
+ lowdim_keys.append(key)
300
+
301
+ root = zarr.group(store)
302
+ data_group = root.require_group('data', overwrite=True)
303
+ meta_group = root.require_group('meta', overwrite=True)
304
+
305
+ with h5py.File(dataset_path) as file:
306
+ # count total steps
307
+ demos = file['data']
308
+ episode_ends = list()
309
+ prev_end = 0
310
+ n_demo = min(n_demo, len(demos))
311
+ for i in range(n_demo):
312
+ demo = demos[f'demo_{i}']
313
+ episode_length = demo['actions'].shape[0]
314
+ episode_end = prev_end + episode_length
315
+ prev_end = episode_end
316
+ episode_ends.append(episode_end)
317
+ n_steps = episode_ends[-1]
318
+ episode_starts = [0] + episode_ends[:-1]
319
+ _ = meta_group.array('episode_ends', episode_ends,
320
+ dtype=np.int64, compressor=None, overwrite=True)
321
+
322
+ # save lowdim data
323
+ for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
324
+ data_key = 'obs/' + key
325
+ if key == 'action':
326
+ data_key = 'actions'
327
+ this_data = list()
328
+ for i in range(n_demo):
329
+ demo = demos[f'demo_{i}']
330
+ this_data.append(demo[data_key][:].astype(np.float32))
331
+ this_data = np.concatenate(this_data, axis=0)
332
+ if key == 'action':
333
+ this_data = _convert_actions(
334
+ raw_actions=this_data,
335
+ abs_action=abs_action,
336
+ rotation_transformer=rotation_transformer
337
+ )
338
+ assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
339
+ else:
340
+ assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
341
+ _ = data_group.array(
342
+ name=key,
343
+ data=this_data,
344
+ shape=this_data.shape,
345
+ chunks=this_data.shape,
346
+ compressor=None,
347
+ dtype=this_data.dtype
348
+ )
349
+
350
+ def copy_to_zarr(zarr_arr, hdf5_arr, start_idx, end_idx):
351
+ try:
352
+ zarr_arr[start_idx:end_idx] = hdf5_arr
353
+ # make sure we can successfully decode
354
+ _ = zarr_arr[start_idx:end_idx]
355
+ return True
356
+ except Exception as e:
357
+ return False
358
+
359
+ with tqdm(total=n_demo*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
360
+ # one chunk per thread, therefore no synchronization needed
361
+ with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
362
+ futures = set()
363
+ for key in rgb_keys:
364
+ data_key = 'obs/' + key
365
+ shape = tuple(shape_meta['obs'][key]['shape'])
366
+ c,h,w = shape
367
+ this_compressor = Jpeg2k(level=50)
368
+ img_arr = data_group.require_dataset(
369
+ name=key,
370
+ shape=(n_steps,h,w,c),
371
+ chunks=(1,h,w,c),
372
+ compressor=this_compressor,
373
+ dtype=np.uint8
374
+ )
375
+ for episode_idx in range(n_demo):
376
+ demo = demos[f'demo_{episode_idx}']
377
+ hdf5_arr = demo['obs'][key][:]
378
+ start_idx = episode_starts[episode_idx]
379
+ if episode_idx < n_demo - 1:
380
+ end_idx = episode_starts[episode_idx+1]
381
+ else:
382
+ end_idx = n_steps
383
+ if len(futures) >= max_inflight_tasks:
384
+ # limit number of inflight tasks
385
+ completed, futures = concurrent.futures.wait(futures,
386
+ return_when=concurrent.futures.FIRST_COMPLETED)
387
+ for f in completed:
388
+ if not f.result():
389
+ raise RuntimeError('Failed to encode image!')
390
+ pbar.update(len(completed))
391
+
392
+ futures.add(
393
+ executor.submit(copy_to_zarr,
394
+ img_arr, hdf5_arr, start_idx, end_idx))
395
+ completed, futures = concurrent.futures.wait(futures)
396
+ for f in completed:
397
+ if not f.result():
398
+ raise RuntimeError('Failed to encode image!')
399
+ pbar.update(len(completed))
400
+
401
+ with tqdm(total=n_demo*len(voxel_keys), desc="Loading voxel data", mininterval=1.0) as pbar:
402
+ # one chunk per thread, therefore no synchronization needed
403
+ with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
404
+ futures = set()
405
+ for key in voxel_keys:
406
+ data_key = key
407
+ shape = tuple(shape_meta['obs'][key]['shape'])
408
+ c,h,w,l = shape
409
+ img_arr = data_group.require_dataset(
410
+ name=key,
411
+ shape=(n_steps,c,h,w,l),
412
+ chunks=(1,c,h,w,l),
413
+ dtype=np.uint8
414
+ )
415
+ for episode_idx in range(n_demo):
416
+ demo = demos[f'demo_{episode_idx}']
417
+ hdf5_arr = demo['obs'][key][:]
418
+ start_idx = episode_starts[episode_idx]
419
+ if episode_idx < n_demo - 1:
420
+ end_idx = episode_starts[episode_idx+1]
421
+ else:
422
+ end_idx = n_steps
423
+ if len(futures) >= max_inflight_tasks:
424
+ # limit number of inflight tasks
425
+ completed, futures = concurrent.futures.wait(futures,
426
+ return_when=concurrent.futures.FIRST_COMPLETED)
427
+ for f in completed:
428
+ if not f.result():
429
+ raise RuntimeError('Failed to encode image!')
430
+ pbar.update(len(completed))
431
+
432
+ futures.add(
433
+ executor.submit(copy_to_zarr,
434
+ img_arr, hdf5_arr, start_idx, end_idx))
435
+ completed, futures = concurrent.futures.wait(futures)
436
+ for f in completed:
437
+ if not f.result():
438
+ raise RuntimeError('Failed to encode image!')
439
+ pbar.update(len(completed))
440
+
441
+ replay_buffer = ReplayBuffer(root)
442
+ return replay_buffer
443
+
444
+ def normalizer_from_stat(stat):
445
+ max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
446
+ scale = np.full_like(stat['max'], fill_value=1/max_abs)
447
+ offset = np.zeros_like(stat['max'])
448
+ return SingleFieldLinearNormalizer.create_manual(
449
+ scale=scale,
450
+ offset=offset,
451
+ input_stats_dict=stat
452
+ )
equidiff/equi_diffpo/env/robomimic/robomimic_image_wrapper.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from matplotlib.pyplot import fill
3
+ import numpy as np
4
+ import gym
5
+ from gym import spaces
6
+ from omegaconf import OmegaConf
7
+ from robomimic.envs.env_robosuite import EnvRobosuite
8
+
9
+ class RobomimicImageWrapper(gym.Env):
10
+ def __init__(self,
11
+ env: EnvRobosuite,
12
+ shape_meta: dict,
13
+ init_state: Optional[np.ndarray]=None,
14
+ render_obs_key='agentview_image',
15
+ ):
16
+
17
+ self.env = env
18
+ self.render_obs_key = render_obs_key
19
+ self.init_state = init_state
20
+ self.seed_state_map = dict()
21
+ self._seed = None
22
+ self.shape_meta = shape_meta
23
+ self.render_cache = None
24
+ self.has_reset_before = False
25
+
26
+ # setup spaces
27
+ action_shape = shape_meta['action']['shape']
28
+ action_space = spaces.Box(
29
+ low=-1,
30
+ high=1,
31
+ shape=action_shape,
32
+ dtype=np.float32
33
+ )
34
+ self.action_space = action_space
35
+
36
+ observation_space = spaces.Dict()
37
+ for key, value in shape_meta['obs'].items():
38
+ shape = value['shape']
39
+ min_value, max_value = -1, 1
40
+ if key.endswith('image'):
41
+ min_value, max_value = 0, 1
42
+ elif key.endswith('depth'):
43
+ min_value, max_value = 0, 1
44
+ elif key.endswith('voxels'):
45
+ min_value, max_value = 0, 1
46
+ elif key.endswith('point_cloud'):
47
+ min_value, max_value = -10, 10
48
+ elif key.endswith('quat'):
49
+ min_value, max_value = -1, 1
50
+ elif key.endswith('qpos'):
51
+ min_value, max_value = -1, 1
52
+ elif key.endswith('pos'):
53
+ # better range?
54
+ min_value, max_value = -1, 1
55
+ else:
56
+ raise RuntimeError(f"Unsupported type {key}")
57
+
58
+ this_space = spaces.Box(
59
+ low=min_value,
60
+ high=max_value,
61
+ shape=shape,
62
+ dtype=np.float32
63
+ )
64
+ observation_space[key] = this_space
65
+ self.observation_space = observation_space
66
+
67
+
68
+ def get_observation(self, raw_obs=None):
69
+ if raw_obs is None:
70
+ raw_obs = self.env.get_observation()
71
+
72
+ self.render_cache = raw_obs[self.render_obs_key]
73
+
74
+ obs = dict()
75
+ for key in self.observation_space.keys():
76
+ obs[key] = raw_obs[key]
77
+ return obs
78
+
79
+ def seed(self, seed=None):
80
+ np.random.seed(seed=seed)
81
+ self._seed = seed
82
+
83
+ def reset(self):
84
+ if self.init_state is not None:
85
+ if not self.has_reset_before:
86
+ # the env must be fully reset at least once to ensure correct rendering
87
+ self.env.reset()
88
+ self.has_reset_before = True
89
+
90
+ # always reset to the same state
91
+ # to be compatible with gym
92
+ raw_obs = self.env.reset_to({'states': self.init_state})
93
+ elif self._seed is not None:
94
+ # reset to a specific seed
95
+ seed = self._seed
96
+ if seed in self.seed_state_map:
97
+ # env.reset is expensive, use cache
98
+ raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]})
99
+ else:
100
+ # robosuite's initializes all use numpy global random state
101
+ np.random.seed(seed=seed)
102
+ raw_obs = self.env.reset()
103
+ state = self.env.get_state()['states']
104
+ self.seed_state_map[seed] = state
105
+ self._seed = None
106
+ else:
107
+ # random reset
108
+ raw_obs = self.env.reset()
109
+
110
+ # return obs
111
+ obs = self.get_observation(raw_obs)
112
+ return obs
113
+
114
+ def step(self, action):
115
+ raw_obs, reward, done, info = self.env.step(action)
116
+ obs = self.get_observation(raw_obs)
117
+ return obs, reward, done, info
118
+
119
+ def render(self, mode='rgb_array'):
120
+ if self.render_cache is None:
121
+ raise RuntimeError('Must run reset or step before render.')
122
+ img = np.moveaxis(self.render_cache, 0, -1)
123
+ img = (img * 255).astype(np.uint8)
124
+ return img
125
+
126
+
127
+ def test():
128
+ import os
129
+ from omegaconf import OmegaConf
130
+ cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
131
+ cfg = OmegaConf.load(cfg_path)
132
+ shape_meta = cfg['shape_meta']
133
+
134
+
135
+ import robomimic.utils.file_utils as FileUtils
136
+ import robomimic.utils.env_utils as EnvUtils
137
+ from matplotlib import pyplot as plt
138
+
139
+ dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5')
140
+ env_meta = FileUtils.get_env_metadata_from_dataset(
141
+ dataset_path)
142
+
143
+ env = EnvUtils.create_env_from_metadata(
144
+ env_meta=env_meta,
145
+ render=False,
146
+ render_offscreen=False,
147
+ use_image_obs=True,
148
+ )
149
+
150
+ wrapper = RobomimicImageWrapper(
151
+ env=env,
152
+ shape_meta=shape_meta
153
+ )
154
+ wrapper.seed(0)
155
+ obs = wrapper.reset()
156
+ img = wrapper.render()
157
+ plt.imshow(img)
158
+
159
+
160
+ # states = list()
161
+ # for _ in range(2):
162
+ # wrapper.seed(0)
163
+ # wrapper.reset()
164
+ # states.append(wrapper.env.get_state()['states'])
165
+ # assert np.allclose(states[0], states[1])
166
+
167
+ # img = wrapper.render()
168
+ # plt.imshow(img)
169
+ # wrapper.seed()
170
+ # states.append(wrapper.env.get_state()['states'])
equidiff/equi_diffpo/env/robomimic/robomimic_lowdim_wrapper.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ import numpy as np
3
+ import gym
4
+ from gym.spaces import Box
5
+ from robomimic.envs.env_robosuite import EnvRobosuite
6
+
7
+ class RobomimicLowdimWrapper(gym.Env):
8
+ def __init__(self,
9
+ env: EnvRobosuite,
10
+ obs_keys: List[str]=[
11
+ 'object',
12
+ 'robot0_eef_pos',
13
+ 'robot0_eef_quat',
14
+ 'robot0_gripper_qpos'],
15
+ init_state: Optional[np.ndarray]=None,
16
+ render_hw=(256,256),
17
+ render_camera_name='agentview'
18
+ ):
19
+
20
+ self.env = env
21
+ self.obs_keys = obs_keys
22
+ self.init_state = init_state
23
+ self.render_hw = render_hw
24
+ self.render_camera_name = render_camera_name
25
+ self.seed_state_map = dict()
26
+ self._seed = None
27
+
28
+ # setup spaces
29
+ low = np.full(env.action_dimension, fill_value=-1)
30
+ high = np.full(env.action_dimension, fill_value=1)
31
+ self.action_space = Box(
32
+ low=low,
33
+ high=high,
34
+ shape=low.shape,
35
+ dtype=low.dtype
36
+ )
37
+ obs_example = self.get_observation()
38
+ low = np.full_like(obs_example, fill_value=-1)
39
+ high = np.full_like(obs_example, fill_value=1)
40
+ self.observation_space = Box(
41
+ low=low,
42
+ high=high,
43
+ shape=low.shape,
44
+ dtype=low.dtype
45
+ )
46
+
47
+ def get_observation(self):
48
+ raw_obs = self.env.get_observation()
49
+ obs = np.concatenate([
50
+ raw_obs[key] for key in self.obs_keys
51
+ ], axis=0)
52
+ return obs
53
+
54
+ def seed(self, seed=None):
55
+ np.random.seed(seed=seed)
56
+ self._seed = seed
57
+
58
+ def reset(self):
59
+ if self.init_state is not None:
60
+ # always reset to the same state
61
+ # to be compatible with gym
62
+ self.env.reset_to({'states': self.init_state})
63
+ elif self._seed is not None:
64
+ # reset to a specific seed
65
+ seed = self._seed
66
+ if seed in self.seed_state_map:
67
+ # env.reset is expensive, use cache
68
+ self.env.reset_to({'states': self.seed_state_map[seed]})
69
+ else:
70
+ # robosuite's initializes all use numpy global random state
71
+ np.random.seed(seed=seed)
72
+ self.env.reset()
73
+ state = self.env.get_state()['states']
74
+ self.seed_state_map[seed] = state
75
+ self._seed = None
76
+ else:
77
+ # random reset
78
+ self.env.reset()
79
+
80
+ # return obs
81
+ obs = self.get_observation()
82
+ return obs
83
+
84
+ def step(self, action):
85
+ raw_obs, reward, done, info = self.env.step(action)
86
+ obs = np.concatenate([
87
+ raw_obs[key] for key in self.obs_keys
88
+ ], axis=0)
89
+ return obs, reward, done, info
90
+
91
+ def render(self, mode='rgb_array'):
92
+ h, w = self.render_hw
93
+ return self.env.render(mode=mode,
94
+ height=h, width=w,
95
+ camera_name=self.render_camera_name)
96
+
97
+
98
+ def test():
99
+ import robomimic.utils.file_utils as FileUtils
100
+ import robomimic.utils.env_utils as EnvUtils
101
+ from matplotlib import pyplot as plt
102
+
103
+ dataset_path = '/home/cchi/dev/diffusion_policy/data/robomimic/datasets/square/ph/low_dim.hdf5'
104
+ env_meta = FileUtils.get_env_metadata_from_dataset(
105
+ dataset_path)
106
+
107
+ env = EnvUtils.create_env_from_metadata(
108
+ env_meta=env_meta,
109
+ render=False,
110
+ render_offscreen=False,
111
+ use_image_obs=False,
112
+ )
113
+ wrapper = RobomimicLowdimWrapper(
114
+ env=env,
115
+ obs_keys=[
116
+ 'object',
117
+ 'robot0_eef_pos',
118
+ 'robot0_eef_quat',
119
+ 'robot0_gripper_qpos'
120
+ ]
121
+ )
122
+
123
+ states = list()
124
+ for _ in range(2):
125
+ wrapper.seed(0)
126
+ wrapper.reset()
127
+ states.append(wrapper.env.get_state()['states'])
128
+ assert np.allclose(states[0], states[1])
129
+
130
+ img = wrapper.render()
131
+ plt.imshow(img)
132
+ # wrapper.seed()
133
+ # states.append(wrapper.env.get_state()['states'])
equidiff/equi_diffpo/env_runner/robomimic_image_runner.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import collections
6
+ import pathlib
7
+ import tqdm
8
+ import h5py
9
+ import math
10
+ import dill
11
+ import wandb.sdk.data_types.video as wv
12
+ from equi_diffpo.gym_util.async_vector_env import AsyncVectorEnv
13
+ from equi_diffpo.gym_util.sync_vector_env import SyncVectorEnv
14
+ from equi_diffpo.gym_util.multistep_wrapper import MultiStepWrapper
15
+ from equi_diffpo.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
16
+ from equi_diffpo.model.common.rotation_transformer import RotationTransformer
17
+
18
+ from equi_diffpo.policy.base_image_policy import BaseImagePolicy
19
+ from equi_diffpo.common.pytorch_util import dict_apply
20
+ from equi_diffpo.env_runner.base_image_runner import BaseImageRunner
21
+ from equi_diffpo.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper
22
+ import robomimic.utils.file_utils as FileUtils
23
+ import robomimic.utils.env_utils as EnvUtils
24
+ import robomimic.utils.obs_utils as ObsUtils
25
+
26
+
27
+ def create_env(env_meta, shape_meta, enable_render=True):
28
+ modality_mapping = collections.defaultdict(list)
29
+ for key, attr in shape_meta['obs'].items():
30
+ modality_mapping[attr.get('type', 'low_dim')].append(key)
31
+ ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
32
+
33
+ env = EnvUtils.create_env_from_metadata(
34
+ env_meta=env_meta,
35
+ render=False,
36
+ render_offscreen=enable_render,
37
+ use_image_obs=enable_render,
38
+ )
39
+ return env
40
+
41
+
42
+ class RobomimicImageRunner(BaseImageRunner):
43
+ """
44
+ Robomimic envs already enforces number of steps.
45
+ """
46
+
47
+ def __init__(self,
48
+ output_dir,
49
+ dataset_path,
50
+ shape_meta:dict,
51
+ n_train=10,
52
+ n_train_vis=3,
53
+ train_start_idx=0,
54
+ n_test=22,
55
+ n_test_vis=6,
56
+ test_start_seed=10000,
57
+ max_steps=400,
58
+ n_obs_steps=2,
59
+ n_action_steps=8,
60
+ render_obs_key='agentview_image',
61
+ fps=10,
62
+ crf=22,
63
+ past_action=False,
64
+ abs_action=False,
65
+ tqdm_interval_sec=5.0,
66
+ n_envs=None
67
+ ):
68
+ super().__init__(output_dir)
69
+
70
+ if n_envs is None:
71
+ n_envs = n_train + n_test
72
+
73
+ # assert n_obs_steps <= n_action_steps
74
+ dataset_path = os.path.expanduser(dataset_path)
75
+ robosuite_fps = 20
76
+ steps_per_render = max(robosuite_fps // fps, 1)
77
+
78
+ # read from dataset
79
+ env_meta = FileUtils.get_env_metadata_from_dataset(
80
+ dataset_path)
81
+ # disable object state observation
82
+ env_meta['env_kwargs']['use_object_obs'] = False
83
+
84
+ rotation_transformer = None
85
+ if abs_action:
86
+ env_meta['env_kwargs']['controller_configs']['control_delta'] = False
87
+ rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')
88
+
89
+ def env_fn():
90
+ robomimic_env = create_env(
91
+ env_meta=env_meta,
92
+ shape_meta=shape_meta
93
+ )
94
+ # Robosuite's hard reset causes excessive memory consumption.
95
+ # Disabled to run more envs.
96
+ # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
97
+ robomimic_env.env.hard_reset = False
98
+ return MultiStepWrapper(
99
+ VideoRecordingWrapper(
100
+ RobomimicImageWrapper(
101
+ env=robomimic_env,
102
+ shape_meta=shape_meta,
103
+ init_state=None,
104
+ render_obs_key=render_obs_key
105
+ ),
106
+ video_recoder=VideoRecorder.create_h264(
107
+ fps=fps,
108
+ codec='h264',
109
+ input_pix_fmt='rgb24',
110
+ crf=crf,
111
+ thread_type='FRAME',
112
+ thread_count=1
113
+ ),
114
+ file_path=None,
115
+ steps_per_render=steps_per_render
116
+ ),
117
+ n_obs_steps=n_obs_steps,
118
+ n_action_steps=n_action_steps,
119
+ max_episode_steps=max_steps
120
+ )
121
+
122
+ # For each process the OpenGL context can only be initialized once
123
+ # Since AsyncVectorEnv uses fork to create worker process,
124
+ # a separate env_fn that does not create OpenGL context (enable_render=False)
125
+ # is needed to initialize spaces.
126
+ def dummy_env_fn():
127
+ robomimic_env = create_env(
128
+ env_meta=env_meta,
129
+ shape_meta=shape_meta,
130
+ enable_render=False
131
+ )
132
+ return MultiStepWrapper(
133
+ VideoRecordingWrapper(
134
+ RobomimicImageWrapper(
135
+ env=robomimic_env,
136
+ shape_meta=shape_meta,
137
+ init_state=None,
138
+ render_obs_key=render_obs_key
139
+ ),
140
+ video_recoder=VideoRecorder.create_h264(
141
+ fps=fps,
142
+ codec='h264',
143
+ input_pix_fmt='rgb24',
144
+ crf=crf,
145
+ thread_type='FRAME',
146
+ thread_count=1
147
+ ),
148
+ file_path=None,
149
+ steps_per_render=steps_per_render
150
+ ),
151
+ n_obs_steps=n_obs_steps,
152
+ n_action_steps=n_action_steps,
153
+ max_episode_steps=max_steps
154
+ )
155
+
156
+ env_fns = [env_fn] * n_envs
157
+ env_seeds = list()
158
+ env_prefixs = list()
159
+ env_init_fn_dills = list()
160
+
161
+ # train
162
+ with h5py.File(dataset_path, 'r') as f:
163
+ for i in range(n_train):
164
+ train_idx = train_start_idx + i
165
+ enable_render = i < n_train_vis
166
+ init_state = f[f'data/demo_{train_idx}/states'][0]
167
+
168
+ def init_fn(env, init_state=init_state,
169
+ enable_render=enable_render):
170
+ # setup rendering
171
+ # video_wrapper
172
+ assert isinstance(env.env, VideoRecordingWrapper)
173
+ env.env.video_recoder.stop()
174
+ env.env.file_path = None
175
+ if enable_render:
176
+ filename = pathlib.Path(output_dir).joinpath(
177
+ 'media', wv.util.generate_id() + ".mp4")
178
+ filename.parent.mkdir(parents=False, exist_ok=True)
179
+ filename = str(filename)
180
+ env.env.file_path = filename
181
+
182
+ # switch to init_state reset
183
+ assert isinstance(env.env.env, RobomimicImageWrapper)
184
+ env.env.env.init_state = init_state
185
+
186
+ env_seeds.append(train_idx)
187
+ env_prefixs.append('train/')
188
+ env_init_fn_dills.append(dill.dumps(init_fn))
189
+
190
+ # test
191
+ for i in range(n_test):
192
+ seed = test_start_seed + i
193
+ enable_render = i < n_test_vis
194
+
195
+ def init_fn(env, seed=seed,
196
+ enable_render=enable_render):
197
+ # setup rendering
198
+ # video_wrapper
199
+ assert isinstance(env.env, VideoRecordingWrapper)
200
+ env.env.video_recoder.stop()
201
+ env.env.file_path = None
202
+ if enable_render:
203
+ filename = pathlib.Path(output_dir).joinpath(
204
+ 'media', wv.util.generate_id() + ".mp4")
205
+ filename.parent.mkdir(parents=False, exist_ok=True)
206
+ filename = str(filename)
207
+ env.env.file_path = filename
208
+
209
+ # switch to seed reset
210
+ assert isinstance(env.env.env, RobomimicImageWrapper)
211
+ env.env.env.init_state = None
212
+ env.seed(seed)
213
+
214
+ env_seeds.append(seed)
215
+ env_prefixs.append('test/')
216
+ env_init_fn_dills.append(dill.dumps(init_fn))
217
+
218
+ env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
219
+
220
+ self.env_meta = env_meta
221
+ self.env = env
222
+ self.env_fns = env_fns
223
+ self.env_seeds = env_seeds
224
+ self.env_prefixs = env_prefixs
225
+ self.env_init_fn_dills = env_init_fn_dills
226
+ self.fps = fps
227
+ self.crf = crf
228
+ self.n_obs_steps = n_obs_steps
229
+ self.n_action_steps = n_action_steps
230
+ self.past_action = past_action
231
+ self.max_steps = max_steps
232
+ self.rotation_transformer = rotation_transformer
233
+ self.abs_action = abs_action
234
+ self.tqdm_interval_sec = tqdm_interval_sec
235
+ self.max_rewards = {}
236
+ for prefix in self.env_prefixs:
237
+ self.max_rewards[prefix] = 0
238
+
239
+ def run(self, policy: BaseImagePolicy):
240
+ device = policy.device
241
+ dtype = policy.dtype
242
+ env = self.env
243
+
244
+ # plan for rollout
245
+ n_envs = len(self.env_fns)
246
+ n_inits = len(self.env_init_fn_dills)
247
+ n_chunks = math.ceil(n_inits / n_envs)
248
+
249
+ # allocate data
250
+ all_video_paths = [None] * n_inits
251
+ all_rewards = [None] * n_inits
252
+
253
+ for chunk_idx in range(n_chunks):
254
+ start = chunk_idx * n_envs
255
+ end = min(n_inits, start + n_envs)
256
+ this_global_slice = slice(start, end)
257
+ this_n_active_envs = end - start
258
+ this_local_slice = slice(0,this_n_active_envs)
259
+
260
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
261
+ n_diff = n_envs - len(this_init_fns)
262
+ if n_diff > 0:
263
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
264
+ assert len(this_init_fns) == n_envs
265
+
266
+ # init envs
267
+ env.call_each('run_dill_function',
268
+ args_list=[(x,) for x in this_init_fns])
269
+
270
+ # start rollout
271
+ obs = env.reset()
272
+ past_action = None
273
+ policy.reset()
274
+
275
+ env_name = self.env_meta['env_name']
276
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx+1}/{n_chunks}",
277
+ leave=False, mininterval=self.tqdm_interval_sec)
278
+
279
+ done = False
280
+ while not done:
281
+ # create obs dict
282
+ np_obs_dict = dict(obs)
283
+ if self.past_action and (past_action is not None):
284
+ # TODO: not tested
285
+ np_obs_dict['past_action'] = past_action[
286
+ :,-(self.n_obs_steps-1):].astype(np.float32)
287
+
288
+ # device transfer
289
+ obs_dict = dict_apply(np_obs_dict,
290
+ lambda x: torch.from_numpy(x).to(
291
+ device=device))
292
+
293
+ # run policy
294
+ with torch.no_grad():
295
+ action_dict = policy.predict_action(obs_dict)
296
+
297
+ # device_transfer
298
+ np_action_dict = dict_apply(action_dict,
299
+ lambda x: x.detach().to('cpu').numpy())
300
+
301
+ action = np_action_dict['action']
302
+ if not np.all(np.isfinite(action)):
303
+ print(action)
304
+ raise RuntimeError("Nan or Inf action")
305
+
306
+ # step env
307
+ env_action = action
308
+ if self.abs_action:
309
+ env_action = self.undo_transform_action(action)
310
+
311
+ obs, reward, done, info = env.step(env_action)
312
+ done = np.all(done)
313
+ past_action = action
314
+
315
+ # update pbar
316
+ pbar.update(action.shape[1])
317
+ pbar.close()
318
+
319
+ # collect data for this round
320
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
321
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
322
+ # clear out video buffer
323
+ _ = env.reset()
324
+
325
+ # log
326
+ max_rewards = collections.defaultdict(list)
327
+ log_data = dict()
328
+ # results reported in the paper are generated using the commented out line below
329
+ # which will only report and average metrics from first n_envs initial condition and seeds
330
+ # fortunately this won't invalidate our conclusion since
331
+ # 1. This bug only affects the variance of metrics, not their mean
332
+ # 2. All baseline methods are evaluated using the same code
333
+ # to completely reproduce reported numbers, uncomment this line:
334
+ # for i in range(len(self.env_fns)):
335
+ # and comment out this line
336
+ for i in range(n_inits):
337
+ seed = self.env_seeds[i]
338
+ prefix = self.env_prefixs[i]
339
+ max_reward = np.max(all_rewards[i])
340
+ max_rewards[prefix].append(max_reward)
341
+ log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
342
+
343
+ # visualize sim
344
+ video_path = all_video_paths[i]
345
+ if video_path is not None:
346
+ sim_video = wandb.Video(video_path)
347
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
348
+
349
+ # log aggregate metrics
350
+ for prefix, value in max_rewards.items():
351
+ name = prefix+'mean_score'
352
+ value = np.mean(value)
353
+ log_data[name] = value
354
+ self.max_rewards[prefix] = max(self.max_rewards[prefix], value)
355
+ log_data[prefix+'max_score'] = self.max_rewards[prefix]
356
+
357
+ return log_data
358
+
359
+ def undo_transform_action(self, action):
360
+ raw_shape = action.shape
361
+ if raw_shape[-1] == 20:
362
+ # dual arm
363
+ action = action.reshape(-1,2,10)
364
+
365
+ d_rot = action.shape[-1] - 4
366
+ pos = action[...,:3]
367
+ rot = action[...,3:3+d_rot]
368
+ gripper = action[...,[-1]]
369
+ rot = self.rotation_transformer.inverse(rot)
370
+ uaction = np.concatenate([
371
+ pos, rot, gripper
372
+ ], axis=-1)
373
+
374
+ if raw_shape[-1] == 20:
375
+ # dual arm
376
+ uaction = uaction.reshape(*raw_shape[:-1], 14)
377
+
378
+ return uaction
equidiff/equi_diffpo/env_runner/robomimic_lowdim_runner.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import collections
6
+ import pathlib
7
+ import tqdm
8
+ import h5py
9
+ import dill
10
+ import math
11
+ import wandb.sdk.data_types.video as wv
12
+ from equi_diffpo.gym_util.async_vector_env import AsyncVectorEnv
13
+ # from equi_diffpo.gym_util.sync_vector_env import SyncVectorEnv
14
+ from equi_diffpo.gym_util.multistep_wrapper import MultiStepWrapper
15
+ from equi_diffpo.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
16
+ from equi_diffpo.model.common.rotation_transformer import RotationTransformer
17
+
18
+ from equi_diffpo.policy.base_lowdim_policy import BaseLowdimPolicy
19
+ from equi_diffpo.common.pytorch_util import dict_apply
20
+ from equi_diffpo.env_runner.base_lowdim_runner import BaseLowdimRunner
21
+ from equi_diffpo.env.robomimic.robomimic_lowdim_wrapper import RobomimicLowdimWrapper
22
+ import robomimic.utils.file_utils as FileUtils
23
+ import robomimic.utils.env_utils as EnvUtils
24
+ import robomimic.utils.obs_utils as ObsUtils
25
+
26
+
27
+ def create_env(env_meta, obs_keys, enable_render=True):
28
+ ObsUtils.initialize_obs_modality_mapping_from_dict(
29
+ {'low_dim': obs_keys})
30
+ env = EnvUtils.create_env_from_metadata(
31
+ env_meta=env_meta,
32
+ render=False,
33
+ # only way to not show collision geometry
34
+ # is to enable render_offscreen
35
+ # which uses a lot of RAM.
36
+ render_offscreen=enable_render,
37
+ use_image_obs=enable_render,
38
+ )
39
+ return env
40
+
41
+
42
+ class RobomimicLowdimRunner(BaseLowdimRunner):
43
+ """
44
+ Robomimic envs already enforces number of steps.
45
+ """
46
+
47
+ def __init__(self,
48
+ output_dir,
49
+ dataset_path,
50
+ obs_keys,
51
+ n_train=10,
52
+ n_train_vis=3,
53
+ train_start_idx=0,
54
+ n_test=22,
55
+ n_test_vis=6,
56
+ test_start_seed=10000,
57
+ max_steps=400,
58
+ n_obs_steps=2,
59
+ n_action_steps=8,
60
+ n_latency_steps=0,
61
+ render_hw=(256,256),
62
+ render_camera_name='agentview',
63
+ fps=10,
64
+ crf=22,
65
+ past_action=False,
66
+ abs_action=False,
67
+ tqdm_interval_sec=5.0,
68
+ n_envs=None
69
+ ):
70
+ """
71
+ Assuming:
72
+ n_obs_steps=2
73
+ n_latency_steps=3
74
+ n_action_steps=4
75
+ o: obs
76
+ i: inference
77
+ a: action
78
+ Batch t:
79
+ |o|o| | | | | | |
80
+ | |i|i|i| | | | |
81
+ | | | | |a|a|a|a|
82
+ Batch t+1
83
+ | | | | |o|o| | | | | | |
84
+ | | | | | |i|i|i| | | | |
85
+ | | | | | | | | |a|a|a|a|
86
+ """
87
+
88
+ super().__init__(output_dir)
89
+
90
+ if n_envs is None:
91
+ n_envs = n_train + n_test
92
+
93
+ # handle latency step
94
+ # to mimic latency, we request n_latency_steps additional steps
95
+ # of past observations, and the discard the last n_latency_steps
96
+ env_n_obs_steps = n_obs_steps + n_latency_steps
97
+ env_n_action_steps = n_action_steps
98
+
99
+ # assert n_obs_steps <= n_action_steps
100
+ dataset_path = os.path.expanduser(dataset_path)
101
+ robosuite_fps = 20
102
+ steps_per_render = max(robosuite_fps // fps, 1)
103
+
104
+ # read from dataset
105
+ env_meta = FileUtils.get_env_metadata_from_dataset(
106
+ dataset_path)
107
+ rotation_transformer = None
108
+ if abs_action:
109
+ env_meta['env_kwargs']['controller_configs']['control_delta'] = False
110
+ rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')
111
+
112
+ def env_fn():
113
+ robomimic_env = create_env(
114
+ env_meta=env_meta,
115
+ obs_keys=obs_keys
116
+ )
117
+ # Robosuite's hard reset causes excessive memory consumption.
118
+ # Disabled to run more envs.
119
+ # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
120
+ robomimic_env.env.hard_reset = False
121
+ return MultiStepWrapper(
122
+ VideoRecordingWrapper(
123
+ RobomimicLowdimWrapper(
124
+ env=robomimic_env,
125
+ obs_keys=obs_keys,
126
+ init_state=None,
127
+ render_hw=render_hw,
128
+ render_camera_name=render_camera_name
129
+ ),
130
+ video_recoder=VideoRecorder.create_h264(
131
+ fps=fps,
132
+ codec='h264',
133
+ input_pix_fmt='rgb24',
134
+ crf=crf,
135
+ thread_type='FRAME',
136
+ thread_count=1
137
+ ),
138
+ file_path=None,
139
+ steps_per_render=steps_per_render
140
+ ),
141
+ n_obs_steps=n_obs_steps,
142
+ n_action_steps=n_action_steps,
143
+ max_episode_steps=max_steps
144
+ )
145
+
146
+ # For each process the OpenGL context can only be initialized once
147
+ # Since AsyncVectorEnv uses fork to create worker process,
148
+ # a separate env_fn that does not create OpenGL context (enable_render=False)
149
+ # is needed to initialize spaces.
150
+ def dummy_env_fn():
151
+ robomimic_env = create_env(
152
+ env_meta=env_meta,
153
+ obs_keys=obs_keys,
154
+ enable_render=False
155
+ )
156
+ return MultiStepWrapper(
157
+ VideoRecordingWrapper(
158
+ RobomimicLowdimWrapper(
159
+ env=robomimic_env,
160
+ obs_keys=obs_keys,
161
+ init_state=None,
162
+ render_hw=render_hw,
163
+ render_camera_name=render_camera_name
164
+ ),
165
+ video_recoder=VideoRecorder.create_h264(
166
+ fps=fps,
167
+ codec='h264',
168
+ input_pix_fmt='rgb24',
169
+ crf=crf,
170
+ thread_type='FRAME',
171
+ thread_count=1
172
+ ),
173
+ file_path=None,
174
+ steps_per_render=steps_per_render
175
+ ),
176
+ n_obs_steps=n_obs_steps,
177
+ n_action_steps=n_action_steps,
178
+ max_episode_steps=max_steps
179
+ )
180
+
181
+ env_fns = [env_fn] * n_envs
182
+ env_seeds = list()
183
+ env_prefixs = list()
184
+ env_init_fn_dills = list()
185
+
186
+ # train
187
+ with h5py.File(dataset_path, 'r') as f:
188
+ for i in range(n_train):
189
+ train_idx = train_start_idx + i
190
+ enable_render = i < n_train_vis
191
+ init_state = f[f'data/demo_{train_idx}/states'][0]
192
+
193
+ def init_fn(env, init_state=init_state,
194
+ enable_render=enable_render):
195
+ # setup rendering
196
+ # video_wrapper
197
+ assert isinstance(env.env, VideoRecordingWrapper)
198
+ env.env.video_recoder.stop()
199
+ env.env.file_path = None
200
+ if enable_render:
201
+ filename = pathlib.Path(output_dir).joinpath(
202
+ 'media', wv.util.generate_id() + ".mp4")
203
+ filename.parent.mkdir(parents=False, exist_ok=True)
204
+ filename = str(filename)
205
+ env.env.file_path = filename
206
+
207
+ # switch to init_state reset
208
+ assert isinstance(env.env.env, RobomimicLowdimWrapper)
209
+ env.env.env.init_state = init_state
210
+
211
+ env_seeds.append(train_idx)
212
+ env_prefixs.append('train/')
213
+ env_init_fn_dills.append(dill.dumps(init_fn))
214
+
215
+ # test
216
+ for i in range(n_test):
217
+ seed = test_start_seed + i
218
+ enable_render = i < n_test_vis
219
+
220
+ def init_fn(env, seed=seed,
221
+ enable_render=enable_render):
222
+ # setup rendering
223
+ # video_wrapper
224
+ assert isinstance(env.env, VideoRecordingWrapper)
225
+ env.env.video_recoder.stop()
226
+ env.env.file_path = None
227
+ if enable_render:
228
+ filename = pathlib.Path(output_dir).joinpath(
229
+ 'media', wv.util.generate_id() + ".mp4")
230
+ filename.parent.mkdir(parents=False, exist_ok=True)
231
+ filename = str(filename)
232
+ env.env.file_path = filename
233
+
234
+ # switch to seed reset
235
+ assert isinstance(env.env.env, RobomimicLowdimWrapper)
236
+ env.env.env.init_state = None
237
+ env.seed(seed)
238
+
239
+ env_seeds.append(seed)
240
+ env_prefixs.append('test/')
241
+ env_init_fn_dills.append(dill.dumps(init_fn))
242
+
243
+ env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
244
+ # env = SyncVectorEnv(env_fns)
245
+
246
+ self.env_meta = env_meta
247
+ self.env = env
248
+ self.env_fns = env_fns
249
+ self.env_seeds = env_seeds
250
+ self.env_prefixs = env_prefixs
251
+ self.env_init_fn_dills = env_init_fn_dills
252
+ self.fps = fps
253
+ self.crf = crf
254
+ self.n_obs_steps = n_obs_steps
255
+ self.n_action_steps = n_action_steps
256
+ self.n_latency_steps = n_latency_steps
257
+ self.env_n_obs_steps = env_n_obs_steps
258
+ self.env_n_action_steps = env_n_action_steps
259
+ self.past_action = past_action
260
+ self.max_steps = max_steps
261
+ self.rotation_transformer = rotation_transformer
262
+ self.abs_action = abs_action
263
+ self.tqdm_interval_sec = tqdm_interval_sec
264
+
265
+ def run(self, policy: BaseLowdimPolicy):
266
+ device = policy.device
267
+ dtype = policy.dtype
268
+ env = self.env
269
+
270
+ # plan for rollout
271
+ n_envs = len(self.env_fns)
272
+ n_inits = len(self.env_init_fn_dills)
273
+ n_chunks = math.ceil(n_inits / n_envs)
274
+
275
+ # allocate data
276
+ all_video_paths = [None] * n_inits
277
+ all_rewards = [None] * n_inits
278
+
279
+ for chunk_idx in range(n_chunks):
280
+ start = chunk_idx * n_envs
281
+ end = min(n_inits, start + n_envs)
282
+ this_global_slice = slice(start, end)
283
+ this_n_active_envs = end - start
284
+ this_local_slice = slice(0,this_n_active_envs)
285
+
286
+ this_init_fns = self.env_init_fn_dills[this_global_slice]
287
+ n_diff = n_envs - len(this_init_fns)
288
+ if n_diff > 0:
289
+ this_init_fns.extend([self.env_init_fn_dills[0]]*n_diff)
290
+ assert len(this_init_fns) == n_envs
291
+
292
+ # init envs
293
+ env.call_each('run_dill_function',
294
+ args_list=[(x,) for x in this_init_fns])
295
+
296
+ # start rollout
297
+ obs = env.reset()
298
+ past_action = None
299
+ policy.reset()
300
+
301
+ env_name = self.env_meta['env_name']
302
+ pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Lowdim {chunk_idx+1}/{n_chunks}",
303
+ leave=False, mininterval=self.tqdm_interval_sec)
304
+
305
+ done = False
306
+ while not done:
307
+ # create obs dict
308
+ np_obs_dict = {
309
+ # handle n_latency_steps by discarding the last n_latency_steps
310
+ 'obs': obs[:,:self.n_obs_steps].astype(np.float32)
311
+ }
312
+ if self.past_action and (past_action is not None):
313
+ # TODO: not tested
314
+ np_obs_dict['past_action'] = past_action[
315
+ :,-(self.n_obs_steps-1):].astype(np.float32)
316
+
317
+ # device transfer
318
+ obs_dict = dict_apply(np_obs_dict,
319
+ lambda x: torch.from_numpy(x).to(
320
+ device=device))
321
+
322
+ # run policy
323
+ with torch.no_grad():
324
+ action_dict = policy.predict_action(obs_dict)
325
+
326
+ # device_transfer
327
+ np_action_dict = dict_apply(action_dict,
328
+ lambda x: x.detach().to('cpu').numpy())
329
+
330
+ # handle latency_steps, we discard the first n_latency_steps actions
331
+ # to simulate latency
332
+ action = np_action_dict['action'][:,self.n_latency_steps:]
333
+ if not np.all(np.isfinite(action)):
334
+ print(action)
335
+ raise RuntimeError("Nan or Inf action")
336
+
337
+ # step env
338
+ env_action = action
339
+ if self.abs_action:
340
+ env_action = self.undo_transform_action(action)
341
+
342
+ obs, reward, done, info = env.step(env_action)
343
+ done = np.all(done)
344
+ past_action = action
345
+
346
+ # update pbar
347
+ pbar.update(action.shape[1])
348
+ pbar.close()
349
+
350
+ # collect data for this round
351
+ all_video_paths[this_global_slice] = env.render()[this_local_slice]
352
+ all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
353
+
354
+ # log
355
+ max_rewards = collections.defaultdict(list)
356
+ log_data = dict()
357
+ # results reported in the paper are generated using the commented out line below
358
+ # which will only report and average metrics from first n_envs initial condition and seeds
359
+ # fortunately this won't invalidate our conclusion since
360
+ # 1. This bug only affects the variance of metrics, not their mean
361
+ # 2. All baseline methods are evaluated using the same code
362
+ # to completely reproduce reported numbers, uncomment this line:
363
+ # for i in range(len(self.env_fns)):
364
+ # and comment out this line
365
+ for i in range(n_inits):
366
+ seed = self.env_seeds[i]
367
+ prefix = self.env_prefixs[i]
368
+ max_reward = np.max(all_rewards[i])
369
+ max_rewards[prefix].append(max_reward)
370
+ log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
371
+
372
+ # visualize sim
373
+ video_path = all_video_paths[i]
374
+ if video_path is not None:
375
+ sim_video = wandb.Video(video_path)
376
+ log_data[prefix+f'sim_video_{seed}'] = sim_video
377
+
378
+ # log aggregate metrics
379
+ for prefix, value in max_rewards.items():
380
+ name = prefix+'mean_score'
381
+ value = np.mean(value)
382
+ log_data[name] = value
383
+
384
+ return log_data
385
+
386
+ def undo_transform_action(self, action):
387
+ raw_shape = action.shape
388
+ if raw_shape[-1] == 20:
389
+ # dual arm
390
+ action = action.reshape(-1,2,10)
391
+
392
+ d_rot = action.shape[-1] - 4
393
+ pos = action[...,:3]
394
+ rot = action[...,3:3+d_rot]
395
+ gripper = action[...,[-1]]
396
+ rot = self.rotation_transformer.inverse(rot)
397
+ uaction = np.concatenate([
398
+ pos, rot, gripper
399
+ ], axis=-1)
400
+
401
+ if raw_shape[-1] == 20:
402
+ # dual arm
403
+ uaction = uaction.reshape(*raw_shape[:-1], 14)
404
+
405
+ return uaction
equidiff/equi_diffpo/policy/robomimic_image_policy.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import torch
3
+ from equi_diffpo.model.common.normalizer import LinearNormalizer
4
+ from equi_diffpo.policy.base_image_policy import BaseImagePolicy
5
+ from equi_diffpo.common.pytorch_util import dict_apply
6
+
7
+ from robomimic.algo import algo_factory
8
+ from robomimic.algo.algo import PolicyAlgo
9
+ import robomimic.utils.obs_utils as ObsUtils
10
+ from equi_diffpo.common.robomimic_config_util import get_robomimic_config
11
+
12
+ class RobomimicImagePolicy(BaseImagePolicy):
13
+ def __init__(self,
14
+ shape_meta: dict,
15
+ algo_name='bc_rnn',
16
+ obs_type='image',
17
+ task_name='square',
18
+ dataset_type='ph',
19
+ crop_shape=(76,76)
20
+ ):
21
+ super().__init__()
22
+
23
+ # parse shape_meta
24
+ action_shape = shape_meta['action']['shape']
25
+ assert len(action_shape) == 1
26
+ action_dim = action_shape[0]
27
+ obs_shape_meta = shape_meta['obs']
28
+ obs_config = {
29
+ 'low_dim': [],
30
+ 'rgb': [],
31
+ 'depth': [],
32
+ 'scan': []
33
+ }
34
+ obs_key_shapes = dict()
35
+ for key, attr in obs_shape_meta.items():
36
+ shape = attr['shape']
37
+ obs_key_shapes[key] = list(shape)
38
+
39
+ type = attr.get('type', 'low_dim')
40
+ if type == 'rgb':
41
+ obs_config['rgb'].append(key)
42
+ elif type == 'low_dim':
43
+ obs_config['low_dim'].append(key)
44
+ else:
45
+ raise RuntimeError(f"Unsupported obs type: {type}")
46
+
47
+ # get raw robomimic config
48
+ config = get_robomimic_config(
49
+ algo_name=algo_name,
50
+ hdf5_type=obs_type,
51
+ task_name=task_name,
52
+ dataset_type=dataset_type)
53
+
54
+
55
+ with config.unlocked():
56
+ # set config with shape_meta
57
+ config.observation.modalities.obs = obs_config
58
+
59
+ if crop_shape is None:
60
+ for key, modality in config.observation.encoder.items():
61
+ if modality.obs_randomizer_class == 'CropRandomizer':
62
+ modality['obs_randomizer_class'] = None
63
+ else:
64
+ # set random crop parameter
65
+ ch, cw = crop_shape
66
+ for key, modality in config.observation.encoder.items():
67
+ if modality.obs_randomizer_class == 'CropRandomizer':
68
+ modality.obs_randomizer_kwargs.crop_height = ch
69
+ modality.obs_randomizer_kwargs.crop_width = cw
70
+
71
+ # init global state
72
+ ObsUtils.initialize_obs_utils_with_config(config)
73
+
74
+ # load model
75
+ model: PolicyAlgo = algo_factory(
76
+ algo_name=config.algo_name,
77
+ config=config,
78
+ obs_key_shapes=obs_key_shapes,
79
+ ac_dim=action_dim,
80
+ device='cpu',
81
+ )
82
+
83
+ self.model = model
84
+ self.nets = model.nets
85
+ self.normalizer = LinearNormalizer()
86
+ self.config = config
87
+
88
+ def to(self,*args,**kwargs):
89
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
90
+ if device is not None:
91
+ self.model.device = device
92
+ super().to(*args,**kwargs)
93
+
94
+ # =========== inference =============
95
+ def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
96
+ nobs_dict = self.normalizer(obs_dict)
97
+ robomimic_obs_dict = dict_apply(nobs_dict, lambda x: x[:,0,...])
98
+ naction = self.model.get_action(robomimic_obs_dict)
99
+ action = self.normalizer['action'].unnormalize(naction)
100
+ # (B, Da)
101
+ result = {
102
+ 'action': action[:,None,:] # (B, 1, Da)
103
+ }
104
+ return result
105
+
106
+ def reset(self):
107
+ self.model.reset()
108
+
109
+ # =========== training ==============
110
+ def set_normalizer(self, normalizer: LinearNormalizer):
111
+ self.normalizer.load_state_dict(normalizer.state_dict())
112
+
113
+ def train_on_batch(self, batch, epoch, validate=False):
114
+ nobs = self.normalizer.normalize(batch['obs'])
115
+ nactions = self.normalizer['action'].normalize(batch['action'])
116
+ robomimic_batch = {
117
+ 'obs': nobs,
118
+ 'actions': nactions
119
+ }
120
+ input_batch = self.model.process_batch_for_training(
121
+ robomimic_batch)
122
+ info = self.model.train_on_batch(
123
+ batch=input_batch, epoch=epoch, validate=validate)
124
+ # keys: losses, predictions
125
+ return info
126
+
127
+ def on_epoch_end(self, epoch):
128
+ self.model.on_epoch_end(epoch)
129
+
130
+ def get_optimizer(self):
131
+ return self.model.optimizers['policy']
132
+
133
+
134
+ def test():
135
+ import os
136
+ from omegaconf import OmegaConf
137
+ cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
138
+ cfg = OmegaConf.load(cfg_path)
139
+ shape_meta = cfg.shape_meta
140
+
141
+ policy = RobomimicImagePolicy(shape_meta=shape_meta)
142
+
equidiff/equi_diffpo/scripts/robomimic_dataset_action_comparison.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+
9
+ import os
10
+ import click
11
+ import pathlib
12
+ import h5py
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from scipy.spatial.transform import Rotation
16
+
17
+ def read_all_actions(hdf5_file, metric_skip_steps=1):
18
+ n_demos = len(hdf5_file['data'])
19
+ all_actions = list()
20
+ for i in tqdm(range(n_demos)):
21
+ actions = hdf5_file[f'data/demo_{i}/actions'][:]
22
+ all_actions.append(actions[metric_skip_steps:])
23
+ all_actions = np.concatenate(all_actions, axis=0)
24
+ return all_actions
25
+
26
+
27
+ @click.command()
28
+ @click.option('-i', '--input', required=True, help='input hdf5 path')
29
+ @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
30
+ def main(input, output):
31
+ # process inputs
32
+ input = pathlib.Path(input).expanduser()
33
+ assert input.is_file()
34
+ output = pathlib.Path(output).expanduser()
35
+ assert output.is_file()
36
+
37
+ input_file = h5py.File(str(input), 'r')
38
+ output_file = h5py.File(str(output), 'r')
39
+
40
+ input_all_actions = read_all_actions(input_file)
41
+ output_all_actions = read_all_actions(output_file)
42
+ pos_dist = np.linalg.norm(input_all_actions[:,:3] - output_all_actions[:,:3], axis=-1)
43
+ rot_dist = (Rotation.from_rotvec(input_all_actions[:,3:6]
44
+ ) * Rotation.from_rotvec(output_all_actions[:,3:6]).inv()
45
+ ).magnitude()
46
+
47
+ print(f'max pos dist: {pos_dist.max()}')
48
+ print(f'max rot dist: {rot_dist.max()}')
49
+
50
+ if __name__ == "__main__":
51
+ main()
equidiff/equi_diffpo/scripts/robomimic_dataset_conversion.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+
9
+ import multiprocessing
10
+ import os
11
+ import shutil
12
+ import click
13
+ import pathlib
14
+ import h5py
15
+ from tqdm import tqdm
16
+ import collections
17
+ import pickle
18
+ from equi_diffpo.common.robomimic_util import RobomimicAbsoluteActionConverter
19
+
20
+ def worker(x):
21
+ path, idx, do_eval = x
22
+ converter = RobomimicAbsoluteActionConverter(path)
23
+ if do_eval:
24
+ abs_actions, info = converter.convert_and_eval_idx(idx)
25
+ else:
26
+ abs_actions = converter.convert_idx(idx)
27
+ info = dict()
28
+ return abs_actions, info
29
+
30
+ @click.command()
31
+ @click.option('-i', '--input', required=True, help='input hdf5 path')
32
+ @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
33
+ @click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics')
34
+ @click.option('-n', '--num_workers', default=None, type=int)
35
+ def main(input, output, eval_dir, num_workers):
36
+ # process inputs
37
+ input = pathlib.Path(input).expanduser()
38
+ assert input.is_file()
39
+ output = pathlib.Path(output).expanduser()
40
+ assert output.parent.is_dir()
41
+ assert not output.is_dir()
42
+
43
+ do_eval = False
44
+ if eval_dir is not None:
45
+ eval_dir = pathlib.Path(eval_dir).expanduser()
46
+ assert eval_dir.parent.exists()
47
+ do_eval = True
48
+
49
+ converter = RobomimicAbsoluteActionConverter(input)
50
+
51
+ # run
52
+ with multiprocessing.Pool(num_workers) as pool:
53
+ results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))])
54
+
55
+ # save output
56
+ print('Copying hdf5')
57
+ shutil.copy(str(input), str(output))
58
+
59
+ # modify action
60
+ with h5py.File(output, 'r+') as out_file:
61
+ for i in tqdm(range(len(converter)), desc="Writing to output"):
62
+ abs_actions, info = results[i]
63
+ demo = out_file[f'data/demo_{i}']
64
+ demo['actions'][:] = abs_actions
65
+
66
+ # save eval
67
+ if do_eval:
68
+ eval_dir.mkdir(parents=False, exist_ok=True)
69
+
70
+ print("Writing error_stats.pkl")
71
+ infos = [info for _, info in results]
72
+ pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb'))
73
+
74
+ print("Generating visualization")
75
+ metrics = ['pos', 'rot']
76
+ metrics_dicts = dict()
77
+ for m in metrics:
78
+ metrics_dicts[m] = collections.defaultdict(list)
79
+
80
+ for i in range(len(infos)):
81
+ info = infos[i]
82
+ for k, v in info.items():
83
+ for m in metrics:
84
+ metrics_dicts[m][k].append(v[m])
85
+
86
+ from matplotlib import pyplot as plt
87
+ plt.switch_backend('PDF')
88
+
89
+ fig, ax = plt.subplots(1, len(metrics))
90
+ for i in range(len(metrics)):
91
+ axis = ax[i]
92
+ data = metrics_dicts[metrics[i]]
93
+ for key, value in data.items():
94
+ axis.plot(value, label=key)
95
+ axis.legend()
96
+ axis.set_title(metrics[i])
97
+ fig.set_size_inches(10,4)
98
+ fig.savefig(str(eval_dir.joinpath('error_stats.pdf')))
99
+ fig.savefig(str(eval_dir.joinpath('error_stats.png')))
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
equidiff/equi_diffpo/scripts/robomimic_dataset_obs_conversion.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import sys
3
+ import os
4
+ import pathlib
5
+
6
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7
+ sys.path.append(ROOT_DIR)
8
+
9
+ import multiprocessing
10
+ import os
11
+ import shutil
12
+ import click
13
+ import pathlib
14
+ import h5py
15
+ from tqdm import tqdm
16
+ import numpy as np
17
+ import collections
18
+ import pickle
19
+ from equi_diffpo.common.robomimic_util import RobomimicObsConverter
20
+
21
+ multiprocessing.set_start_method('spawn', force=True)
22
+
23
+ def worker(x):
24
+ path, idx = x
25
+ converter = RobomimicObsConverter(path)
26
+ obss = converter.convert_idx(idx)
27
+ return obss
28
+
29
+ @click.command()
30
+ @click.option('-i', '--input', required=True, help='input hdf5 path')
31
+ @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
32
+ @click.option('-n', '--num_workers', default=None, type=int)
33
+ def main(input, output, num_workers):
34
+ # process inputs
35
+ input = pathlib.Path(input).expanduser()
36
+ assert input.is_file()
37
+ output = pathlib.Path(output).expanduser()
38
+ assert output.parent.is_dir()
39
+ assert not output.is_dir()
40
+
41
+ converter = RobomimicObsConverter(input)
42
+
43
+ # save output
44
+ print('Copying hdf5')
45
+ shutil.copy(str(input), str(output))
46
+
47
+ # run
48
+ idx = 0
49
+ while idx < len(converter):
50
+ with multiprocessing.Pool(num_workers) as pool:
51
+ end = min(idx + num_workers, len(converter))
52
+ results = pool.map(worker, [(input, i) for i in range(idx, end)])
53
+
54
+ # modify action
55
+ print('Writing {} to {}'.format(idx, end))
56
+ with h5py.File(output, 'r+') as out_file:
57
+ for i in tqdm(range(idx, end), desc="Writing to output"):
58
+ obss = results[i - idx]
59
+ demo = out_file[f'data/demo_{i}']
60
+ del demo['obs']
61
+ for k in obss:
62
+ demo.create_dataset("obs/{}".format(k), data=np.array(obss[k]), compression="gzip")
63
+
64
+ idx = end
65
+ del results
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()