| """ |
| Minimal example script for converting a dataset collected on the DROID platform to LeRobot format. |
| |
| Usage: |
| uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data |
| |
| If you want to push your dataset to the Hugging Face Hub, you can use the following command: |
| uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub |
| |
| The resulting dataset will get saved to the $LEROBOT_HOME directory. |
| """ |
|
|
| from collections import defaultdict |
| import copy |
| import glob |
| import json |
| from pathlib import Path |
| import shutil |
|
|
| import cv2 |
| import h5py |
| from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME |
| from lerobot.common.datasets.lerobot_dataset import LeRobotDataset |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| import tyro |
|
|
| REPO_NAME = "your_hf_username/my_droid_dataset" |
|
|
|
|
| def resize_image(image, size): |
| image = Image.fromarray(image) |
| return np.array(image.resize(size, resample=Image.BICUBIC)) |
|
|
|
|
| def main(data_dir: str, *, push_to_hub: bool = False): |
| |
| output_path = HF_LEROBOT_HOME / REPO_NAME |
| if output_path.exists(): |
| shutil.rmtree(output_path) |
| data_dir = Path(data_dir) |
|
|
| |
| |
| |
| dataset = LeRobotDataset.create( |
| repo_id=REPO_NAME, |
| robot_type="panda", |
| fps=15, |
| features={ |
| |
| "exterior_image_1_left": { |
| "dtype": "image", |
| "shape": (180, 320, 3), |
| "names": ["height", "width", "channel"], |
| }, |
| "exterior_image_2_left": { |
| "dtype": "image", |
| "shape": (180, 320, 3), |
| "names": ["height", "width", "channel"], |
| }, |
| "wrist_image_left": { |
| "dtype": "image", |
| "shape": (180, 320, 3), |
| "names": ["height", "width", "channel"], |
| }, |
| "joint_position": { |
| "dtype": "float32", |
| "shape": (7,), |
| "names": ["joint_position"], |
| }, |
| "gripper_position": { |
| "dtype": "float32", |
| "shape": (1,), |
| "names": ["gripper_position"], |
| }, |
| "actions": { |
| "dtype": "float32", |
| "shape": (8,), |
| "names": ["actions"], |
| }, |
| }, |
| image_writer_threads=10, |
| image_writer_processes=5, |
| ) |
|
|
| |
| |
| with (data_dir / "aggregated-annotations-030724.json").open() as f: |
| language_annotations = json.load(f) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| episode_paths = list(data_dir.glob("**/trajectory.h5")) |
| print(f"Found {len(episode_paths)} episodes for conversion") |
|
|
| |
| for episode_path in tqdm(episode_paths, desc="Converting episodes"): |
| |
| recording_folderpath = episode_path.parent / "recordings" / "MP4" |
| trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath)) |
|
|
| |
| |
| metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json"))) |
| episode_id = metadata_filepath.name.split(".")[0].split("_")[-1] |
| language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[ |
| "language_instruction1" |
| ] |
| print(f"Converting episode with language instruction: {language_instruction}") |
|
|
| |
| for step in trajectory: |
| camera_type_dict = step["observation"]["camera_type"] |
| wrist_ids = [k for k, v in camera_type_dict.items() if v == 0] |
| exterior_ids = [k for k, v in camera_type_dict.items() if v != 0] |
| dataset.add_frame( |
| { |
| |
| "exterior_image_1_left": resize_image( |
| step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180) |
| ), |
| "exterior_image_2_left": resize_image( |
| step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180) |
| ), |
| "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)), |
| "joint_position": np.asarray( |
| step["observation"]["robot_state"]["joint_positions"], dtype=np.float32 |
| ), |
| "gripper_position": np.asarray( |
| step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32 |
| ), |
| |
| "actions": np.concatenate( |
| [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32 |
| ), |
| "task": language_instruction, |
| } |
| ) |
| dataset.save_episode() |
|
|
| |
| if push_to_hub: |
| dataset.push_to_hub( |
| tags=["libero", "panda", "rlds"], |
| private=False, |
| push_videos=True, |
| license="apache-2.0", |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| camera_type_dict = { |
| "hand_camera_id": 0, |
| "varied_camera_1_id": 1, |
| "varied_camera_2_id": 1, |
| } |
|
|
| camera_type_to_string_dict = { |
| 0: "hand_camera", |
| 1: "varied_camera", |
| 2: "fixed_camera", |
| } |
|
|
|
|
| def get_camera_type(cam_id): |
| if cam_id not in camera_type_dict: |
| return None |
| type_int = camera_type_dict[cam_id] |
| return camera_type_to_string_dict[type_int] |
|
|
|
|
| class MP4Reader: |
| def __init__(self, filepath, serial_number): |
| |
| self.serial_number = serial_number |
| self._index = 0 |
|
|
| |
| self._mp4_reader = cv2.VideoCapture(filepath) |
| if not self._mp4_reader.isOpened(): |
| raise RuntimeError("Corrupted MP4 File") |
|
|
| def set_reading_parameters( |
| self, |
| image=True, |
| concatenate_images=False, |
| resolution=(0, 0), |
| resize_func=None, |
| ): |
| |
| self.image = image |
| self.concatenate_images = concatenate_images |
| self.resolution = resolution |
| self.resize_func = cv2.resize |
| self.skip_reading = not image |
| if self.skip_reading: |
| return |
|
|
| def get_frame_resolution(self): |
| width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) |
| height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) |
| return (width, height) |
|
|
| def get_frame_count(self): |
| if self.skip_reading: |
| return 0 |
| return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) |
|
|
| def set_frame_index(self, index): |
| if self.skip_reading: |
| return |
|
|
| if index < self._index: |
| self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1) |
| self._index = index |
|
|
| while self._index < index: |
| self.read_camera(ignore_data=True) |
|
|
| def _process_frame(self, frame): |
| frame = copy.deepcopy(frame) |
| if self.resolution == (0, 0): |
| return frame |
| return self.resize_func(frame, self.resolution) |
|
|
| def read_camera(self, ignore_data=False, correct_timestamp=None): |
| |
| if self.skip_reading: |
| return {} |
|
|
| |
| success, frame = self._mp4_reader.read() |
|
|
| self._index += 1 |
| if not success: |
| return None |
| if ignore_data: |
| return None |
|
|
| |
| data_dict = {} |
|
|
| if self.concatenate_images or "stereo" not in self.serial_number: |
| data_dict["image"] = {self.serial_number: self._process_frame(frame)} |
| else: |
| single_width = frame.shape[1] // 2 |
| data_dict["image"] = { |
| self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]), |
| self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]), |
| } |
|
|
| return data_dict |
|
|
| def disable_camera(self): |
| if hasattr(self, "_mp4_reader"): |
| self._mp4_reader.release() |
|
|
|
|
| class RecordedMultiCameraWrapper: |
| def __init__(self, recording_folderpath, camera_kwargs={}): |
| |
| self.camera_kwargs = camera_kwargs |
|
|
| |
| mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4") |
| all_filepaths = mp4_filepaths |
|
|
| self.camera_dict = {} |
| for f in all_filepaths: |
| serial_number = f.split("/")[-1][:-4] |
| cam_type = get_camera_type(serial_number) |
| camera_kwargs.get(cam_type, {}) |
|
|
| if f.endswith(".mp4"): |
| Reader = MP4Reader |
| else: |
| raise ValueError |
|
|
| self.camera_dict[serial_number] = Reader(f, serial_number) |
|
|
| def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): |
| full_obs_dict = defaultdict(dict) |
|
|
| |
| all_cam_ids = list(self.camera_dict.keys()) |
| |
|
|
| for cam_id in all_cam_ids: |
| if "stereo" in cam_id: |
| continue |
| try: |
| cam_type = camera_type_dict[cam_id] |
| except KeyError: |
| print(f"{self.camera_dict} -- {camera_type_dict}") |
| raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") |
| curr_cam_kwargs = self.camera_kwargs.get(cam_type, {}) |
| self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs) |
|
|
| timestamp = timestamp_dict.get(cam_id + "_frame_received", None) |
| if index is not None: |
| self.camera_dict[cam_id].set_frame_index(index) |
|
|
| data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp) |
|
|
| |
| if data_dict is None: |
| return None |
| for key in data_dict: |
| full_obs_dict[key].update(data_dict[key]) |
|
|
| return full_obs_dict |
|
|
|
|
| def get_hdf5_length(hdf5_file, keys_to_ignore=[]): |
| length = None |
|
|
| for key in hdf5_file: |
| if key in keys_to_ignore: |
| continue |
|
|
| curr_data = hdf5_file[key] |
| if isinstance(curr_data, h5py.Group): |
| curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore) |
| elif isinstance(curr_data, h5py.Dataset): |
| curr_length = len(curr_data) |
| else: |
| raise ValueError |
|
|
| if length is None: |
| length = curr_length |
| assert curr_length == length |
|
|
| return length |
|
|
|
|
| def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): |
| data_dict = {} |
|
|
| for key in hdf5_file: |
| if key in keys_to_ignore: |
| continue |
|
|
| curr_data = hdf5_file[key] |
| if isinstance(curr_data, h5py.Group): |
| data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore) |
| elif isinstance(curr_data, h5py.Dataset): |
| data_dict[key] = curr_data[index] |
| else: |
| raise ValueError |
|
|
| return data_dict |
|
|
|
|
| class TrajectoryReader: |
| def __init__(self, filepath, read_images=True): |
| self._hdf5_file = h5py.File(filepath, "r") |
| is_video_folder = "observations/videos" in self._hdf5_file |
| self._read_images = read_images and is_video_folder |
| self._length = get_hdf5_length(self._hdf5_file) |
| self._video_readers = {} |
| self._index = 0 |
|
|
| def length(self): |
| return self._length |
|
|
| def read_timestep(self, index=None, keys_to_ignore=[]): |
| |
| if index is None: |
| index = self._index |
| else: |
| assert not self._read_images |
| self._index = index |
| assert index < self._length |
|
|
| |
| keys_to_ignore = [*keys_to_ignore.copy(), "videos"] |
| timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore) |
|
|
| |
| self._index += 1 |
|
|
| |
| return timestep |
|
|
| def close(self): |
| self._hdf5_file.close() |
|
|
|
|
| def load_trajectory( |
| filepath=None, |
| read_cameras=True, |
| recording_folderpath=None, |
| camera_kwargs={}, |
| remove_skipped_steps=False, |
| num_samples_per_traj=None, |
| num_samples_per_traj_coeff=1.5, |
| ): |
| read_recording_folderpath = read_cameras and (recording_folderpath is not None) |
|
|
| traj_reader = TrajectoryReader(filepath) |
| if read_recording_folderpath: |
| camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs) |
|
|
| horizon = traj_reader.length() |
| timestep_list = [] |
|
|
| |
| if num_samples_per_traj: |
| num_to_save = num_samples_per_traj |
| if remove_skipped_steps: |
| num_to_save = int(num_to_save * num_samples_per_traj_coeff) |
| max_size = min(num_to_save, horizon) |
| indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False)) |
| else: |
| indices_to_save = np.arange(horizon) |
|
|
| |
| for i in indices_to_save: |
| |
| timestep = traj_reader.read_timestep(index=i) |
|
|
| |
| if read_recording_folderpath: |
| timestamp_dict = timestep["observation"]["timestamp"]["cameras"] |
| camera_type_dict = { |
| k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items() |
| } |
| camera_obs = camera_reader.read_cameras( |
| index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict |
| ) |
| camera_failed = camera_obs is None |
|
|
| |
| if camera_failed: |
| break |
| timestep["observation"].update(camera_obs) |
|
|
| |
| step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True) |
| delete_skipped_step = step_skipped and remove_skipped_steps |
|
|
| |
| if delete_skipped_step: |
| del timestep |
| else: |
| timestep_list.append(timestep) |
|
|
| |
| timestep_list = np.array(timestep_list) |
| if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj): |
| ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False) |
| timestep_list = timestep_list[ind_to_keep] |
|
|
| |
| traj_reader.close() |
|
|
| |
| return timestep_list |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|