VLAdaptorBench / external /rlbench /tools /cinematic_recorder.py
lsnu's picture
Add files using upload-large-folder tool
a32fcea verified
import os
from typing import Type
import numpy as np
from absl import app
from absl import flags
from pyrep.objects.dummy import Dummy
from pyrep.objects.vision_sensor import VisionSensor
from rlbench import Environment
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.observation import Observation
from rlbench.backend.task import TASKS_PATH
from rlbench.backend.task import Task
from rlbench.backend.utils import task_file_to_task_class
from rlbench.observation_config import ObservationConfig
from rlbench.sim2real.domain_randomization import RandomizeEvery, \
VisualRandomizationConfig
FLAGS = flags.FLAGS
flags.DEFINE_string(
'save_dir', '/tmp/rlbench_videos/',
'Where to save the generated videos.')
flags.DEFINE_list(
'tasks', [], 'The tasks to record. If empty, all tasks are recorded.')
flags.DEFINE_boolean(
'individual', True, 'One long clip of all the tasks, or individual videos.')
flags.DEFINE_boolean(
'domain_randomization', False, 'If domain randomization should be applied.')
flags.DEFINE_string(
'textures_path', '',
'Where to locate textures if using domain randomization.')
flags.DEFINE_boolean('headless', True, 'Run in headless mode.')
flags.DEFINE_list(
'camera_resolution', [1280, 720], 'The camera resolution')
class CameraMotion(object):
def __init__(self, cam: VisionSensor):
self.cam = cam
def step(self):
raise NotImplementedError()
def save_pose(self):
self._prev_pose = self.cam.get_pose()
def restore_pose(self):
self.cam.set_pose(self._prev_pose)
class CircleCameraMotion(CameraMotion):
def __init__(self, cam: VisionSensor, origin: Dummy,
speed: float, init_rotation: float = np.deg2rad(0)):
super().__init__(cam)
self.origin = origin
self.speed = speed # in radians
self.origin.rotate([0, 0, init_rotation])
def step(self):
self.origin.rotate([0, 0, self.speed])
class TaskRecorder(object):
def __init__(self, env: Environment, cam_motion: CameraMotion, fps=30):
self._env = env
self._cam_motion = cam_motion
self._fps = fps
self._snaps = []
self._current_snaps = []
def take_snap(self, obs: Observation):
self._cam_motion.step()
self._current_snaps.append(
(self._cam_motion.cam.capture_rgb() * 255.).astype(np.uint8))
def record_task(self, task: Type[Task]):
task = self._env.get_task(task)
self._cam_motion.save_pose()
while True:
try:
task.get_demos(
1, live_demos=True, callable_each_step=self.take_snap,
max_attempts=1)
break
except RuntimeError:
self._cam_motion.restore_pose()
self._current_snaps = []
self._snaps.extend(self._current_snaps)
self._current_snaps = []
return True
def save(self, path):
print('Converting to video ...')
os.makedirs(os.path.dirname(path), exist_ok=True)
# OpenCV QT version can conflict with PyRep, so import here
import cv2
video = cv2.VideoWriter(
path, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), self._fps,
tuple(self._cam_motion.cam.get_resolution()))
for image in self._snaps:
video.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
video.release()
self._snaps = []
def main(argv):
obs_config = ObservationConfig(record_gripper_closing=True)
obs_config.set_all(False)
vrc = rand_every = None
frequency = 0
if FLAGS.domain_randomization:
vrc = VisualRandomizationConfig(FLAGS.textures_path)
rand_every = RandomizeEvery.TRANSITION
frequency = 10
action_mode = MoveArmThenGripper(
arm_action_mode=JointVelocity(), gripper_action_mode=Discrete())
env = Environment(action_mode, obs_config=obs_config,
randomize_every=rand_every, frequency=frequency,
visual_randomization_config=vrc, headless=FLAGS.headless)
env.launch()
# Add the camera to the scene
cam_placeholder = Dummy('cam_cinematic_placeholder')
cam = VisionSensor.create(FLAGS.camera_resolution)
cam.set_pose(cam_placeholder.get_pose())
cam.set_parent(cam_placeholder)
cam_motion = CircleCameraMotion(cam, Dummy('cam_cinematic_base'), 0.005)
tr = TaskRecorder(env, cam_motion, fps=30)
if len(FLAGS.tasks) > 0:
task_names = FLAGS.tasks
else:
task_names = [t.split('.py')[0] for t in os.listdir(TASKS_PATH)
if t != '__init__.py' and t.endswith('.py')]
task_classes = [task_file_to_task_class(
task_file) for task_file in task_names]
for i, (name, cls) in enumerate(zip(task_names, task_classes)):
good = tr.record_task(cls)
if FLAGS.individual and good:
tr.save(os.path.join(FLAGS.save_dir, '%s.avi' % name))
if not FLAGS.individual:
tr.save(os.path.join(FLAGS.save_dir, 'recorded_tasks.mp4'))
env.shutdown()
if __name__ == '__main__':
app.run(main)