| import gradio as gr |
| import sys |
| import site |
| from PIL import Image |
| from pathlib import Path |
| from omegaconf import DictConfig, OmegaConf |
| from tqdm import tqdm, trange |
| import random |
| import math |
| import hydra |
| import numpy as np |
| import glob |
| import os |
| import subprocess |
| import time |
| import cv2 |
| import copy |
| import yaml |
| import matplotlib.pyplot as plt |
| from sklearn.neighbors import NearestNeighbors |
|
|
| import spaces |
| from spaces import zero |
| zero.startup() |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| import kornia |
|
|
| from diff_gaussian_rasterization import GaussianRasterizer |
| from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).parent / "src")) |
| sys.path.append(str(Path(__file__).parent / "src" / "experiments")) |
|
|
| from real_world.utils.render_utils import interpolate_motions |
| from real_world.gs.helpers import setup_camera |
| from real_world.gs.convert import save_to_splat, read_splat |
|
|
| root = Path(__file__).parent / "src" / "experiments" |
|
|
| def make_video( |
| image_root: Path, |
| video_path: Path, |
| image_pattern: str = '%04d.png', |
| frame_rate: int = 10): |
|
|
| subprocess.run([ |
| 'ffmpeg', |
| '-y', |
| '-hide_banner', |
| '-loglevel', 'error', |
| '-framerate', str(frame_rate), |
| '-i', str(image_root / image_pattern), |
| '-c:v', 'libx264', |
| '-pix_fmt', 'yuv420p', |
| str(video_path) |
| ]) |
|
|
| def quat2mat(quat): |
| import kornia |
| return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat) |
|
|
|
|
| def mat2quat(mat): |
| import kornia |
| return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat) |
|
|
|
|
| def fps(x, enabled, n, device, random_start=False): |
| import torch |
| from dgl.geometry import farthest_point_sampler |
| assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0] |
| start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0 |
| fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0] |
| fps_idx = fps_idx.to(x.device) |
| return fps_idx |
|
|
|
|
| class DynamicsVisualizer: |
|
|
| def __init__(self, wp_device='cuda', torch_device='cuda'): |
| |
| self.best_models = { |
| 'cloth': ['cloth', 'train', 100000, [610, 650]], |
| 'rope': ['rope', 'train', 100000, [651, 691]], |
| 'paperbag': ['paperbag', 'train', 100000, [200, 220]], |
| 'sloth': ['sloth', 'train', 100000, [113, 133]], |
| 'box': ['box', 'train', 100000, [306, 323]], |
| 'bread': ['bread', 'train', 100000, [143, 163]], |
| } |
| task_name = 'rope' |
| self.init(task_name) |
|
|
| def init(self, task_name): |
| self.width = 640 |
| self.height = 480 |
| self.task_name = task_name |
|
|
| with open(root / f'log/{self.best_models[task_name][0]}/{self.best_models[task_name][1]}/hydra.yaml', 'r') as f: |
| config = yaml.load(f, Loader=yaml.CLoader) |
| cfg = OmegaConf.create(config) |
|
|
| cfg.iteration = self.best_models[task_name][2] |
| cfg.start_episode = self.best_models[task_name][3][0] |
| cfg.end_episode = self.best_models[task_name][3][1] |
| cfg.sim.num_steps = 1000 |
| cfg.sim.gripper_forcing = False |
| cfg.sim.uniform = True |
| cfg.sim.use_pv = False |
|
|
| device = torch.device('cuda') |
|
|
| self.cfg = cfg |
| self.device = device |
| self.k_rel = 8 |
| self.k_wgt = 16 |
| self.with_bg = True |
| self.render_gripper = True |
| self.render_direction = True |
| self.verbose = False |
|
|
| self.dt_base = cfg.sim.dt |
| self.high_freq_pred = True |
|
|
| seed = cfg.seed |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| |
| |
|
|
| self.clear() |
|
|
| def clear(self, clear_params=True): |
| self.metadata = {} |
| self.config = {} |
| if clear_params: |
| self.params = None |
| self.state = { |
| |
| 'x': None, |
| 'v': None, |
| 'x_his': None, |
| 'v_his': None, |
| 'x_pred': None, |
| 'v_pred': None, |
| 'clip_bound': None, |
| 'enabled': None, |
| |
| 'prev_key_pos': None, |
| 'prev_key_pos_timestamp': None, |
| 'sub_pos': None, |
| 'sub_pos_timestamps': None, |
| 'gripper_radius': None, |
| } |
| self.preprocess_metadata = None |
| self.table_params = None |
| self.gripper_params = None |
|
|
| self.sim = None |
| self.statics = None |
| self.colliders = None |
| self.material = None |
| self.friction = None |
|
|
| def load_scaniverse(self, data_path): |
|
|
| |
|
|
| params_obj = read_splat(data_path / 'object.splat') |
| params_table = read_splat(data_path / 'table.splat') |
| params_robot = read_splat(data_path / 'gripper.splat') |
|
|
| pts, colors, scales, quats, opacities = params_obj |
| self.params = { |
| 'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), |
| 'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), |
| 'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), |
| 'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), |
| 'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) |
| } |
|
|
| t_pts, t_colors, t_scales, t_quats, t_opacities = params_table |
| t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) |
| t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) |
| t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) |
| t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) |
| t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) |
|
|
| g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot |
| g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) |
| g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) |
| g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) |
| g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) |
| g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) |
|
|
| self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
| self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
| n_particles = self.cfg.sim.n_particles |
| self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) |
| self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) |
|
|
| |
|
|
| cfg = self.cfg |
| dx = cfg.sim.num_grids[-1] |
|
|
| p_x = torch.tensor(pts).to(torch.float32).to(self.device) |
| R = torch.tensor( |
| [[1, 0, 0], |
| [0, 0, -1], |
| [0, 1, 0]] |
| ).to(p_x.device).to(p_x.dtype) |
| p_x_rotated = p_x @ R.T |
|
|
| scale = 1.0 |
| p_x_rotated_scaled = p_x_rotated * scale |
|
|
| global_translation = torch.tensor([ |
| 0.5 - p_x_rotated_scaled[:, 0].mean(), |
| dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(), |
| 0.5 - p_x_rotated_scaled[:, 2].mean(), |
| ], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device) |
|
|
| R_viewer = torch.tensor( |
| [[1, 0, 0], |
| [0, 0, -1], |
| [0, 1, 0]] |
| ).to(p_x.device).to(p_x.dtype) |
| t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype) |
|
|
| self.preprocess_metadata = { |
| 'R': R, |
| 'R_viewer': R_viewer, |
| 't_viewer': t_viewer, |
| 'scale': scale, |
| 'global_translation': global_translation, |
| } |
|
|
| |
| grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None] |
| assert grippers.shape == (1, 3) |
| |
| if grippers is not None: |
| grippers = torch.tensor(grippers).to(self.device).to(torch.float32) |
|
|
| |
| |
| R = self.preprocess_metadata['R'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
| grippers[:, :3] = grippers[:, :3] @ R.T |
| grippers[:, :3] = grippers[:, :3] * scale |
| grippers[:, :3] += global_translation |
|
|
| assert grippers.shape[0] == 1 |
| self.state['prev_key_pos'] = grippers[:, :3] |
| |
| self.state['gripper_radius'] = cfg.model.gripper_radius |
| |
| def load_eef(self, grippers=None, eef_t=None): |
| assert self.state['prev_key_pos'] is None |
|
|
| if grippers is not None: |
| grippers = torch.tensor(grippers).to(self.device).to(torch.float32) |
| eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32) |
| grippers[:, :3] = grippers[:, :3] + eef_t |
|
|
| |
| |
| R = self.preprocess_metadata['R'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
| grippers[:, :3] = grippers[:, :3] @ R.T |
| grippers[:, :3] = grippers[:, :3] * scale |
| grippers[:, :3] += global_translation |
|
|
| assert grippers.shape[0] == 1 |
| self.state['prev_key_pos'] = grippers[:, :3] |
| |
| self.state['gripper_radius'] = self.cfg.model.gripper_radius |
|
|
| def load_preprocess_metadata(self, p_x_orig): |
| cfg = self.cfg |
| dx = cfg.sim.num_grids[-1] |
|
|
| p_x_orig = p_x_orig.to(self.device) |
| R = torch.tensor( |
| [[1, 0, 0], |
| [0, 0, -1], |
| [0, 1, 0]] |
| ).to(p_x_orig.device).to(p_x_orig.dtype) |
| p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T) |
|
|
| scale = 1.0 |
| p_x_orig_rotated_scaled = p_x_orig_rotated * scale |
|
|
| global_translation = torch.tensor([ |
| 0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(), |
| dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(), |
| 0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(), |
| ], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device) |
|
|
| R_viewer = torch.tensor( |
| [[1, 0, 0], |
| [0, 0, -1], |
| [0, 1, 0]] |
| ).to(p_x_orig.device).to(p_x_orig.dtype) |
| t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype) |
|
|
| self.preprocess_metadata = { |
| 'R': R, |
| 'R_viewer': R_viewer, |
| 't_viewer': t_viewer, |
| 'scale': scale, |
| 'global_translation': global_translation, |
| } |
|
|
| |
| def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]): |
| render_data = {k: v.to(self.device) for k, v in render_data.items()} |
| w, h = self.metadata['w'], self.metadata['h'] |
| k, w2c = self.metadata['k'], self.metadata['w2c'] |
| cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg) |
| im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data) |
| return im, depth |
| |
| def knn_relations(self, bones): |
| k = self.k_rel |
| knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy()) |
| _, indices = knn.kneighbors(bones.detach().cpu().numpy()) |
| indices = indices[:, 1:] |
| return indices |
| |
| def knn_weights_brute(self, bones, pts): |
| k = self.k_wgt |
| dist = torch.norm(pts[:, None] - bones, dim=-1) |
| _, indices = torch.topk(dist, k, dim=-1, largest=False) |
| bones_selected = bones[indices] |
| dist = torch.norm(bones_selected - pts[:, None], dim=-1) |
| weights = 1 / (dist + 1e-6) |
| weights = weights / weights.sum(dim=-1, keepdim=True) |
| weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device) |
| weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights |
| return weights_all |
| |
| def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0): |
| self.metadata['k'] = k |
| self.metadata['w2c'] = w2c |
| if w is not None: |
| self.metadata['w'] = w |
| if h is not None: |
| self.metadata['h'] = h |
| self.config['near'] = near |
| self.config['far'] = far |
| |
| def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None): |
| from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch |
| from pgnd.material import PGNDModel |
|
|
| self.cfg.sim.num_steps = num_steps |
| cfg = self.cfg |
|
|
| sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True) |
|
|
| statics = StaticsBatch() |
| statics.init(shape=(batch_size, num_particles), device=self.wp_device) |
| statics.update_clip_bound(self.state['clip_bound'].detach().cpu()) |
| statics.update_enabled(self.state['enabled'][None].detach().cpu()) |
| colliders = CollidersBatch() |
| colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device) |
|
|
| self.sim = sim |
| self.statics = statics |
| self.colliders = colliders |
| |
| |
| ckpt_path = root / f'log/{self.task_name}/train/ckpt/100000.pt' |
| ckpt = torch.load(ckpt_path, map_location=self.torch_device) |
| |
| material: nn.Module = PGNDModel(cfg) |
| material.to(self.torch_device) |
| material.load_state_dict(ckpt['material']) |
| material.requires_grad_(False) |
| material.eval() |
| |
| if 'friction' in ckpt: |
| friction = ckpt['friction']['mu'].reshape(-1, 1) |
| else: |
| friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1) |
| |
| self.material = material |
| self.friction = friction |
|
|
| def reload_model(self, num_steps): |
| from pgnd.sim import CacheDiffSimWithFrictionBatch |
| self.cfg.sim.num_steps = num_steps |
| sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True) |
| self.sim = sim |
|
|
| |
| def step(self): |
| cfg = self.cfg |
| batch_size = 1 |
| num_steps = 1 |
| num_particles = cfg.sim.n_particles |
|
|
| |
| self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0) |
| self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0) |
| self.state['x'] = self.state['x_pred'].clone() |
| self.state['v'] = self.state['v_pred'].clone() |
| |
| eef_xyz_key = self.state['prev_key_pos'] |
| eef_xyz_sub = self.state['sub_pos'] |
| |
| if eef_xyz_sub is None: |
| return |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| eef_xyz_key_next = eef_xyz_sub[-1] |
| eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt |
| if self.verbose: |
| print('delta_t:', np.round(cfg.sim.dt, 4)) |
| print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist()) |
| print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist()) |
| print('v:', eef_v.cpu().numpy().tolist()) |
|
|
| |
| |
| |
| |
| if cfg.sim.num_grippers > 0: |
| grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device) |
| eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) |
| eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device) |
| eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device) |
| grippers[:, :, :3] = eef_xyz_key |
| grippers[:, :, 3:6] = eef_v |
| grippers[:, :, 6:10] = eef_quat |
| grippers[:, :, 10:13] = eef_quat_vel |
| grippers[:, :, 13] = cfg.model.gripper_radius |
| grippers[:, :, 14] = eef_gripper |
| self.colliders.initialize_grippers(grippers) |
|
|
| x = self.state['x'].clone()[None].repeat(batch_size, 1, 1) |
| v = self.state['v'].clone()[None].repeat(batch_size, 1, 1) |
| x_his = self.state['x_his'].permute(1, 0, 2).clone() |
| assert x_his.shape[0] == num_particles |
| x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) |
| v_his = self.state['v_his'].permute(1, 0, 2).clone() |
| assert v_his.shape[0] == num_particles |
| v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) |
| enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1) |
|
|
| for t in range(num_steps): |
| x_in = x.clone() |
| pred = self.material(x, v, x_his, v_his, enabled) |
| |
| |
| |
| |
| |
| |
| x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred) |
|
|
| |
| x_pred = x[0].clone() |
| v_pred = v[0].clone() |
| self.state['x_pred'] = x_pred |
| self.state['v_pred'] = v_pred |
| |
| |
|
|
| self.state['prev_key_pos'] = eef_xyz_key_next |
| |
| self.state['sub_pos'] = None |
| |
|
|
| def preprocess_x(self, p_x): |
| R = self.preprocess_metadata['R'] |
| R_viewer = self.preprocess_metadata['R_viewer'] |
| t_viewer = self.preprocess_metadata['t_viewer'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
|
|
| |
| p_x = (p_x - t_viewer) @ R_viewer |
|
|
| |
| |
| |
| |
|
|
| return p_x |
| |
| def preprocess_gripper(self, grippers): |
| R = self.preprocess_metadata['R'] |
| R_viewer = self.preprocess_metadata['R_viewer'] |
| t_viewer = self.preprocess_metadata['t_viewer'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
|
|
| |
| grippers[:, :3] = grippers[:, :3] @ R_viewer |
|
|
| return grippers |
| |
| def inverse_preprocess_x(self, p_x): |
| R = self.preprocess_metadata['R'] |
| R_viewer = self.preprocess_metadata['R_viewer'] |
| t_viewer = self.preprocess_metadata['t_viewer'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
|
|
| |
| p_x = p_x @ R_viewer.T + t_viewer |
|
|
| return p_x |
| |
| def inverse_preprocess_gripper(self, grippers): |
| R = self.preprocess_metadata['R'] |
| R_viewer = self.preprocess_metadata['R_viewer'] |
| t_viewer = self.preprocess_metadata['t_viewer'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
|
|
| |
| grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer |
|
|
| return grippers |
|
|
| def rotate(self, params, rot_mat): |
| scale = np.linalg.norm(rot_mat, axis=1, keepdims=True) |
| |
| params = { |
| 'means3D': pts, |
| 'rgb_colors': params['rgb_colors'], |
| 'log_scales': params['log_scales'], |
| 'unnorm_rotations': quats, |
| 'logit_opacities': params['logit_opacities'], |
| } |
| return params |
|
|
| def preprocess_gs(self, params): |
| if isinstance(params, dict): |
| xyz = params['means3D'] |
| rgb = params['rgb_colors'] |
| quat = torch.nn.functional.normalize(params['unnorm_rotations']) |
| opa = torch.sigmoid(params['logit_opacities']) |
| scales = torch.exp(params['log_scales']) |
| else: |
| assert isinstance(params, tuple) |
| xyz, rgb, quat, opa, scales = params |
| |
| quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
| |
| R = self.preprocess_metadata['R'] |
| R_viewer = self.preprocess_metadata['R_viewer'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
|
|
| mat = quat2mat(quat) |
| mat = R @ mat |
| xyz = xyz @ R.T |
| xyz = xyz * scale |
| xyz += global_translation |
| quat = mat2quat(mat) |
| scales = scales * scale |
|
|
| |
| |
| xyz = xyz @ R_viewer.T |
| quat = mat2quat(R_viewer @ quat2mat(quat)) |
|
|
| t_viewer = -xyz.mean(dim=0) |
| t_viewer[2] = 0 |
| xyz += t_viewer |
| print('Overwriting t_viewer to be the planar mean of the object') |
| self.preprocess_metadata['t_viewer'] = t_viewer |
|
|
| if isinstance(params, dict): |
| params['means3D'] = xyz |
| params['rgb_colors'] = rgb |
| params['unnorm_rotations'] = quat |
| params['logit_opacities'] = opa |
| params['log_scales'] = torch.log(scales) |
| else: |
| params = xyz, rgb, quat, opa, scales |
| |
| return params |
| |
| def preprocess_bg_gs(self): |
| t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params |
| g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params |
| |
| |
| g_pts_tip_z = g_pts[:, 2].max() |
| g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z) |
| |
| R = self.preprocess_metadata['R'] |
| R_viewer = self.preprocess_metadata['R_viewer'] |
| t_viewer = self.preprocess_metadata['t_viewer'] |
| scale = self.preprocess_metadata['scale'] |
| global_translation = self.preprocess_metadata['global_translation'] |
|
|
| t_mat = quat2mat(t_quats) |
| t_mat = R @ t_mat |
| t_pts = t_pts @ R.T |
| t_pts = t_pts * scale |
| t_pts += global_translation |
| t_quats = mat2quat(t_mat) |
| t_scales = t_scales * scale |
|
|
| t_pts = t_pts @ R_viewer.T |
| t_quats = mat2quat(R_viewer @ quat2mat(t_quats)) |
| t_pts += t_viewer |
| |
| axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] |
| dirs = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] |
| for ee in range(3): |
| gripper_direction = torch.tensor(dirs[ee], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
| gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) |
|
|
| R = self.preprocess_metadata['R'] |
| |
| direction = gripper_direction @ R.T |
|
|
| n_grippers = 1 |
| N = 200 |
| length = 0.2 |
| kk = 5 |
| xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) |
|
|
| if self.task_name == 'rope': |
| pos = torch.tensor([0.0, 0.0, 1.2], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
| else: |
| pos = torch.tensor([1.2, 0.0, 0.7], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
| gripper_now_inv_xyz = self.inverse_preprocess_gripper(pos) |
| gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1) |
|
|
| center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
| gripper_center_inv_xyz = gripper_now_inv_xyz + \ |
| torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) |
|
|
| for i in range(N): |
| offset = i / N * length * direction |
| xyz_test[:, i] = gripper_center_inv_xyz + offset |
| |
| if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: |
| direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=t_pts.dtype) |
| direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
| direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=t_pts.dtype) |
| direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
| else: |
| direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype) |
| direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
| direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype) |
| direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
|
|
| for i in range(N, N + N // kk): |
| offset = length * direction + (i - N) / N * length * direction_up |
| xyz_test[:, i] = gripper_center_inv_xyz + offset |
| |
| for i in range(N + N // kk, N + N // kk + N // kk): |
| offset = length * direction + (i - N - N // kk) / N * length * direction_down |
| xyz_test[:, i] = gripper_center_inv_xyz + offset |
| |
| color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=t_pts.dtype) |
| color_test[:, :, 0] = axes[ee][0] |
| color_test[:, :, 1] = axes[ee][1] |
| color_test[:, :, 2] = axes[ee][2] |
| quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=t_pts.dtype) |
| quat_test[:, :, 0] = 1.0 |
| opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=t_pts.dtype) |
| scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) * 0.002 |
| |
| t_pts = torch.cat([t_pts, xyz_test.reshape(-1, 3)], dim=0) |
| t_colors = torch.cat([t_colors, color_test.reshape(-1, 3)], dim=0) |
| t_quats = torch.cat([t_quats, quat_test.reshape(-1, 4)], dim=0) |
| t_opacities = torch.cat([t_opacities, opa_test.reshape(-1, 1)], dim=0) |
| t_scales = torch.cat([t_scales, scales_test.reshape(-1, 3)], dim=0) |
|
|
| t_pts = t_pts.reshape(-1, 3) |
| t_colors = t_colors.reshape(-1, 3) |
| t_quats = t_quats.reshape(-1, 4) |
| t_opacities = t_opacities.reshape(-1, 1) |
|
|
| g_mat = quat2mat(g_quats) |
| g_mat = R @ g_mat |
| g_pts = g_pts @ R.T |
| g_pts = g_pts * scale |
| g_pts += global_translation |
| g_quats = mat2quat(g_mat) |
| g_scales = g_scales * scale |
| |
| g_pts = g_pts @ R_viewer.T |
| g_quats = mat2quat(R_viewer @ quat2mat(g_quats)) |
| g_pts += t_viewer |
|
|
| |
| g_pts_tip = g_pts[g_pts_tip_mask] |
| g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0) |
| |
| if self.task_name == 'rope': |
| g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device) |
| elif self.task_name == 'sloth': |
| g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.32]).to(torch.float32).to(self.device) |
| else: |
| raise NotImplementedError(f"Task {self.task_name} not implemented for gripper translation.") |
| g_pts = g_pts + g_pts_translation |
|
|
| self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
| self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
| def update_rendervar(self, rendervar): |
| p_x = self.state['x'] |
| p_x_viewer = self.inverse_preprocess_x(p_x) |
|
|
| p_x_pred = self.state['x_pred'] |
| p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred) |
|
|
| xyz = rendervar['means3D'] |
| rgb = rendervar['colors_precomp'] |
| quat = rendervar['rotations'] |
| opa = rendervar['opacities'] |
| scales = rendervar['scales'] |
|
|
| relations = self.knn_relations(p_x_viewer) |
| weights = self.knn_weights_brute(p_x_viewer, xyz) |
| xyz, quat, _ = interpolate_motions( |
| bones=p_x_viewer, |
| motions=p_x_pred_viewer - p_x_viewer, |
| relations=relations, |
| weights=weights, |
| xyz=xyz, |
| quat=quat, |
| ) |
|
|
| |
| quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
| rendervar = { |
| 'means3D': xyz, |
| 'colors_precomp': rgb, |
| 'rotations': quat, |
| 'opacities': opa, |
| 'scales': scales, |
| 'means2D': torch.zeros_like(xyz), |
| } |
|
|
| if self.with_bg: |
| t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params |
| |
| |
| xyz = torch.cat([xyz, t_pts], dim=0) |
| rgb = torch.cat([rgb, t_colors], dim=0) |
| quat = torch.cat([quat, t_quats], dim=0) |
| opa = torch.cat([opa, t_opacities], dim=0) |
| scales = torch.cat([scales, t_scales], dim=0) |
|
|
| if self.render_gripper: |
| g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params |
|
|
| |
| g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0] |
| |
| |
| xyz = torch.cat([xyz, g_pts], dim=0) |
| rgb = torch.cat([rgb, g_colors], dim=0) |
| quat = torch.cat([quat, g_quats], dim=0) |
| opa = torch.cat([opa, g_opacities], dim=0) |
| scales = torch.cat([scales, g_scales], dim=0) |
|
|
| if self.render_direction: |
| gripper_direction = self.gripper_direction |
| gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) |
|
|
| R = self.preprocess_metadata['R'] |
| |
| direction = gripper_direction @ R.T |
|
|
| n_grippers = 1 |
| N = 200 |
| length = 0.2 |
| kk = 5 |
| xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) |
|
|
| gripper_now_inv_xyz = self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone()) |
| gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1) |
|
|
| center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=xyz.dtype).reshape(1, 3) |
| gripper_center_inv_xyz = gripper_now_inv_xyz + \ |
| torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) |
|
|
| for i in range(N): |
| offset = i / N * length * direction |
| xyz_test[:, i] = gripper_center_inv_xyz + offset |
| |
| if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: |
| direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=xyz.dtype) |
| direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
| direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=xyz.dtype) |
| direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
| else: |
| direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=xyz.dtype) |
| direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
| direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=xyz.dtype) |
| direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
|
|
| for i in range(N, N + N // kk): |
| offset = length * direction + (i - N) / N * length * direction_up |
| xyz_test[:, i] = gripper_center_inv_xyz + offset |
| |
| for i in range(N + N // kk, N + N // kk + N // kk): |
| offset = length * direction + (i - N - N // kk) / N * length * direction_down |
| xyz_test[:, i] = gripper_center_inv_xyz + offset |
| |
| color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=xyz.dtype) |
| color_test[:, :, 0] = 255 / 255 |
| color_test[:, :, 1] = 80 / 255 |
| color_test[:, :, 2] = 110 / 255 |
| quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=xyz.dtype) |
| quat_test[:, :, 0] = 1.0 |
| opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=xyz.dtype) |
| scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) * 0.002 |
| |
| xyz = torch.cat([xyz, xyz_test.reshape(-1, 3)], dim=0) |
| rgb = torch.cat([rgb, color_test.reshape(-1, 3)], dim=0) |
| quat = torch.cat([quat, quat_test.reshape(-1, 4)], dim=0) |
| opa = torch.cat([opa, opa_test.reshape(-1, 1)], dim=0) |
| scales = torch.cat([scales, scales_test.reshape(-1, 3)], dim=0) |
|
|
| |
| quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
| rendervar_full = { |
| 'means3D': xyz, |
| 'colors_precomp': rgb, |
| 'rotations': quat, |
| 'opacities': opa, |
| 'scales': scales, |
| 'means2D': torch.zeros_like(xyz), |
| } |
| |
| else: |
| rendervar_full = rendervar |
|
|
| return rendervar, rendervar_full |
|
|
| def reset_state(self, params, visualize_image=False, init=False): |
| xyz_0 = params['means3D'] |
| rgb_0 = params['rgb_colors'] |
| quat_0 = torch.nn.functional.normalize(params['unnorm_rotations']) |
| opa_0 = torch.sigmoid(params['logit_opacities']) |
| scales_0 = torch.exp(params['log_scales']) |
|
|
| rendervar_init = { |
| 'means3D': xyz_0, |
| 'colors_precomp': rgb_0, |
| 'rotations': quat_0, |
| 'opacities': opa_0, |
| 'scales': scales_0, |
| 'means2D': torch.zeros_like(xyz_0), |
| } |
|
|
| w = self.width |
| h = self.height |
| center = (0, 0, 0.1) |
| distance = 0.7 |
| elevation = 20 |
| azimuth = 180.0 if self.task_name == 'rope' else 120.0 |
| target = np.array(center) |
| theta = 90 + azimuth |
| z = distance * math.sin(math.radians(elevation)) |
| y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
| x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
| origin = target + np.array([x, y, z]) |
| |
| look_at = target - origin |
| look_at /= np.linalg.norm(look_at) |
| up = np.array([0.0, 0.0, 1.0]) |
| right = np.cross(look_at, up) |
| right /= np.linalg.norm(right) |
| up = np.cross(right, look_at) |
| w2c = np.eye(4) |
| w2c[:3, 0] = right |
| w2c[:3, 1] = -up |
| w2c[:3, 2] = look_at |
| w2c[:3, 3] = origin |
| w2c = np.linalg.inv(w2c) |
|
|
| k = np.array( |
| [[w / 2 * 1.0, 0., w / 2], |
| [0., w / 2 * 1.0, h / 2], |
| [0., 0., 1.]], |
| ) |
| self.metadata = {} |
| self.config = {} |
| self.update_camera(k, w2c, w, h) |
|
|
| n_particles = self.cfg.sim.n_particles |
| downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device) |
| p_x_viewer = xyz_0[downsample_indices] |
| p_x = self.preprocess_x(p_x_viewer) |
|
|
| self.state['x'] = p_x |
| self.state['v'] = torch.zeros_like(p_x) |
| self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1) |
| self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1)) |
| self.state['x_pred'] = p_x |
| self.state['v_pred'] = torch.zeros_like(p_x) |
|
|
| rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init) |
| im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0]) |
| im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8) |
|
|
| return rendervar_init |
|
|
| def reset(self, task_name, scene_name): |
| self.init(task_name) |
| |
| import warp as wp |
| wp.init() |
| gpus = [int(gpu) for gpu in self.cfg.gpus] |
| wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
| torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
| device_count = len(torch_devices) |
| assert device_count == 1 |
| self.wp_device = wp_devices[0] |
| self.torch_device = torch_devices[0] |
| |
| in_dir = root / f'log/gs/ckpts/{scene_name}' |
| batch_size = 1 |
| num_steps = 1 |
| num_particles = self.cfg.sim.n_particles |
| self.load_scaniverse(in_dir) |
| self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) |
| self.render_direction = False |
|
|
| params = self.preprocess_gs(self.params) |
| if self.with_bg: |
| self.preprocess_bg_gs() |
| rendervar = self.reset_state(params, visualize_image=False, init=True) |
| rendervar, rendervar_full = self.update_rendervar(rendervar) |
| |
|
|
| im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
| im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
| cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) |
|
|
| make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1) |
|
|
| gs_pred = save_to_splat( |
| rendervar_full['means3D'].cpu().numpy(), |
| rendervar_full['colors_precomp'].cpu().numpy(), |
| rendervar_full['scales'].cpu().numpy(), |
| rendervar_full['rotations'].cpu().numpy(), |
| rendervar_full['opacities'].cpu().numpy(), |
| root / 'log/gs/temp/gs_pred.splat', |
| rot_rev=True, |
| ) |
| |
| for k, v in self.preprocess_metadata.items(): |
| self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
| for k, v in self.state.items(): |
| self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
| for k, v in self.params.items(): |
| if isinstance(v, dict): |
| for k2, v2 in v.items(): |
| self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2 |
| else: |
| self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
| self.table_params = tuple( |
| v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params |
| ) |
| self.gripper_params = tuple( |
| v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params |
| ) |
| for k, v in rendervar.items(): |
| rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
|
| form_video = gr.Video( |
| label='Predicted video', |
| value=root / f'log/gs/temp/form_video_init.mp4', |
| format='mp4', |
| width=self.width, |
| height=self.height, |
| ) |
| form_3dgs_pred = gr.Model3D( |
| label='Predicted Gaussian Splats', |
| height=self.height, |
| value=root / 'log/gs/temp/gs_pred.splat', |
| clear_color=[0, 0, 0, 0], |
| ) |
|
|
| return form_video, form_3dgs_pred, \ |
| self.preprocess_metadata, self.state, self.params, \ |
| self.table_params, self.gripper_params, rendervar, task_name |
|
|
| def run_command(self, unit_command, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| self.task_name = task_name |
| import warp as wp |
| wp.init() |
| gpus = [int(gpu) for gpu in self.cfg.gpus] |
| wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
| torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
| device_count = len(torch_devices) |
| assert device_count == 1 |
| self.wp_device = wp_devices[0] |
| self.torch_device = torch_devices[0] |
| os.system('rm -rf ' + str(root / 'log/temp/*')) |
| |
| w = 640 |
| h = 480 |
| center = (0, 0, 0.1) |
| distance = 0.7 |
| elevation = 20 |
| azimuth = 180.0 if self.task_name == 'rope' else 120.0 |
| target = np.array(center) |
| theta = 90 + azimuth |
| z = distance * math.sin(math.radians(elevation)) |
| y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
| x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
| origin = target + np.array([x, y, z]) |
| |
| look_at = target - origin |
| look_at /= np.linalg.norm(look_at) |
| up = np.array([0.0, 0.0, 1.0]) |
| right = np.cross(look_at, up) |
| right /= np.linalg.norm(right) |
| up = np.cross(right, look_at) |
| w2c = np.eye(4) |
| w2c[:3, 0] = right |
| w2c[:3, 1] = -up |
| w2c[:3, 2] = look_at |
| w2c[:3, 3] = origin |
| w2c = np.linalg.inv(w2c) |
|
|
| k = np.array( |
| [[w / 2 * 1.0, 0., w / 2], |
| [0., w / 2 * 1.0, h / 2], |
| [0., 0., 1.]], |
| ) |
| self.update_camera(k, w2c, w, h) |
| |
| self.preprocess_metadata = preprocess_metadata |
| self.state = state |
| self.params = params |
| self.table_params = table_params |
| self.gripper_params = gripper_params |
| for k, v in self.preprocess_metadata.items(): |
| self.preprocess_metadata[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
| for k, v in self.state.items(): |
| self.state[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
| for k, v in self.params.items(): |
| if isinstance(v, dict): |
| for k2, v2 in v.items(): |
| self.params[k][k2] = v2.to(self.torch_device) if isinstance(v2, torch.Tensor) else v2 |
| else: |
| self.params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
| self.table_params = tuple( |
| v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.table_params |
| ) |
| self.gripper_params = tuple( |
| v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.gripper_params |
| ) |
| for k, v in rendervar.items(): |
| rendervar[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
| |
| num_steps = 15 |
| batch_size = 1 |
| num_particles = self.cfg.sim.n_particles |
| self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) |
| self.render_direction = True |
|
|
| |
| for i in range(num_steps): |
| dt = 0.1 |
| command = torch.tensor([unit_command]).to(self.device).to(torch.float32) |
| command = self.preprocess_gripper(command) |
| |
| |
| if self.verbose: |
| print('command:', command.cpu().numpy().tolist()) |
| |
| self.gripper_direction = command.clone() |
|
|
| assert self.state['sub_pos'] is None |
|
|
| if self.state['sub_pos'] is None: |
| eef_xyz_latest = self.state['prev_key_pos'] |
| |
| |
| else: |
| eef_xyz_latest = self.state['sub_pos'][-1] |
| |
|
|
| eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 |
|
|
| if self.state['sub_pos'] is None: |
| self.state['sub_pos'] = eef_xyz_updated[None] |
| |
| else: |
| self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| self.step() |
| rendervar, rendervar_full = self.update_rendervar(rendervar) |
| |
| im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
| im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
| |
| cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) |
|
|
| |
| self.state['v'] *= 0.0 |
| self.state['x'] = self.state['x_pred'].clone() |
| self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1) |
| self.state['v_his'] *= 0.0 |
| self.state['v_pred'] *= 0.0 |
| |
| for k, v in self.preprocess_metadata.items(): |
| self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
| for k, v in self.state.items(): |
| self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
| for k, v in self.params.items(): |
| if isinstance(v, dict): |
| for k2, v2 in v.items(): |
| self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2 |
| else: |
| self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
| self.table_params = tuple( |
| v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params |
| ) |
| self.gripper_params = tuple( |
| v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params |
| ) |
| for k, v in rendervar.items(): |
| rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
|
| make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5) |
|
|
| form_video = gr.Video( |
| label='Predicted video', |
| value=root / f'log/gs/temp/form_video.mp4', |
| format='mp4', |
| width=self.width, |
| height=self.height, |
| ) |
|
|
| im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
| im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
| gs_pred = save_to_splat( |
| rendervar_full['means3D'].cpu().numpy(), |
| rendervar_full['colors_precomp'].cpu().numpy(), |
| rendervar_full['scales'].cpu().numpy(), |
| rendervar_full['rotations'].cpu().numpy(), |
| rendervar_full['opacities'].cpu().numpy(), |
| root / 'log/gs/temp/gs_pred.splat', |
| rot_rev=True, |
| ) |
| form_3dgs_pred = gr.Model3D( |
| label='Predicted Gaussian Splats', |
| height=self.height, |
| value=root / 'log/gs/temp/gs_pred.splat', |
| clear_color=[0, 0, 0, 0], |
| ) |
| return form_video, form_3dgs_pred, \ |
| self.preprocess_metadata, self.state, self.params, \ |
| self.table_params, self.gripper_params, rendervar, task_name |
| |
| @spaces.GPU |
| def reset_rope(self): |
| return self.reset('rope', 'rope_scene_1') |
|
|
| @spaces.GPU |
| def reset_plush(self): |
| return self.reset('sloth', 'sloth_scene_1') |
|
|
| @spaces.GPU |
| def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
| @spaces.GPU |
| def on_click_run_xminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| return self.run_command([-5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
| |
| @spaces.GPU |
| def on_click_run_yplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| return self.run_command([0, 5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
| |
| @spaces.GPU |
| def on_click_run_yminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| return self.run_command([0, -5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
| |
| @spaces.GPU |
| def on_click_run_zplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| return self.run_command([0, 0, 5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
| |
| @spaces.GPU |
| def on_click_run_zminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
| return self.run_command([0, 0, -5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
| def launch(self, share=False): |
| |
| with gr.Blocks() as app: |
| preprocess_metadata = gr.State(self.preprocess_metadata) |
| state = gr.State(self.state) |
| params = gr.State(self.params) |
| table_params = gr.State(self.table_params) |
| gripper_params = gr.State(self.gripper_params) |
| rendervar = gr.State(None) |
| task_name = gr.State(self.task_name) |
|
|
| with gr.Row(): |
| gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos") |
|
|
| with gr.Row(): |
| gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)') |
| |
| with gr.Row(): |
| gr.Markdown('### Instructions:') |
|
|
| with gr.Row(): |
| gr.Markdown(' '.join([ |
| '- Click the "Reset-\<object\>" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n', |
| '- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n', |
| '- X-Y plane is the table surface, and Z is the height.\n', |
| '- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n', |
| '- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n', |
| '- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the Reset button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n', |
| ])) |
| |
| with gr.Row(): |
| gr.Markdown('### Select a scene to reset the simulation:') |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| with gr.Row(): |
| with gr.Column(): |
| run_reset_plush = gr.Button("Reset - Plush") |
| with gr.Column(): |
| run_reset_rope = gr.Button("Reset - Rope") |
|
|
| with gr.Column(scale=2): |
| _ = gr.Button(visible=False) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=2): |
| form_video = gr.Video( |
| label='Predicted video', |
| value=None, |
| format='mp4', |
| width=self.width, |
| height=self.height, |
| ) |
| |
| with gr.Column(scale=2): |
| form_3dgs_pred = gr.Model3D( |
| label='Predicted Gaussians', |
| height=self.height, |
| value=None, |
| clear_color=[0, 0, 0, 0], |
| ) |
|
|
| |
| with gr.Row(): |
| gr.Markdown('### Control the gripper to move in the x, y, z directions:') |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
|
|
| with gr.Row(): |
| with gr.Column(): |
| run_xminus = gr.Button("x-") |
| with gr.Column(): |
| run_xplus = gr.Button("x+") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| run_yminus = gr.Button("y-") |
| with gr.Column(): |
| run_yplus = gr.Button("y+") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| run_zminus = gr.Button("z-") |
| with gr.Column(): |
| run_zplus = gr.Button("z+") |
| |
| with gr.Column(scale=2): |
| _ = gr.Button(visible=False) |
|
|
| |
| run_reset_rope.click(self.reset_rope, |
| inputs=[], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
| |
| run_reset_plush.click(self.reset_plush, |
| inputs=[], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
| |
| run_xplus.click(self.on_click_run_xplus, |
| inputs=[preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
| |
| run_xminus.click(self.on_click_run_xminus, |
| inputs=[preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
| |
| run_yplus.click(self.on_click_run_yplus, |
| inputs=[preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
| |
| run_yminus.click(self.on_click_run_yminus, |
| inputs=[preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
| |
| run_zplus.click(self.on_click_run_zplus, |
| inputs=[preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
|
|
| run_zminus.click(self.on_click_run_zminus, |
| inputs=[preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name], |
| outputs=[form_video, form_3dgs_pred, |
| preprocess_metadata, state, params, |
| table_params, gripper_params, rendervar, task_name]) |
|
|
| app.launch(share=share) |
|
|
|
|
| if __name__ == '__main__': |
| visualizer = DynamicsVisualizer() |
| visualizer.launch(share=True) |
|
|