Spaces:
Sleeping
Sleeping
| import spaces | |
| from pickle import FALSE | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from sam2point import dataset | |
| import sam2point.configs as configs | |
| from demo_utils import run_demo, create_box | |
| samples = { | |
| "3D Indoor Scene - S3DIS": ["Conference Room", "Restroom", "Lobby", "Office1", "Office2"], | |
| "3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"], | |
| "3D Raw LiDAR - KITTI": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"], | |
| "3D Outdoor Scene - Semantic3D": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6", "Scene7"], | |
| "3D Object - Objaverse": ["Plant", "Lego", "Lock", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse"], | |
| } | |
| PATH = { | |
| "S3DIS": ['Area_1_conferenceRoom_1.txt', 'Area_2_WC_1.txt', 'Area_4_lobby_2.txt', 'Area_5_office_3.txt', 'Area_6_office_9.txt'], | |
| "ScanNet": ['scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth', 'scene0000_00.pth', 'scene0002_00.pth'], | |
| "Objaverse": ["plant.npy", "human.npy", "lock.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy"], | |
| "KITTI": ["scene1.npy", "scene2.npy", "scene3.npy", "scene4.npy", "scene5.npy", "scene6.npy"], | |
| "Semantic3D": ["scene1.npy", "scene2.npy", "patch19.npy", "patch0.npy", "patch1.npy", "patch50.npy", "patch62.npy"] | |
| } | |
| prompt_types = ["Point", "Box", "Mask"] | |
| def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new_color=None): | |
| DATASET = name.split('-')[1].replace(" ", "") | |
| path = 'data/' + DATASET + '/' + PATH[DATASET][sample_idx] | |
| asp, SIZE = 1., 1 | |
| print(path) | |
| if DATASET == 'S3DIS': | |
| point, color = dataset.load_S3DIS_sample(path, sample=True) | |
| alpha = 1 | |
| elif DATASET == 'ScanNet': | |
| point, color = dataset.load_ScanNet_sample(path) | |
| alpha = 1 | |
| elif DATASET == 'Objaverse': | |
| point, color = dataset.load_Objaverse_sample(path) | |
| alpha = 1 | |
| SIZE = 2 | |
| elif DATASET == 'KITTI': | |
| point, color = dataset.load_KITTI_sample(path) | |
| asp = 0.3 | |
| alpha = 0.7 | |
| elif DATASET == 'Semantic3D': | |
| point, color = dataset.load_Semantic3D_sample(path, sample_idx, sample=True) | |
| alpha = 0.2 | |
| print("Loading Dataset:", DATASET, "Point Cloud Size:", point.shape, "Path:", path) | |
| ##### Initial Show ##### | |
| if not type_: | |
| if point.shape[0] > 100000: # sample points for speeding up | |
| indices = np.random.choice(point.shape[0], 100000, replace=False) | |
| point = point[indices] | |
| color = color[indices] | |
| fig = go.Figure( | |
| data=[ | |
| go.Scatter3d( | |
| x=point[:,0], y=point[:,1], z=point[:,2], | |
| mode='markers', | |
| marker=dict(size=SIZE, color=color, opacity=alpha), | |
| name="" | |
| ) | |
| ], | |
| layout=dict( | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| aspectratio=dict(x=1, y=1, z=asp), | |
| camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) | |
| ) | |
| ) | |
| ) | |
| return fig | |
| ##### Final Results ##### | |
| if final: | |
| color = new_color | |
| green = np.array([[0.1, 0.1, 0.1]]) | |
| add_green = go.Scatter3d( | |
| x=green[:,0], y=green[:,1], z=green[:,2], | |
| mode='markers', | |
| marker=dict(size=0.0001, color='green', opacity=1), | |
| name="Segmentation Results" | |
| ) | |
| if type_ == "box": | |
| if point.shape[0] > 100000: | |
| indices = np.random.choice(point.shape[0], 100000, replace=False) | |
| point = point[indices] | |
| color = color[indices] | |
| scatter = go.Scatter3d( | |
| x=point[:,0], y=point[:,1], z=point[:,2], | |
| mode='markers', | |
| marker=dict(size=SIZE, color=color, opacity=alpha), | |
| name="3D Object/Scene" | |
| ) | |
| if final: scatter = [scatter, add_green] + create_box(prompt) | |
| else: scatter = [scatter] + create_box(prompt) | |
| elif type_ == "point": | |
| prompt = np.array([prompt]) | |
| new = go.Scatter3d( | |
| x=prompt[:,0], y=prompt[:,1], z=prompt[:,2], | |
| mode='markers', | |
| marker=dict(size=5, color='red', opacity=1), | |
| name="Point Prompt" | |
| ) | |
| if point.shape[0] > 100000: | |
| indices = np.random.choice(point.shape[0], 100000, replace=False) | |
| point = point[indices] | |
| color = color[indices] | |
| scatter = go.Scatter3d( | |
| x=point[:,0], y=point[:,1], z=point[:,2], | |
| mode='markers', | |
| marker=dict(size=SIZE, color=color, opacity=alpha), | |
| name="3D Object/Scene" | |
| ) | |
| if final: scatter = [scatter, new, add_green] | |
| else: scatter = [scatter, new] | |
| elif type_ == 'mask' and not final: | |
| color = np.clip(prompt * 255, 0, 255).astype(np.uint8) | |
| if point.shape[0] > 100000: | |
| indices = np.random.choice(point.shape[0], 100000, replace=False) | |
| point = point[indices] | |
| color = color[indices] | |
| scatter = go.Scatter3d( | |
| x=point[:,0], y=point[:,1], z=point[:,2], | |
| mode='markers', | |
| marker=dict(size=SIZE, color=color, opacity=alpha), | |
| name="3D Object/Scene" | |
| ) | |
| red = np.array([[0.1, 0.1, 0.1]]) | |
| add_red = go.Scatter3d( | |
| x=red[:,0], y=red[:,1], z=red[:,2], | |
| mode='markers', | |
| marker=dict(size=0.0001, color='red', opacity=1), | |
| name="Mask Prompt" | |
| ) | |
| scatter = [scatter, add_red] | |
| elif type_ == 'mask' and final: | |
| if point.shape[0] > 100000: | |
| indices = np.random.choice(point.shape[0], 100000, replace=False) | |
| point = point[indices] | |
| color = color[indices] | |
| scatter = go.Scatter3d( | |
| x=point[:,0], y=point[:,1], z=point[:,2], | |
| mode='markers', | |
| marker=dict(size=SIZE, color=color, opacity=alpha), | |
| name="3D Object/Scene" | |
| ) | |
| scatter = [scatter, add_green] | |
| else: | |
| print("Wrong Prompt Type") | |
| exit(1) | |
| fig = go.Figure( | |
| data=scatter, | |
| layout=dict( | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| aspectratio=dict(x=1, y=1, z=asp), | |
| camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) | |
| ) | |
| ) | |
| ) | |
| return fig | |
| def show_prompt_in_3d(name, sample_idx, prompt_type, prompt_idx): | |
| if name == None or sample_idx == None or prompt_type == None or prompt_idx == None: | |
| return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True) | |
| DATASET = name.split('-')[1].replace(" ", "") | |
| TYPE = prompt_type.lower() | |
| theta = 0. if DATASET in "S3DIS ScanNet" else 0.5 | |
| mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest' | |
| prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, 0.02, theta, mode, ret_prompt=True) | |
| fig = load_3d_scene(name, sample_idx, TYPE, prompt) | |
| return fig, gr.Textbox(label="Response", value="Prompt has been shown in 3D Object/Scene!", visible=True) | |
| def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=None, vx=0.02): | |
| if name == None or sample_idx == None or prompt_type == None or prompt_idx == None: | |
| return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True) | |
| DATASET = name.split('-')[1].replace(" ", "") | |
| TYPE = prompt_type.lower() | |
| theta = 0. if DATASET in "S3DIS ScanNet" else 0.5 | |
| mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest' | |
| new_color, prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, vx, theta, mode, ret_prompt=False) | |
| fig = load_3d_scene(name, sample_idx, TYPE, prompt, final=True, new_color=new_color) | |
| return fig, gr.Textbox(label="Response", value="Segmentation completed successfully!", visible=True) | |
| def update1(datasets): | |
| if 'Objaverse' in datasets: | |
| return gr.Radio(label="Select 3D Object", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) | |
| return gr.Radio(label="Select 3D Scene", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) | |
| def update2(name, sample_idx, prompt_type): | |
| if name == None or sample_idx == None or prompt_type == None: | |
| return gr.Radio(label="Select Prompt Example", choices=[]), gr.Textbox(label="Response", value="", visible=True) | |
| DATASET = name.split('-')[1].replace(" ", "") | |
| TYPE = prompt_type.lower() + '_prompts' | |
| if DATASET == 'S3DIS': | |
| info = configs.S3DIS_samples[sample_idx][TYPE] | |
| elif DATASET == 'ScanNet': | |
| info = configs.ScanNet_samples[sample_idx][TYPE] | |
| elif DATASET == 'Objaverse': | |
| info = configs.Objaverse_samples[sample_idx][TYPE] | |
| elif DATASET == 'KITTI': | |
| info = configs.KITTI_samples[sample_idx][TYPE] | |
| elif DATASET == 'Semantic3D': | |
| info = configs.Semantic3D_samples[sample_idx][TYPE] | |
| cur = ['Example ' + str(i) for i in range(1, len(info) + 1)] | |
| return gr.Radio(label="Select Prompt Example", choices=cur), gr.Textbox(label="Response", value="", visible=True) | |
| def update3(name, sample_idx, prompt_type, prompt_idx): | |
| if name == None or sample_idx == None or prompt_type == None: | |
| return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) | |
| DATASET = name.split('-')[1].replace(" ", "") | |
| TYPE = configs.VOXEL[prompt_type.lower()] | |
| if DATASET in "S3DIS ScanNet": | |
| vx_ = 0.02 | |
| elif DATASET == 'Objaverse': | |
| vx_ = configs.Objaverse_samples[sample_idx][TYPE][prompt_idx] | |
| elif DATASET == 'KITTI': | |
| vx_ = configs.KITTI_samples[sample_idx][TYPE][prompt_idx] | |
| elif DATASET == 'Semantic3D': | |
| vx_ = configs.Semantic3D_samples[sample_idx][TYPE][prompt_idx] | |
| return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=vx_) | |
| def main(): | |
| title = """<h1 style="text-align: center;"> | |
| <div style="width: 1.2em; height: 1.2em; display: inline-block;"><img src="https://github.com/ZiyuGuo99/ZiyuGuo99.github.io/blob/main/assets/img/logo.png?raw=true" style='width: 100%; height: 100%; object-fit: contain;' /></div> | |
| <span style="font-variant: small-caps; font-weight: bold;">Sam2Point</span> | |
| </h1> | |
| <h3 align="center"><span style="font-variant: small-caps; ">Segment Any 3D as Videos in Zero-shot and Promptable Manners | |
| </span></h3> | |
| <div style="text-align: center;"> | |
| <div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem; margin-bottom: 0.5rem; font-size: 1rem; flex-wrap: wrap;"> | |
| <a href="https://sam2point.github.io/" target="_blank">[Webpage]</a> | |
| <a href="https://arxiv.org/pdf/2408.16768" target="_blank">[Paper]</a> | |
| <a href="https://github.com/ZiyuGuo99/SAM2Point" target="_blank">[Code]</a> | |
| </div> | |
| </div> | |
| <p style="text-align: center;"> | |
| Select an example and a 3D prompt to start segmentation using <span style="font-variant: small-caps;">Sam2Point</span>. | |
| </p> | |
| <p style="text-align: center;"> | |
| Custom 3D input and prompts will be supported soon. | |
| </p> | |
| """ | |
| with gr.Blocks( | |
| css=""" | |
| .contain { display: flex; flex-direction: column; } | |
| .gradio-container { height: 100vh !important; } | |
| #col_container { height: 100%; } | |
| pre { | |
| white-space: pre-wrap; /* Since CSS 2.1 */ | |
| white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ | |
| white-space: -pre-wrap; /* Opera 4-6 */ | |
| white-space: -o-pre-wrap; /* Opera 7 */ | |
| word-wrap: break-word; /* Internet Explorer 5.5+ */ | |
| }""", | |
| js=""" | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| }""", | |
| title="SAM2Point: Segment Any 3D as Videos in Zero-shot and Promptable Manners", | |
| theme=gr.themes.Soft() | |
| ) as app: | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Column(elem_id="col_container"): | |
| sample_dropdown = gr.Dropdown(label="Select 3D Data Type", choices=samples, type="value") | |
| scene_dropdown = gr.Radio(label="Select 3D Object/Scene", choices=[], type="index") | |
| show_button = gr.Button("Show 3D Scene/Object") | |
| prompt_type_dropdown = gr.Radio(label="Select Prompt Type", choices=prompt_types) | |
| prompt_sample_dropdown = gr.Radio(label="Select Prompt Example", choices=[], type="index") | |
| show_prompt_button = gr.Button("Show Prompt in 3D Scene/Object") | |
| with gr.Column(): | |
| start_segment_button = gr.Button("Start Segmentation") | |
| plot1 = gr.Plot() | |
| response = gr.Textbox(label="Response") | |
| sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response]) | |
| sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) | |
| scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) | |
| prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) | |
| show_button.click(load_3d_scene, inputs=[sample_dropdown, scene_dropdown], outputs=plot1) | |
| show_prompt_button.click(show_prompt_in_3d, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response]) | |
| start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response]) | |
| app.queue(max_size=20, api_open=False) | |
| app.launch(max_threads=400) | |
| if __name__ == "__main__": | |
| main() |