Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| import math | |
| import os | |
| import unittest | |
| from typing import Tuple | |
| import torch | |
| from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset | |
| from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud | |
| from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround | |
| from pytorch3d.implicitron.tools.config import expand_args_fields | |
| from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d | |
| from pytorch3d.renderer.cameras import CamerasBase | |
| from tests.common_testing import interactive_testing_requested | |
| from visdom import Visdom | |
| from .common_resources import get_skateboard_data | |
| class TestModelVisualize(unittest.TestCase): | |
| def test_flyaround_one_sequence( | |
| self, | |
| image_size: int = 256, | |
| ): | |
| if not interactive_testing_requested(): | |
| return | |
| category = "skateboard" | |
| stack = contextlib.ExitStack() | |
| dataset_root, path_manager = stack.enter_context(get_skateboard_data()) | |
| self.addCleanup(stack.close) | |
| frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") | |
| sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") | |
| subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") | |
| expand_args_fields(JsonIndexDataset) | |
| train_dataset = JsonIndexDataset( | |
| frame_annotations_file=frame_file, | |
| sequence_annotations_file=sequence_file, | |
| subset_lists_file=subset_lists_file, | |
| dataset_root=dataset_root, | |
| image_height=image_size, | |
| image_width=image_size, | |
| box_crop=True, | |
| load_point_clouds=True, | |
| path_manager=path_manager, | |
| subsets=[ | |
| "train_known", | |
| ], | |
| ) | |
| # select few sequences to visualize | |
| sequence_names = list(train_dataset.seq_annots.keys()) | |
| # select the first sequence name | |
| show_sequence_name = sequence_names[0] | |
| output_dir = os.path.split(os.path.abspath(__file__))[0] | |
| visdom_show_preds = Visdom().check_connection() | |
| for load_dataset_pointcloud in [True, False]: | |
| model = _PointcloudRenderingModel( | |
| train_dataset, | |
| show_sequence_name, | |
| device="cuda:0", | |
| load_dataset_pointcloud=load_dataset_pointcloud, | |
| ) | |
| video_path = os.path.join( | |
| output_dir, | |
| f"load_pcl_{load_dataset_pointcloud}", | |
| ) | |
| os.makedirs(output_dir, exist_ok=True) | |
| for output_video_frames_dir in [None, video_path]: | |
| render_flyaround( | |
| train_dataset, | |
| show_sequence_name, | |
| model, | |
| video_path, | |
| n_flyaround_poses=10, | |
| fps=5, | |
| max_angle=2 * math.pi, | |
| trajectory_type="circular_lsq_fit", | |
| trajectory_scale=1.1, | |
| scene_center=(0.0, 0.0, 0.0), | |
| up=(0.0, 1.0, 0.0), | |
| traj_offset=1.0, | |
| n_source_views=1, | |
| visdom_show_preds=visdom_show_preds, | |
| visdom_environment="test_model_visalize", | |
| visdom_server="http://127.0.0.1", | |
| visdom_port=8097, | |
| num_workers=10, | |
| seed=None, | |
| video_resize=None, | |
| visualize_preds_keys=[ | |
| "images_render", | |
| "depths_render", | |
| "masks_render", | |
| "_all_source_images", | |
| ], | |
| output_video_frames_dir=output_video_frames_dir, | |
| ) | |
| class _PointcloudRenderingModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| train_dataset: JsonIndexDataset, | |
| sequence_name: str, | |
| render_size: Tuple[int, int] = (400, 400), | |
| device=None, | |
| load_dataset_pointcloud: bool = False, | |
| max_frames: int = 30, | |
| num_workers: int = 10, | |
| ): | |
| super().__init__() | |
| self._render_size = render_size | |
| point_cloud, _ = get_implicitron_sequence_pointcloud( | |
| train_dataset, | |
| sequence_name=sequence_name, | |
| mask_points=True, | |
| max_frames=max_frames, | |
| num_workers=num_workers, | |
| load_dataset_point_cloud=load_dataset_pointcloud, | |
| ) | |
| self._point_cloud = point_cloud.to(device) | |
| def forward( | |
| self, | |
| camera: CamerasBase, | |
| **kwargs, | |
| ): | |
| image_render, mask_render, depth_render = render_point_cloud_pytorch3d( | |
| camera[0], | |
| self._point_cloud, | |
| render_size=self._render_size, | |
| point_radius=1e-2, | |
| topk=10, | |
| bg_color=0.0, | |
| ) | |
| return { | |
| "images_render": image_render.clamp(0.0, 1.0), | |
| "masks_render": mask_render, | |
| "depths_render": depth_render, | |
| } | |