LongStream / demo_gradio.py
Cc
fixed bugs
3f1fddf
import os
import gradio as gr
try:
import spaces
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
from longstream.demo import BRANCH_OPTIONS, create_demo_session, load_metadata
from longstream.demo.backend import load_frame_previews
from longstream.demo.export import export_glb
from longstream.demo.viewer import build_interactive_figure
DEFAULT_KEYFRAME_STRIDE = 8
DEFAULT_REFRESH = 3
DEFAULT_WINDOW_SIZE = 48
DEFAULT_CHECKPOINT = os.getenv("LONGSTREAM_CHECKPOINT", "checkpoints/50_longstream.pt")
def _run_stable_demo_impl(
image_dir,
uploaded_files,
uploaded_video,
checkpoint,
device,
mode,
streaming_mode,
refresh,
window_size,
compute_sky,
branch_label,
show_cameras,
mask_sky,
camera_scale,
point_size,
opacity,
preview_max_points,
glb_max_points,
):
if not image_dir and not uploaded_files and not uploaded_video:
raise gr.Error("Provide an image folder, upload images, or upload a video.")
session_dir = create_demo_session(
image_dir=image_dir or "",
uploaded_files=uploaded_files,
uploaded_video=uploaded_video,
checkpoint=checkpoint,
device=device,
mode=mode,
streaming_mode=streaming_mode,
keyframe_stride=DEFAULT_KEYFRAME_STRIDE,
refresh=int(refresh),
window_size=int(window_size),
compute_sky=bool(compute_sky),
)
fig = build_interactive_figure(
session_dir=session_dir,
branch=branch_label,
display_mode="All Frames",
frame_index=0,
point_size=float(point_size),
opacity=float(opacity),
preview_max_points=int(preview_max_points),
show_cameras=bool(show_cameras),
camera_scale=float(camera_scale),
mask_sky=bool(mask_sky),
)
glb_path = export_glb(
session_dir=session_dir,
branch=branch_label,
display_mode="All Frames",
frame_index=0,
mask_sky=bool(mask_sky),
show_cameras=bool(show_cameras),
camera_scale=float(camera_scale),
max_points=int(glb_max_points),
)
rgb, depth, frame_label = load_frame_previews(session_dir, 0)
meta = load_metadata(session_dir)
slider = gr.update(
minimum=0,
maximum=max(meta["num_frames"] - 1, 0),
value=0,
step=1,
interactive=meta["num_frames"] > 1,
)
sky_msg = ""
if meta.get("has_sky_masks"):
removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0
sky_msg = f" | sky_removed={removed:.1f}%"
status = f"Ready: {meta['num_frames']} frames | branch={branch_label}{sky_msg}"
return (
fig,
glb_path,
session_dir,
rgb,
depth,
frame_label,
slider,
status,
)
if HAS_SPACES:
_run_stable_demo = spaces.GPU(_run_stable_demo_impl)
else:
_run_stable_demo = _run_stable_demo_impl
def _update_stable_scene(
session_dir,
branch_label,
show_cameras,
mask_sky,
camera_scale,
point_size,
opacity,
preview_max_points,
glb_max_points,
):
if not session_dir or not os.path.isdir(session_dir):
return None, None, "Run reconstruction first."
fig = build_interactive_figure(
session_dir=session_dir,
branch=branch_label,
display_mode="All Frames",
frame_index=0,
point_size=float(point_size),
opacity=float(opacity),
preview_max_points=int(preview_max_points),
show_cameras=bool(show_cameras),
camera_scale=float(camera_scale),
mask_sky=bool(mask_sky),
)
glb_path = export_glb(
session_dir=session_dir,
branch=branch_label,
display_mode="All Frames",
frame_index=0,
mask_sky=bool(mask_sky),
show_cameras=bool(show_cameras),
camera_scale=float(camera_scale),
max_points=int(glb_max_points),
)
meta = load_metadata(session_dir)
sky_msg = ""
if meta.get("has_sky_masks"):
removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0
sky_msg = f" | sky_removed={removed:.1f}%"
return fig, glb_path, f"Updated preview: {branch_label}{sky_msg}"
def _update_frame_preview(session_dir, frame_index):
if not session_dir or not os.path.isdir(session_dir):
return None, None, ""
rgb, depth, label = load_frame_previews(session_dir, int(frame_index))
return rgb, depth, label
def main():
with gr.Blocks(title="LongStream Demo") as demo:
session_dir = gr.Textbox(visible=False)
gr.Markdown("# LongStream Demo")
with gr.Row():
image_dir = gr.Textbox(
label="Image Folder", placeholder="/path/to/sequence"
)
uploaded_files = gr.File(
label="Upload Images", file_count="multiple", file_types=["image"]
)
uploaded_video = gr.File(
label="Upload Video", file_count="single", file_types=["video"]
)
with gr.Row():
checkpoint = gr.Textbox(label="Checkpoint", value=DEFAULT_CHECKPOINT)
device = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cuda")
with gr.Accordion("Inference", open=False):
with gr.Row():
mode = gr.Dropdown(
label="Mode",
choices=["streaming_refresh", "batch_refresh"],
value="batch_refresh",
)
streaming_mode = gr.Dropdown(
label="Streaming Mode", choices=["causal", "window"], value="causal"
)
with gr.Row():
refresh = gr.Slider(
label="Refresh", minimum=2, maximum=9, step=1, value=DEFAULT_REFRESH
)
window_size = gr.Slider(
label="Window Size",
minimum=1,
maximum=64,
step=1,
value=DEFAULT_WINDOW_SIZE,
)
compute_sky = gr.Checkbox(label="Compute Sky Masks", value=True)
with gr.Accordion("GLB Settings", open=True):
with gr.Row():
branch_label = gr.Dropdown(
label="Point Cloud Branch",
choices=BRANCH_OPTIONS,
value="Point Head + Pose",
)
show_cameras = gr.Checkbox(label="Show Cameras", value=True)
mask_sky = gr.Checkbox(label="Mask Sky", value=True)
with gr.Row():
point_size = gr.Slider(
label="Point Size",
minimum=0.05,
maximum=2.0,
step=0.05,
value=0.3,
)
opacity = gr.Slider(
label="Opacity",
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.75,
)
preview_max_points = gr.Slider(
label="Preview Max Points",
minimum=5000,
maximum=1000000,
step=10000,
value=100000,
)
with gr.Row():
camera_scale = gr.Slider(
label="Camera Scale",
minimum=0.001,
maximum=0.05,
step=0.001,
value=0.01,
)
glb_max_points = gr.Slider(
label="GLB Max Points",
minimum=20000,
maximum=1000000,
step=10000,
value=400000,
)
run_btn = gr.Button("Run Stable Demo", variant="primary")
status = gr.Markdown("Provide input images, then run reconstruction.")
plot = gr.Plot(label="Scene Preview")
glb_file = gr.File(label="Download GLB")
with gr.Row():
frame_slider = gr.Slider(
label="Preview Frame",
minimum=0,
maximum=0,
step=1,
value=0,
interactive=False,
)
frame_label = gr.Textbox(label="Frame")
with gr.Row():
rgb_preview = gr.Image(label="RGB", type="numpy")
depth_preview = gr.Image(label="Depth Plasma", type="numpy")
run_btn.click(
_run_stable_demo,
inputs=[
image_dir,
uploaded_files,
uploaded_video,
checkpoint,
device,
mode,
streaming_mode,
refresh,
window_size,
compute_sky,
branch_label,
show_cameras,
mask_sky,
camera_scale,
point_size,
opacity,
preview_max_points,
glb_max_points,
],
outputs=[
plot,
glb_file,
session_dir,
rgb_preview,
depth_preview,
frame_label,
frame_slider,
status,
],
)
for component in [
branch_label,
show_cameras,
mask_sky,
camera_scale,
point_size,
opacity,
preview_max_points,
glb_max_points,
]:
component.change(
_update_stable_scene,
inputs=[
session_dir,
branch_label,
show_cameras,
mask_sky,
camera_scale,
point_size,
opacity,
preview_max_points,
glb_max_points,
],
outputs=[plot, glb_file, status],
)
frame_slider.change(
_update_frame_preview,
inputs=[session_dir, frame_slider],
outputs=[rgb_preview, depth_preview, frame_label],
)
demo.launch()
if __name__ == "__main__":
main()