VLAdaptorBench / external /rlbench /tools /dataset_generator.py
lsnu's picture
Add files using upload-large-folder tool
a32fcea verified
from multiprocessing import Process, Manager
from pyrep.const import RenderMode
from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task
import os
import pickle
from PIL import Image
from rlbench.backend import utils
from rlbench.backend.const import *
import numpy as np
from absl import app
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_string('save_path',
'/tmp/rlbench_data/',
'Where to save the demos.')
flags.DEFINE_list('tasks', [],
'The tasks to collect. If empty, all tasks are collected.')
flags.DEFINE_list('image_size', [128, 128],
'The size of the images tp save.')
flags.DEFINE_enum('renderer', 'opengl3', ['opengl', 'opengl3'],
'The renderer to use. opengl does not include shadows, '
'but is faster.')
flags.DEFINE_integer('processes', 1,
'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
'The number of episodes to collect per task.')
flags.DEFINE_integer('variations', -1,
'Number of variations to collect per task. -1 for all.')
flags.DEFINE_bool('all_variations', True,
'Include all variations when sampling epsiodes')
def check_and_make(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def save_demo(demo, example_path, variation):
# Save image data first, and then None the image data, and pickle
left_shoulder_rgb_path = os.path.join(
example_path, LEFT_SHOULDER_RGB_FOLDER)
left_shoulder_depth_path = os.path.join(
example_path, LEFT_SHOULDER_DEPTH_FOLDER)
left_shoulder_mask_path = os.path.join(
example_path, LEFT_SHOULDER_MASK_FOLDER)
right_shoulder_rgb_path = os.path.join(
example_path, RIGHT_SHOULDER_RGB_FOLDER)
right_shoulder_depth_path = os.path.join(
example_path, RIGHT_SHOULDER_DEPTH_FOLDER)
right_shoulder_mask_path = os.path.join(
example_path, RIGHT_SHOULDER_MASK_FOLDER)
overhead_rgb_path = os.path.join(
example_path, OVERHEAD_RGB_FOLDER)
overhead_depth_path = os.path.join(
example_path, OVERHEAD_DEPTH_FOLDER)
overhead_mask_path = os.path.join(
example_path, OVERHEAD_MASK_FOLDER)
wrist_rgb_path = os.path.join(example_path, WRIST_RGB_FOLDER)
wrist_depth_path = os.path.join(example_path, WRIST_DEPTH_FOLDER)
wrist_mask_path = os.path.join(example_path, WRIST_MASK_FOLDER)
front_rgb_path = os.path.join(example_path, FRONT_RGB_FOLDER)
front_depth_path = os.path.join(example_path, FRONT_DEPTH_FOLDER)
front_mask_path = os.path.join(example_path, FRONT_MASK_FOLDER)
check_and_make(left_shoulder_rgb_path)
check_and_make(left_shoulder_depth_path)
check_and_make(left_shoulder_mask_path)
check_and_make(right_shoulder_rgb_path)
check_and_make(right_shoulder_depth_path)
check_and_make(right_shoulder_mask_path)
check_and_make(overhead_rgb_path)
check_and_make(overhead_depth_path)
check_and_make(overhead_mask_path)
check_and_make(wrist_rgb_path)
check_and_make(wrist_depth_path)
check_and_make(wrist_mask_path)
check_and_make(front_rgb_path)
check_and_make(front_depth_path)
check_and_make(front_mask_path)
for i, obs in enumerate(demo):
left_shoulder_rgb = Image.fromarray(obs.left_shoulder_rgb)
left_shoulder_depth = utils.float_array_to_rgb_image(
obs.left_shoulder_depth, scale_factor=DEPTH_SCALE)
left_shoulder_mask = Image.fromarray(
(obs.left_shoulder_mask * 255).astype(np.uint8))
right_shoulder_rgb = Image.fromarray(obs.right_shoulder_rgb)
right_shoulder_depth = utils.float_array_to_rgb_image(
obs.right_shoulder_depth, scale_factor=DEPTH_SCALE)
right_shoulder_mask = Image.fromarray(
(obs.right_shoulder_mask * 255).astype(np.uint8))
overhead_rgb = Image.fromarray(obs.overhead_rgb)
overhead_depth = utils.float_array_to_rgb_image(
obs.overhead_depth, scale_factor=DEPTH_SCALE)
overhead_mask = Image.fromarray(
(obs.overhead_mask * 255).astype(np.uint8))
wrist_rgb = Image.fromarray(obs.wrist_rgb)
wrist_depth = utils.float_array_to_rgb_image(
obs.wrist_depth, scale_factor=DEPTH_SCALE)
wrist_mask = Image.fromarray((obs.wrist_mask * 255).astype(np.uint8))
front_rgb = Image.fromarray(obs.front_rgb)
front_depth = utils.float_array_to_rgb_image(
obs.front_depth, scale_factor=DEPTH_SCALE)
front_mask = Image.fromarray((obs.front_mask * 255).astype(np.uint8))
left_shoulder_rgb.save(
os.path.join(left_shoulder_rgb_path, IMAGE_FORMAT % i))
left_shoulder_depth.save(
os.path.join(left_shoulder_depth_path, IMAGE_FORMAT % i))
left_shoulder_mask.save(
os.path.join(left_shoulder_mask_path, IMAGE_FORMAT % i))
right_shoulder_rgb.save(
os.path.join(right_shoulder_rgb_path, IMAGE_FORMAT % i))
right_shoulder_depth.save(
os.path.join(right_shoulder_depth_path, IMAGE_FORMAT % i))
right_shoulder_mask.save(
os.path.join(right_shoulder_mask_path, IMAGE_FORMAT % i))
overhead_rgb.save(
os.path.join(overhead_rgb_path, IMAGE_FORMAT % i))
overhead_depth.save(
os.path.join(overhead_depth_path, IMAGE_FORMAT % i))
overhead_mask.save(
os.path.join(overhead_mask_path, IMAGE_FORMAT % i))
wrist_rgb.save(os.path.join(wrist_rgb_path, IMAGE_FORMAT % i))
wrist_depth.save(os.path.join(wrist_depth_path, IMAGE_FORMAT % i))
wrist_mask.save(os.path.join(wrist_mask_path, IMAGE_FORMAT % i))
front_rgb.save(os.path.join(front_rgb_path, IMAGE_FORMAT % i))
front_depth.save(os.path.join(front_depth_path, IMAGE_FORMAT % i))
front_mask.save(os.path.join(front_mask_path, IMAGE_FORMAT % i))
# We save the images separately, so set these to None for pickling.
obs.left_shoulder_rgb = None
obs.left_shoulder_depth = None
obs.left_shoulder_point_cloud = None
obs.left_shoulder_mask = None
obs.right_shoulder_rgb = None
obs.right_shoulder_depth = None
obs.right_shoulder_point_cloud = None
obs.right_shoulder_mask = None
obs.overhead_rgb = None
obs.overhead_depth = None
obs.overhead_point_cloud = None
obs.overhead_mask = None
obs.wrist_rgb = None
obs.wrist_depth = None
obs.wrist_point_cloud = None
obs.wrist_mask = None
obs.front_rgb = None
obs.front_depth = None
obs.front_point_cloud = None
obs.front_mask = None
# Save the low-dimension data
with open(os.path.join(example_path, LOW_DIM_PICKLE), 'wb') as f:
pickle.dump(demo, f)
with open(os.path.join(example_path, VARIATION_NUMBER), 'wb') as f:
pickle.dump(variation, f)
def run(i, lock, task_index, variation_count, results, file_lock, tasks):
"""Each thread will choose one task and variation, and then gather
all the episodes_per_task for that variation."""
# Initialise each thread with random seed
np.random.seed(None)
num_tasks = len(tasks)
img_size = list(map(int, FLAGS.image_size))
obs_config = ObservationConfig()
obs_config.set_all(True)
obs_config.right_shoulder_camera.image_size = img_size
obs_config.left_shoulder_camera.image_size = img_size
obs_config.overhead_camera.image_size = img_size
obs_config.wrist_camera.image_size = img_size
obs_config.front_camera.image_size = img_size
# Store depth as 0 - 1
obs_config.right_shoulder_camera.depth_in_meters = False
obs_config.left_shoulder_camera.depth_in_meters = False
obs_config.overhead_camera.depth_in_meters = False
obs_config.wrist_camera.depth_in_meters = False
obs_config.front_camera.depth_in_meters = False
# We want to save the masks as rgb encodings.
obs_config.left_shoulder_camera.masks_as_one_channel = False
obs_config.right_shoulder_camera.masks_as_one_channel = False
obs_config.overhead_camera.masks_as_one_channel = False
obs_config.wrist_camera.masks_as_one_channel = False
obs_config.front_camera.masks_as_one_channel = False
if FLAGS.renderer == 'opengl':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.overhead_camera.render_mode = RenderMode.OPENGL
obs_config.wrist_camera.render_mode = RenderMode.OPENGL
obs_config.front_camera.render_mode = RenderMode.OPENGL
rlbench_env = Environment(
action_mode=MoveArmThenGripper(JointVelocity(), Discrete()),
obs_config=obs_config,
headless=True)
rlbench_env.launch()
task_env = None
tasks_with_problems = results[i] = ''
while True:
# Figure out what task/variation this thread is going to do
with lock:
if task_index.value >= num_tasks:
print('Process', i, 'finished')
break
my_variation_count = variation_count.value
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
var_target = task_env.variation_count()
if FLAGS.variations >= 0:
var_target = np.minimum(FLAGS.variations, var_target)
if my_variation_count >= var_target:
# If we have reached the required number of variations for this
# task, then move on to the next task.
variation_count.value = my_variation_count = 0
task_index.value += 1
variation_count.value += 1
if task_index.value >= num_tasks:
print('Process', i, 'finished')
break
t = tasks[task_index.value]
variation_path = os.path.join(
FLAGS.save_path, task_env.get_name(),
VARIATIONS_FOLDER % my_variation_count)
check_and_make(variation_path)
episodes_path = os.path.join(variation_path, EPISODES_FOLDER)
check_and_make(episodes_path)
abort_variation = False
for ex_idx in range(FLAGS.episodes_per_task):
print('Process', i, '// Task:', task_env.get_name(),
'// Variation:', my_variation_count, '// Demo:', ex_idx)
attempts = 10
while attempts > 0:
try:
task_env = rlbench_env.get_task(t)
task_env.set_variation(my_variation_count)
descriptions, obs = task_env.reset()
# TODO: for now we do the explicit looping.
demo, = task_env.get_demos(
amount=1,
live_demos=True)
except Exception as e:
attempts -= 1
if attempts > 0:
continue
problem = (
'Process %d failed collecting task %s (variation: %d, '
'example: %d). Skipping this task/variation.\n%s\n' % (
i, task_env.get_name(), my_variation_count, ex_idx,
str(e))
)
print(problem)
tasks_with_problems += problem
abort_variation = True
break
episode_path = os.path.join(episodes_path, EPISODE_FOLDER % ex_idx)
with file_lock:
save_demo(demo, episode_path, my_variation_count)
with open(os.path.join(
episode_path, VARIATION_DESCRIPTIONS), 'wb') as f:
pickle.dump(descriptions, f)
break
if abort_variation:
break
results[i] = tasks_with_problems
rlbench_env.shutdown()
def run_all_variations(i, lock, task_index, variation_count, results, file_lock, tasks):
"""Each thread will choose one task and variation, and then gather
all the episodes_per_task for that variation."""
# Initialise each thread with random seed
np.random.seed(None)
num_tasks = len(tasks)
img_size = list(map(int, FLAGS.image_size))
obs_config = ObservationConfig()
obs_config.set_all(True)
obs_config.right_shoulder_camera.image_size = img_size
obs_config.left_shoulder_camera.image_size = img_size
obs_config.overhead_camera.image_size = img_size
obs_config.wrist_camera.image_size = img_size
obs_config.front_camera.image_size = img_size
# Store depth as 0 - 1
obs_config.right_shoulder_camera.depth_in_meters = False
obs_config.left_shoulder_camera.depth_in_meters = False
obs_config.overhead_camera.depth_in_meters = False
obs_config.wrist_camera.depth_in_meters = False
obs_config.front_camera.depth_in_meters = False
# We want to save the masks as rgb encodings.
obs_config.left_shoulder_camera.masks_as_one_channel = False
obs_config.right_shoulder_camera.masks_as_one_channel = False
obs_config.overhead_camera.masks_as_one_channel = False
obs_config.wrist_camera.masks_as_one_channel = False
obs_config.front_camera.masks_as_one_channel = False
if FLAGS.renderer == 'opengl':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.overhead_camera.render_mode = RenderMode.OPENGL
obs_config.wrist_camera.render_mode = RenderMode.OPENGL
obs_config.front_camera.render_mode = RenderMode.OPENGL
rlbench_env = Environment(
action_mode=MoveArmThenGripper(JointVelocity(), Discrete()),
obs_config=obs_config,
headless=True)
rlbench_env.launch()
task_env = None
tasks_with_problems = results[i] = ''
while True:
# with lock:
if task_index.value >= num_tasks:
print('Process', i, 'finished')
break
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
possible_variations = task_env.variation_count()
variation_path = os.path.join(
FLAGS.save_path, task_env.get_name(),
VARIATIONS_ALL_FOLDER)
check_and_make(variation_path)
episodes_path = os.path.join(variation_path, EPISODES_FOLDER)
check_and_make(episodes_path)
abort_variation = False
for ex_idx in range(FLAGS.episodes_per_task):
attempts = 10
while attempts > 0:
try:
variation = np.random.randint(possible_variations)
task_env = rlbench_env.get_task(t)
task_env.set_variation(variation)
descriptions, obs = task_env.reset()
print('Process', i, '// Task:', task_env.get_name(),
'// Variation:', variation, '// Demo:', ex_idx)
# TODO: for now we do the explicit looping.
demo, = task_env.get_demos(
amount=1,
live_demos=True)
except Exception as e:
attempts -= 1
if attempts > 0:
continue
problem = (
'Process %d failed collecting task %s (variation: %d, '
'example: %d). Skipping this task/variation.\n%s\n' % (
i, task_env.get_name(), variation, ex_idx,
str(e))
)
print(problem)
tasks_with_problems += problem
abort_variation = True
break
episode_path = os.path.join(episodes_path, EPISODE_FOLDER % ex_idx)
with file_lock:
save_demo(demo, episode_path, variation)
with open(os.path.join(
episode_path, VARIATION_DESCRIPTIONS), 'wb') as f:
pickle.dump(descriptions, f)
break
if abort_variation:
break
# with lock:
task_index.value += 1
results[i] = tasks_with_problems
rlbench_env.shutdown()
def main(argv):
task_files = [t.replace('.py', '') for t in os.listdir(task.TASKS_PATH)
if t != '__init__.py' and t.endswith('.py')]
if len(FLAGS.tasks) > 0:
for t in FLAGS.tasks:
if t not in task_files:
raise ValueError('Task %s not recognised!.' % t)
task_files = FLAGS.tasks
tasks = [task_file_to_task_class(t) for t in task_files]
manager = Manager()
result_dict = manager.dict()
file_lock = manager.Lock()
task_index = manager.Value('i', 0)
variation_count = manager.Value('i', 0)
lock = manager.Lock()
check_and_make(FLAGS.save_path)
if FLAGS.all_variations:
# multiprocessing for all_variations not support (for now)
run_all_variations(0, lock, task_index, variation_count, result_dict, file_lock, tasks)
else:
processes = [Process(
target=run, args=(
i, lock, task_index, variation_count, result_dict, file_lock,
tasks))
for i in range(FLAGS.processes)]
[t.start() for t in processes]
[t.join() for t in processes]
print('Data collection done!')
for i in range(FLAGS.processes):
print(result_dict[i])
if __name__ == '__main__':
app.run(main)