diff --git "a/ui.py" "b/ui.py" new file mode 100644--- /dev/null +++ "b/ui.py" @@ -0,0 +1,3422 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: I001 +import math +import os +import threading +import time +from typing import Optional + +from kimodo.constraints import load_constraints_lst, save_constraints_lst +from kimodo.exports.bvh import motion_to_bvh_bytes, save_motion_bvh +from kimodo.exports.motion_io import ( + amass_npz_to_bytes, + g1_csv_to_bytes, + kimodo_npz_to_bytes, + load_motion_file, + save_kimodo_npz, +) +from kimodo.model.registry import kimodo_short_key_for_skeleton_dataset, registry_skeleton_for_joint_count +from kimodo.tools import to_torch +from kimodo.viz import viser_utils +from kimodo.viz.viser_utils import GuiElements +import numpy as np +import torch +import viser +from viser._timeline_api import PROMPT_COLORS + +from . import generation +from ._qwen_prompts import call_qwen_for_prompts +from .config import ( + DEFAULT_CUR_DURATION, + DEMO_UI_INSTRUCTIONS_TAB_MD, + get_datasets, + get_model_info, + get_models_for_dataset_skeleton, + get_skeleton_display_name, + get_skeleton_display_names_for_dataset, + get_skeleton_key_from_display_name, + get_short_key_from_display_name, + HF_MODE, + INIT_POSTPROCESSING, + MODEL_NAMES, + NB_TRANSITION_FRAMES, + SHOW_TRANSITION_PARAMS, +) +from .state import ClientSession +from kimodo.skeleton import G1Skeleton34, SOMASkeleton30, SOMASkeleton77 + + +QWEN_EXAMPLE_NAME = "09_qwen_agentic_actions" +QWEN_EXAMPLE_LEGACY_NAME = "09_qwen_agentic_10_actions" + + +def extract_intervals_and_singles(t: torch.Tensor): + intervals = [] + intervals_indices = [] + single_frames = [] + single_frames_indices = [] + + start_idx = 0 + + for i in range(1, len(t) + 1): + # End of run if: + # - end of tensor + # - non-consecutive value + if i == len(t) or t[i] != t[i - 1] + 1: + run_length = i - start_idx + + if run_length >= 2: + intervals.append((int(t[start_idx]), int(t[i - 1]))) + intervals_indices.append((start_idx, i - 1)) + else: + single_frames.append(int(t[start_idx])) + single_frames_indices.append(start_idx) + + start_idx = i + + return intervals, intervals_indices, single_frames, single_frames_indices + + +def create_gui( + demo, + client: viser.ClientHandle, + model_name: str, + model_fps: float, +): + """Create GUI elements for a specific client.""" + client_id = client.client_id + + def get_active_session(event_client: viser.ClientHandle | None): + if event_client is None: + return None + if not demo.client_active(event_client.client_id): + return None + return demo.client_sessions[event_client.client_id] + + def build_timeline_tracks(): + timeline = client.timeline + demo.set_timeline_defaults(timeline, model_fps) + timeline.set_visible(True) + timeline.set_current_frame(0) + + timeline_tracks = {} + fullbody_id = timeline.add_track( + "Full-Body", + track_type="keyframe", + color=(219, 148, 86), + height_scale=0.5, + ) + timeline_tracks[fullbody_id] = { + "name": "Full-Body", + "track_type": "keyframe", + "color": (219, 148, 86), + "height_scale": 0.5, + } + + root2d_id = timeline.add_track( + "2D Root", + track_type="keyframe", + color=(150, 100, 200), + height_scale=0.5, + ) + timeline_tracks[root2d_id] = { + "name": "2D Root", + "track_type": "keyframe", + "color": (150, 100, 200), + "height_scale": 0.5, + } + lefthand_id = timeline.add_track( + "Left Hand", + track_type="keyframe", + color=(100, 200, 150), + height_scale=0.5, + ) + timeline_tracks[lefthand_id] = { + "name": "Left Hand", + "track_type": "keyframe", + "color": (100, 200, 150), + "height_scale": 0.5, + } + righthand_id = timeline.add_track( + "Right Hand", + track_type="keyframe", + color=(200, 100, 150), + height_scale=0.5, + ) + timeline_tracks[righthand_id] = { + "name": "Right Hand", + "track_type": "keyframe", + "color": (200, 100, 150), + "height_scale": 0.5, + } + leftfoot_id = timeline.add_track( + "Left Foot", + track_type="keyframe", + color=(219, 148, 86), + height_scale=0.5, + ) + timeline_tracks[leftfoot_id] = { + "name": "Left Foot", + "track_type": "keyframe", + "color": (219, 148, 86), + "height_scale": 0.5, + } + rightfoot_id = timeline.add_track( + "Right Foot", + track_type="keyframe", + color=(150, 100, 200), + height_scale=0.5, + ) + timeline_tracks[rightfoot_id] = { + "name": "Right Foot", + "track_type": "keyframe", + "color": (150, 100, 200), + "height_scale": 0.5, + } + return timeline, timeline_tracks + + timeline, timeline_tracks = build_timeline_tracks() + # These handles are part of GuiElements, but the demo currently uses timeline + buttons + # embedded in the Viser UI instead of custom controls. + gui_play_pause_button = None + gui_next_frame_button = None + gui_prev_frame_button = None + gui_timeline = None + gui_duration_slider = None + + # now other gui elements + tab_group = client.gui.add_tab_group() + + # + # Playback and Motion generation controls + # + with tab_group.add_tab("Generate", viser.Icon.WALK): + with client.gui.add_folder("Model Selection", expand_by_default=True): + info = get_model_info(model_name) + if info is None: + info = get_model_info(next(iter(MODEL_NAMES))) + + def get_allowed_skeleton_labels(dataset_ui_label: str) -> list[str]: + labels = get_skeleton_display_names_for_dataset(dataset_ui_label, family="Kimodo") + if HF_MODE: + labels = [label for label in labels if get_skeleton_key_from_display_name(label) != "SMPLX"] + return labels + + dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label + datasets = ["Rigplay"] if HF_MODE else get_datasets(family="Kimodo") + skeleton_labels = get_allowed_skeleton_labels(dataset_ui_label) + initial_skeleton_label = get_skeleton_display_name(info.skeleton) + if initial_skeleton_label not in skeleton_labels and skeleton_labels: + initial_skeleton_label = skeleton_labels[0] + initial_skeleton_key = ( + get_skeleton_key_from_display_name(initial_skeleton_label) if skeleton_labels else None + ) + models_for_pair = ( + get_models_for_dataset_skeleton(dataset_ui_label, initial_skeleton_key, family="Kimodo") + if initial_skeleton_key is not None + else [] + ) + version_options = [m.display_name for m in models_for_pair] + initial_version = ( + info.display_name + if info.display_name in version_options + else (version_options[0] if version_options else "") + ) + gui_dataset_selector = client.gui.add_dropdown( + "Training dataset", + options=datasets, + initial_value=dataset_ui_label, + visible=not HF_MODE, + ) + gui_skeleton_selector = client.gui.add_dropdown( + "Model" if HF_MODE else "Skeleton", + options=skeleton_labels, + initial_value=initial_skeleton_label, + ) + gui_version_selector = client.gui.add_dropdown( + "Version", + options=version_options, + initial_value=initial_version, + ) + gui_version_selector.visible = len(models_for_pair) > 1 + gui_model_display = client.gui.add_markdown( + content=f"**Model:** {initial_version}", + ) + gui_load_model_button = client.gui.add_button( + "Load model", + hint="Load the selected model (dataset, skeleton, version).", + ) + + class ModelSelectorHandle: + """Wrapper so session and callbacks can treat three dropdowns as one.""" + + def __init__(self): + self._dataset = gui_dataset_selector + self._skeleton = gui_skeleton_selector + self._version = gui_version_selector + self._display = gui_model_display + + @property + def value(self) -> str: + return get_short_key_from_display_name(self._version.value) or "" + + def set_from_short_key(self, short_key: str) -> None: + info = get_model_info(short_key) + if info is None: + return + dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label + self._dataset.value = dataset_ui_label + self._skeleton.options = get_allowed_skeleton_labels(dataset_ui_label) + skeleton_label = get_skeleton_display_name(info.skeleton) + if skeleton_label not in self._skeleton.options and self._skeleton.options: + skeleton_label = self._skeleton.options[0] + self._skeleton.value = skeleton_label + skeleton_key = get_skeleton_key_from_display_name(skeleton_label) + if skeleton_key is None: + return + models = get_models_for_dataset_skeleton(dataset_ui_label, skeleton_key, family="Kimodo") + self._version.options = [m.display_name for m in models] + self._version.value = ( + info.display_name if info.display_name in self._version.options else self._version.options[0] + ) + self._version.visible = len(models) > 1 + self._display.content = f"**Model:** {self._version.value}" + + gui_model_selector = ModelSelectorHandle() + + with client.gui.add_folder("Examples", expand_by_default=True): + examples_base_dir = demo.get_examples_base_dir(model_name, absolute=True) + example_dict = viser_utils.load_example_cases(examples_base_dir) + example_names = list(example_dict.keys()) + example_names.append(QWEN_EXAMPLE_NAME) + gui_examples_dropdown = client.gui.add_dropdown( + "Example", + options=example_names, + initial_value=example_names[0], + ) + gui_load_example_button = client.gui.add_button( + "Load Example", + hint="Load the selected example (or Qwen agentic prompt plan).", + disabled=False, + ) + + def update_examples_dropdown( + new_example_dict: dict[str, str], + keep_selection: bool = True, + ) -> None: + example_names_local = list(new_example_dict.keys()) + if QWEN_EXAMPLE_NAME not in example_names_local: + example_names_local.append(QWEN_EXAMPLE_NAME) + if QWEN_EXAMPLE_LEGACY_NAME not in example_names_local: + example_names_local.append(QWEN_EXAMPLE_LEGACY_NAME) + gui_examples_dropdown.options = example_names_local + if keep_selection and gui_examples_dropdown.value in example_names_local: + return + gui_examples_dropdown.value = example_names_local[0] + + with client.gui.add_folder("Generate", expand_by_default=True): + gui_duration = client.gui.add_markdown(content=f"Total duration: {DEFAULT_CUR_DURATION:.1f} (sec)") + + def update_duration_gui(duration): + gui_duration.content = f"Total duration: {duration:.1f} (sec)" + + def compute_prompt_num_frames(prompt_values): + """Convert timeline prompt bounds to per-prompt frame counts. + + Convention in this demo: + - All prompts except the last are treated as [start_frame, end_frame) + (end is exclusive). + - The last prompt is treated as [start_frame, end_frame] (end is inclusive). + - This assumes the prompts values are sorted by start_frame. + """ + if len(prompt_values) == 0: + return [] + num_frames = [] + for i, x in enumerate(prompt_values): + cur = x.end_frame - x.start_frame + if i == len(prompt_values) - 1: + cur += 1 + num_frames.append(cur) + return num_frames + + def update_duration_auto(): + session = demo.client_sessions[client_id] + prompt_values = sorted( + [x for x in timeline._prompts.values()], + key=lambda x: x.start_frame, + ) + num_frames = compute_prompt_num_frames(prompt_values) + total_nb_frames = sum(num_frames) + cur_duration = total_nb_frames / session.model_fps + set_new_duration(client_id, cur_duration) + update_duration_gui(cur_duration) + + gui_num_samples_slider = client.gui.add_slider( + "Num Samples", + min=1, + max=10, + step=1, + initial_value=1, + visible=not HF_MODE, + ) + + gui_use_soma_layer_checkbox = client.gui.add_checkbox( + "SOMA layer", + initial_value=False, + visible="soma" in (model_name or ""), + ) + + with client.gui.add_folder("Model Parameters", expand_by_default=False): + gui_seed = client.gui.add_number("Seed", initial_value=42) + + with client.gui.add_folder("Diffusion", expand_by_default=False): + gui_diffusion_steps_slider = client.gui.add_slider( + "Denoising Steps", + min=2, + max=1000, + step=10, + initial_value=100, + ) + with client.gui.add_folder("Classifier-Free Guidance", expand_by_default=False): + gui_cfg_checkbox = client.gui.add_checkbox( + "Enable", + initial_value=True, + visible=True, + ) + + gui_cfg_text_weight_slider = client.gui.add_slider( + "Text Weight", + min=0.0, + max=5.0, + step=0.1, + initial_value=2.0, + visible=True, + ) + gui_cfg_constraint_weight_slider = client.gui.add_slider( + "Constraint Weight", + min=0.0, + max=5.0, + step=0.1, + initial_value=2.0, + visible=True, + ) + with client.gui.add_folder( + "Transitions", + expand_by_default=False, + visible=SHOW_TRANSITION_PARAMS, + ): + gui_num_transition_frames_slider = client.gui.add_slider( + "Transition frames", + min=1, + max=10, + step=1, + initial_value=NB_TRANSITION_FRAMES, + visible=True, + ) + gui_share_transition_checkbox = client.gui.add_checkbox( # noqa + "Override previous frames", + initial_value=False, + visible=True, + ) + gui_percentage_transition_sharing_slider = client.gui.add_slider( + "Percentage overriding frames", + min=0, + max=30, + step=1, + initial_value=10, + visible=True, + ) + + @gui_share_transition_checkbox.on_update + def _(event: viser.GuiEvent) -> None: + if get_active_session(event.client) is None: + return + # disable the slider if sharing transition is False + gui_percentage_transition_sharing_slider.visible = gui_share_transition_checkbox.value + + with client.gui.add_folder("Post Processing", expand_by_default=False): + _model_name = model_name or "" + _postprocess_visible = "g1" not in _model_name + gui_postprocess_checkbox = client.gui.add_checkbox( + "Enable", + initial_value=INIT_POSTPROCESSING, + hint="Apply motion post-processing (not available for G1)", + visible=_postprocess_visible, + ) + gui_root_margin = client.gui.add_number( + "Root Margin", + min=0.0, + # max=0.5, + step=0.01, + initial_value=0.04, + hint="Margin for root position (meters). Lower values pin root closer to target.", + visible=INIT_POSTPROCESSING and _postprocess_visible, + ) + + @gui_postprocess_checkbox.on_update + def _(event: viser.GuiEvent) -> None: + if get_active_session(event.client) is None: + return + # disable the slider if sharing transition is False + gui_root_margin.visible = gui_postprocess_checkbox.value + + gui_real_robot_rotations_checkbox = client.gui.add_checkbox( + "Real robot rotations", + initial_value=False, + hint="Project joint rotations to G1 real robot DoF (1-DoF per joint) and clamp to axis limits from the MuJoCo XML.", + visible="g1" in _model_name, + ) + + with client.gui.add_folder("Qwen Auto-Prompts", expand_by_default=True): + gui_qwen_scene = client.gui.add_text( + "Scene context", + initial_value="A lone figure moving through an empty plaza", + hint="Describe the scene or character context for Qwen to generate motion prompts.", + ) + gui_qwen_actions = client.gui.add_slider( + "Target actions", + min=1, + max=10, + step=1, + initial_value=6, + hint="Number of prompt segments to place on the timeline.", + ) + gui_qwen_auto_run = client.gui.add_checkbox( + "Auto-run Generate after loading prompts", + initial_value=False, + ) + gui_qwen_status = client.gui.add_markdown(content="") + gui_qwen_button = client.gui.add_button("Fill Timeline via Qwen", color="blue") + + @gui_qwen_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + def _run_qwen_fill_and_maybe_generate() -> None: + gui_qwen_button.disabled = True + gui_qwen_status.content = "⏳ Calling Qwen…" + target_actions = int(gui_qwen_actions.value) + history: list[str] = [] + all_texts: list[str] = [] + all_durations: list[float] = [] + rounds = 0 + + # Keep requesting batches until we fill target actions (max 10). + while len(all_texts) < target_actions and rounds < 6: + remaining = target_actions - len(all_texts) + batch, history = call_qwen_for_prompts( + scene=gui_qwen_scene.value, + history=history, + requested_actions=min(5, remaining), + ) + batch_texts = batch.get("texts", []) + batch_durations = batch.get("durations", []) + for t, d in zip(batch_texts, batch_durations): + if len(all_texts) >= target_actions: + break + all_texts.append(t) + all_durations.append(float(d)) + rounds += 1 + + if len(all_texts) == 0: + gui_qwen_status.content = "⚠ Qwen did not return usable prompts" + gui_qwen_button.disabled = False + return + + fps = session.model_fps + event_client.timeline.clear_prompts() + frame_cursor = 0 + for i, (txt, dur) in enumerate(zip(all_texts, all_durations)): + n = max(1, int(round(dur * fps))) + start_f = frame_cursor + end_f = frame_cursor + n if i < len(all_texts) - 1 else frame_cursor + n - 1 + color = PROMPT_COLORS[i % len(PROMPT_COLORS)] + event_client.timeline.add_prompt(txt, start_f, end_f, color=color) + frame_cursor += n + + # Keep timeline readable by expanding zoom to planned sequence length. + target_visible_frames = int(math.ceil(1.10 * frame_cursor)) + event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames)) + + update_duration_auto() + gui_qwen_status.content = f"✓ Loaded {len(all_texts)} Qwen prompt segments" + + if gui_qwen_auto_run.value: + gui_qwen_status.content = f"✓ Loaded {len(all_texts)} prompts, generating motion…" + try: + demo.generate( + event_client, + all_texts, + [max(1, int(round(d * fps))) for d in all_durations], + gui_num_samples_slider.value, + gui_seed.value, + gui_diffusion_steps_slider.value, + cfg_weight=[ + gui_cfg_text_weight_slider.value, + gui_cfg_constraint_weight_slider.value, + ], + cfg_type="separated" if gui_cfg_checkbox.value else "nocfg", + postprocess_parameters={ + "post_processing": gui_postprocess_checkbox.value, + "root_margin": gui_root_margin.value, + }, + transitions_parameters={ + "num_transition_frames": gui_num_transition_frames_slider.value, + "share_transition": gui_share_transition_checkbox.value, + "percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100, + }, + real_robot_rotations=gui_real_robot_rotations_checkbox.value, + ) + gui_qwen_status.content = f"✓ Generated motion from {len(all_texts)} Qwen actions" + except Exception as exc: + gui_qwen_status.content = f"⚠ Generate error: {exc}" + + gui_qwen_button.disabled = False + + threading.Thread(target=_run_qwen_fill_and_maybe_generate, daemon=True).start() + + gui_generate_button = client.gui.add_button("Generate", color="green") + with client.gui.add_folder("Constraints", expand_by_default=False): + gui_gizmo_space_dropdown = client.gui.add_dropdown( + "Gizmo space", + ("Local", "World"), + initial_value="Local", + visible="g1" not in _model_name, + ) + gui_edit_constraint_button = client.gui.add_button("Enter Editing Mode") + gui_snap_to_constraint_button = client.gui.add_button( + "Snap to Constraint", + disabled=True, + ) + gui_reset_constraint_button = client.gui.add_button( + "Reset Constraint", + disabled=True, + ) + gui_undo_drag_button = client.gui.add_button( + "Undo Move", + disabled=True, + ) + + with client.gui.add_folder("Root 2D Options", expand_by_default=True): + gui_dense_path_checkbox = client.gui.add_checkbox( + "Make Smooth Path", + initial_value=False, + visible=True, + ) + + gui_show_only_current_constraint_checkbox = client.gui.add_checkbox( + "Show only Current", + initial_value=False, + hint="Show only constraint overlays at the current frame; uncheck to show all.", + ) + + def apply_constraint_overlay_visibility(session: ClientSession) -> None: + demo._apply_constraint_overlay_visibility(session) + + @gui_show_only_current_constraint_checkbox.on_update + def _(event: viser.GuiEvent) -> None: + session = get_active_session(event.client) + if session is None: + return + session.show_only_current_constraint = gui_show_only_current_constraint_checkbox.value + apply_constraint_overlay_visibility(session) + + gui_clear_all_constraints_button = client.gui.add_button( + "Clear All Constraints", + color="red", + ) + + def has_constraint_at_frame(session: ClientSession, frame_idx: int) -> bool: + for constraint_name in ["Full-Body", "End-Effectors", "2D Root"]: + constraint = session.constraints.get(constraint_name) + if constraint is None: + continue + if frame_idx in constraint.keyframes: + return True + return False + + def update_snap_to_constraint_button(session: ClientSession) -> None: + gui_snap_to_constraint_button.disabled = not has_constraint_at_frame(session, session.frame_idx) + + def ensure_edit_snapshot(session: ClientSession, motion, frame_idx: int) -> None: + if session.edit_mode_snapshot is None: + session.edit_mode_snapshot = {} + if frame_idx in session.edit_mode_snapshot: + return + session.edit_mode_snapshot[frame_idx] = { + "joints_pos": motion.get_joints_pos(frame_idx), + "joints_rot": motion.get_joints_rot(frame_idx), + } + + def _update_dense_path(motion, session): + constraint_info = session.constraints["2D Root"].get_constraint_info() + + if len(constraint_info["frame_idx"]) > 0: + min_root_frame = min(constraint_info["frame_idx"]) + max_root_frame = max(constraint_info["frame_idx"]) + motion.set_projected_root_pos_path( + constraint_info["root_pos"][:, [0, 2]], + min_frame_idx=min_root_frame, + max_frame_idx=max_root_frame, + ) + + # Delay (ms) after last keyframe/interval move before updating path = "on release". + DENSE_PATH_AFTER_RELEASE_MS = 300 + + def _schedule_dense_path_after_release(session): + """Schedule a single path update to run after user stops dragging.""" + if "2D Root" not in session.constraints or not session.constraints["2D Root"].dense_path: + return + tdata = session.timeline_data + if tdata.get("dense_path_after_release_timer"): + tdata["dense_path_after_release_timer"].cancel() + delay = DENSE_PATH_AFTER_RELEASE_MS / 1000.0 + + def run(): + if not demo.client_active(client_id): + return + sess = demo.client_sessions[client_id] + tdata["dense_path_after_release_timer"] = None + if "2D Root" not in sess.constraints or not sess.constraints["2D Root"].dense_path: + return + mot = list(sess.motions.values())[0] + _update_dense_path(mot, sess) + + t = threading.Timer(delay, run) + tdata["dense_path_after_release_timer"] = t + t.start() + + @gui_dense_path_checkbox.on_update + def _(event: viser.GuiEvent) -> None: + session = get_active_session(event.client) + if session is None: + return + + if gui_dense_path_checkbox.value: + # Make sure 0 and max_frame_idx keyframes are added to the constraint + # since dense path should cover full duration for best model performance + root_2d_track = session.timeline_data["tracks_ids"]["2D Root"] + + # add a locked keyframe at 0 + start_keyframe_id = client.timeline.add_locked_keyframe( # noqa + root_2d_track, + 0, + opacity=0.0, + ) + session.timeline_data["keyframes"][start_keyframe_id] = { + "frame": 0, + "track_id": root_2d_track, + "locked": True, + "opacity": 0.0, + "value": None, + } + add_constraint_callback( + start_keyframe_id, + "2D Root", + (0, 0), + verbose=False, + ) + + # add a locked keyframe at max_frame_idx + end_keyframe_id = client.timeline.add_locked_keyframe( + root_2d_track, + session.max_frame_idx, + opacity=0.0, + ) + session.timeline_data["keyframes"][end_keyframe_id] = { + "frame": session.max_frame_idx, + "track_id": root_2d_track, + "locked": True, + "opacity": 0.0, + "value": None, + } + add_constraint_callback( + end_keyframe_id, + "2D Root", + (session.max_frame_idx, session.max_frame_idx), + verbose=False, + ) + + # add a locked interval only for visual purposes + locked_interval = client.timeline.add_locked_interval( # noqa + root_2d_track, + start_frame=0, + end_frame=session.max_frame_idx, + ) + session.timeline_data["intervals"][locked_interval] = { + "track_id": root_2d_track, + "start_frame_idx": 0, + "end_frame_idx": session.max_frame_idx, + "locked": True, + "opacity": 0.3, + "value": None, + } + + session.constraints["2D Root"].set_dense_path(gui_dense_path_checkbox.value) + if session.constraints["2D Root"].dense_path: + # update the character motion to reflect the full path + # will be full length by construction, no need to specify min/max frame idx + motion = list(session.motions.values())[0] + _update_dense_path(motion, session) + + # remove locked interval and locked keyframes + if not gui_dense_path_checkbox.value: + # Get all locked keyframes + keyframes_to_remove = [] + for uuid, keyframe in client.timeline._keyframes.items(): + if keyframe.locked: + keyframes_to_remove.append(uuid) + _data = session.timeline_data["keyframes"][uuid] + remove_constraint_callback( + uuid, + constraint_type=session.timeline_data["tracks"][_data["track_id"]]["name"], + frame_range=(_data["frame"], _data["frame"]), + verbose=False, + ) + + intervals_to_remove = [] + # remove all locked intervals + for uuid, interval in client.timeline._intervals.items(): + if interval.locked: + intervals_to_remove.append(uuid) + + # removing keyframes and intervals + for uuid in keyframes_to_remove: + client.timeline.remove_keyframe(uuid) + + for uuid in intervals_to_remove: + client.timeline.remove_interval(uuid) + + apply_constraint_overlay_visibility(session) + + with client.gui.add_folder( + "Load/Save", + expand_by_default=False, + visible=not HF_MODE, + ): + with client.gui.add_folder("Motion", expand_by_default=False): + gui_save_motion_path_text = client.gui.add_text("Save Path", initial_value="output") + gui_save_motion_format_dropdown = client.gui.add_dropdown( + "Save Format", + options=( + ["NPZ", "CSV"] + if "g1" in model_name.lower() + else ["NPZ", "AMASS NPZ"] + if "smplx" in model_name.lower() + else ["NPZ", "BVH"] + ), + initial_value="NPZ", + ) + gui_save_motion_button = client.gui.add_button( + "Save Motion", + hint="Save the current motion (format + path above)", + ) + gui_load_motion_path_text = client.gui.add_text( + "Load Path", + initial_value="output.npz", + hint="SOMA .bvh, Kimodo or AMASS .npz, or G1 MuJoCo .csv", + ) + gui_load_motion_button = client.gui.add_button( + "Load Motion", + hint="Load the selected motion", + ) + with client.gui.add_folder("Constraints", expand_by_default=False): + gui_save_constraints_path_text = client.gui.add_text( + "Save Path", initial_value="output_constraints.json" + ) + gui_save_constraints_button = client.gui.add_button("Save Constraints") + gui_load_constraints_path_text = client.gui.add_text( + "Load Path", initial_value="output_constraints.json" + ) + gui_load_constraints_button = client.gui.add_button("Load Constraints") + with client.gui.add_folder("Example", expand_by_default=False): + gui_save_example_path_text = client.gui.add_text( + "Save Dir", + initial_value=os.path.join( + demo.get_examples_base_dir(model_name, absolute=True), + "custom_example_1", + ), + ) + gui_save_example_button = client.gui.add_button("Save Example") + gui_load_example_path_text = client.gui.add_text( + "Load Dir", + initial_value=os.path.join( + demo.get_examples_base_dir(model_name, absolute=True), + "custom_example_1", + ), + ) + gui_load_gt_checkbox = client.gui.add_checkbox( + "Load GT instead", + initial_value=False, + ) + gui_load_example_from_path_button = client.gui.add_button("Load Example") + + def _get_primary_motion(session: ClientSession): + return list(session.motions.values())[0] + + def _motion_to_numpy_dict(motion) -> dict[str, np.ndarray]: + joints_pos = motion.joints_pos.detach().cpu().numpy() + joints_rot = motion.joints_rot.detach().cpu().numpy() + joints_local_rot = motion.joints_local_rot.detach().cpu().numpy() + + if joints_pos.ndim != 3: + raise ValueError(f"Expected unbatched joints_pos with shape [T, J, 3], got {joints_pos.shape}") + if joints_rot.ndim != 4: + raise ValueError(f"Expected unbatched joints_rot with shape [T, J, 3, 3], got {joints_rot.shape}") + if joints_local_rot.ndim != 4: + raise ValueError( + "Expected unbatched joints_local_rot with shape " f"[T, J, 3, 3], got {joints_local_rot.shape}" + ) + + motion_data = { + "posed_joints": joints_pos, + "global_rot_mats": joints_rot, + "local_rot_mats": joints_local_rot, + "root_positions": joints_pos[:, motion.skeleton.root_idx, :], + } + if motion.foot_contacts is not None: + foot_contacts = motion.foot_contacts.detach().cpu().numpy() + if foot_contacts.ndim != 2: + raise ValueError( + f"Expected unbatched foot_contacts with shape [T, C], got {foot_contacts.shape}" + ) + motion_data["foot_contacts"] = foot_contacts + return motion_data + + def _coerce_save_path(raw_path: str, *, ext: str) -> str: + """Ensure the save path ends with the correct extension for the chosen format.""" + name = (raw_path or "").strip() + if name == "": + return f"output{ext}" + known_exts = (".npz", ".bvh", ".csv") + if name.lower().endswith(known_exts): + return os.path.splitext(name)[0] + ext + if os.path.splitext(name)[1] == "": + return name + ext + return name + + def save_motion(client, save_path, fmt): + session = demo.client_sessions[client.client_id] + motion = _get_primary_motion(session) + motion_data = _motion_to_numpy_dict(motion) + + if fmt == "BVH": + save_path = _coerce_save_path(save_path, ext=".bvh") + save_motion_bvh( + save_path, + motion.joints_local_rot, + motion.joints_pos[:, session.skeleton.root_idx, :], + skeleton=session.skeleton, + fps=float(session.model_fps), + ) + elif fmt == "CSV": + save_path = _coerce_save_path(save_path, ext=".csv") + data = g1_csv_to_bytes(motion_data, session.skeleton, demo.device) + with open(save_path, "wb") as f: + f.write(data) + elif fmt == "AMASS NPZ": + save_path = _coerce_save_path(save_path, ext=".npz") + data = amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps) + with open(save_path, "wb") as f: + f.write(data) + else: + save_path = _coerce_save_path(save_path, ext=".npz") + save_kimodo_npz(save_path, motion_data) + return save_path + + @gui_save_motion_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + if get_active_session(event_client) is None: + return + + raw_path = gui_save_motion_path_text.value + fmt = str(gui_save_motion_format_dropdown.value).upper() + try: + saved_path = save_motion(event_client, raw_path, fmt) + event_client.add_notification( + title="Motion saved!", + body=f"Saved motion to {saved_path}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to save motion!", + body=str(e), + auto_close_seconds=5.0, + color="red", + ) + + def load_motion(client, load_path): + session = demo.client_sessions[client.client_id] + + fps_arg = session.model_fps if session.model_fps and session.model_fps > 0 else None + motion_dict, num_joints_motion = load_motion_file(load_path, target_fps=fps_arg) + + target_skel = registry_skeleton_for_joint_count(num_joints_motion) + current_info = get_model_info(session.model_name) + current_skel = current_info.skeleton if current_info is not None else None + + if current_skel != target_skel: + dataset = current_info.dataset if current_info is not None else "RP" + new_key = kimodo_short_key_for_skeleton_dataset(target_skel, dataset) + if new_key is None: + new_key = kimodo_short_key_for_skeleton_dataset(target_skel, "RP") + if new_key is None: + raise ValueError( + f"No Kimodo model found for skeleton {target_skel} (motion has J={num_joints_motion})." + ) + if new_key != session.model_name: + gui_model_selector.set_from_short_key(new_key) + apply_model_selection(new_key) + _update_visibility_for_loaded_model(new_key) + client.add_notification( + title="Model switched", + body=f"Switched to {new_key} to match loaded motion (J={num_joints_motion}).", + auto_close_seconds=5.0, + color="blue", + ) + session = demo.client_sessions[client.client_id] + + joints_pos = motion_dict["posed_joints"].to(device=demo.device, dtype=torch.float32) + joints_rot = motion_dict["global_rot_mats"].to(device=demo.device, dtype=torch.float32) + foot_contacts = motion_dict.get("foot_contacts") + if foot_contacts is not None: + foot_contacts = foot_contacts.to(device=demo.device, dtype=torch.float32) + + # Support both batched [B, T, J, 3] and unbatched [T, J, 3]; take first sample if batched + if joints_pos.ndim == 4: + joints_pos = joints_pos[0] + if joints_rot.ndim == 5: + joints_rot = joints_rot[0] + if foot_contacts is not None and foot_contacts.ndim == 3: + foot_contacts = foot_contacts[0] + + # Motion must match the current model's skeleton after auto-switch + num_joints_loaded = joints_pos.shape[1] + num_joints_skeleton = session.skeleton.nbjoints + if num_joints_loaded != num_joints_skeleton: + # Backward compat: expand 30-joint SOMA motion to 77 + if ( + num_joints_loaded == 30 + and num_joints_skeleton == 77 + and isinstance(session.skeleton, SOMASkeleton77) + ): + from kimodo.skeleton import global_rots_to_local_rots + + skel30 = SOMASkeleton30().to(demo.device) + if "local_rot_mats" in motion_dict: + local_rot_30 = motion_dict["local_rot_mats"].to(device=demo.device, dtype=torch.float32) + if local_rot_30.ndim == 4: + local_rot_30 = local_rot_30[0] + else: + local_rot_30 = global_rots_to_local_rots(joints_rot, skel30) + local_rot_77 = skel30.to_SOMASkeleton77(local_rot_30) + root_positions = joints_pos[:, skel30.root_idx, :] + joints_rot, joints_pos, _ = session.skeleton.fk(local_rot_77, root_positions) + + if foot_contacts is not None and foot_contacts.shape[-1] == 4: + foot_contacts = torch.cat( + [ + foot_contacts[..., :2], + foot_contacts[..., 1:2], + foot_contacts[..., 2:4], + foot_contacts[..., 3:4], + ], + dim=-1, + ) + else: + raise ValueError( + f"The loaded motion has {num_joints_loaded} joints but the current model " + f"({session.model_name}) has {num_joints_skeleton} joints. " + "Load a motion generated with the same skeleton, or switch the model to match the motion." + ) + elif joints_rot.shape[1] != num_joints_skeleton: + raise ValueError( + f"Rotation data has {joints_rot.shape[1]} joints but the current model has " + f"{num_joints_skeleton} joints. The NPZ may be corrupted or from a different skeleton." + ) + + # Apply G1 real robot projection (1-DoF per joint + axis limits) if enabled. + if ( + "g1" in session.model_name + and isinstance(session.skeleton, G1Skeleton34) + and gui_real_robot_rotations_checkbox.value + ): + joints_pos, joints_rot = generation.apply_g1_real_robot_projection( + session.skeleton, joints_pos, joints_rot + ) + + # Update duration and frame range based on loaded motion + num_frames = joints_pos.shape[0] + duration = num_frames / session.model_fps + + # Update GUI elements + session.cur_duration = duration + session.max_frame_idx = num_frames - 1 + + # Clear existing motions and add the loaded one + demo.clear_motions(client.client_id) + demo.add_character_motion( + client, + session.skeleton, + joints_pos, + joints_rot, + foot_contacts, + ) + + # Reset to frame 0 + demo.set_frame(client.client_id, 0) + + @gui_load_motion_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + load_path = gui_load_motion_path_text.value + loading_notif = event_client.add_notification( + title="Loading motion...", + body=f"Loading from {load_path}", + loading=True, + with_close_button=False, + auto_close_seconds=None, + ) + try: + load_motion(event_client, load_path) + + loading_notif.title = "Motion loaded!" + loading_notif.body = f"Loaded motion from {load_path} ({session.max_frame_idx + 1} frames, {session.cur_duration:.2f}s)" + loading_notif.loading = False + loading_notif.with_close_button = True + loading_notif.auto_close_seconds = 5.0 + loading_notif.color = "green" + except Exception as e: + import traceback + + traceback.print_exc() + loading_notif.title = "Failed to load motion!" + loading_notif.body = str(e) + loading_notif.loading = False + loading_notif.with_close_button = True + loading_notif.auto_close_seconds = 10.0 + loading_notif.color = "red" + + def save_constraints(client, save_path): + session = demo.client_sessions[client.client_id] + # Keep save behavior aligned with demo frame convention: + # valid frame indices are [0, max_frame_idx], so count is +1. + num_frames = session.max_frame_idx + 1 + model_bundle = demo.load_model(session.model_name) + constraints_lst = demo.compute_model_constraints_lst(session, model_bundle, num_frames) + save_constraints_lst(save_path, constraints_lst) + + @gui_save_constraints_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + if get_active_session(event_client) is None: + return + + try: + save_path = gui_save_constraints_path_text.value + save_constraints(event_client, save_path) + event_client.add_notification( + title="Constraints saved!", + body=f"Saved constraints to {save_path}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to save constraints!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + + def load_constraints(client, load_path): + session = demo.client_sessions[client.client_id] + constraints_lst = load_constraints_lst(load_path, skeleton=session.skeleton) + + # Clear existing constraints first + with session.timeline_data["keyframe_update_lock"]: + for constraint in list(session.constraints.values()): + constraint.clear() + client.timeline.clear_keyframes() + client.timeline.clear_intervals() + + # Add loaded constraints to the session + # We need to directly add constraint data, not read from current motion + device = demo.device + for constraint_obj in constraints_lst: + constraint_type = constraint_obj.name + + # decompose the frame indices into intervals or single keyframes + frame_indices = constraint_obj.frame_indices + ( + intervals, + intervals_indices, + single_frames, + single_frames_indices, + ) = extract_intervals_and_singles(frame_indices) + + load_targets: list[dict] = [] + root_pos = None + + if constraint_type == "root2d": + # smooth_root_2d is [T, 2] (x, z), convert to [T, 3] (x, 0, z) + num_frames = constraint_obj.smooth_root_2d.shape[0] + root_pos = torch.zeros(num_frames, 3, device=device) + root_pos[:, 0] = constraint_obj.smooth_root_2d[:, 0] + root_pos[:, 2] = constraint_obj.smooth_root_2d[:, 1] + load_targets = [ + { + "track_name": "2D Root", + "constraint_track": session.constraints["2D Root"], + } + ] + elif constraint_type == "fullbody": + load_targets = [ + { + "track_name": "Full-Body", + "constraint_track": session.constraints["Full-Body"], + } + ] + elif constraint_type in { + "left-hand", + "right-hand", + "left-foot", + "right-foot", + }: + track_name = { + "left-hand": "Left Hand", + "right-hand": "Right Hand", + "left-foot": "Left Foot", + "right-foot": "Right Foot", + }[constraint_type] + load_targets = [ + { + "track_name": track_name, + "constraint_track": session.constraints["End-Effectors"], + "joint_names": constraint_obj.joint_names, + "end_effector_type": constraint_type, + } + ] + elif constraint_type in {"end-effector", "end-effectors"}: + # Backward-compatible loader: + # split a generic end-effector constraint into per-limb timeline tracks. + joint_names_set = set(constraint_obj.joint_names) + for jname, track_name, eff_type in [ + ("LeftHand", "Left Hand", "left-hand"), + ("RightHand", "Right Hand", "right-hand"), + ("LeftFoot", "Left Foot", "left-foot"), + ("RightFoot", "Right Foot", "right-foot"), + ]: + if jname not in joint_names_set: + continue + target_joint_names = [jname] + if "Hips" in joint_names_set: + target_joint_names.append("Hips") + load_targets.append( + { + "track_name": track_name, + "constraint_track": session.constraints["End-Effectors"], + "joint_names": target_joint_names, + "end_effector_type": eff_type, + } + ) + if not load_targets: + raise KeyError( + "No recognized end-effector joint in constraint " + f"joint_names={constraint_obj.joint_names}" + ) + else: + raise KeyError(f"Unsupported constraint type in loader: {constraint_type}") + + for target in load_targets: + track_id = session.timeline_data["tracks_ids"][target["track_name"]] + constraint_track = target["constraint_track"] + + # add intervals + for (start_idx, end_idx), (start_idx_t, end_idx_t) in zip(intervals, intervals_indices): + # Add to timeline + interval_id = client.timeline.add_interval(track_id, start_idx, end_idx) + session.timeline_data["intervals"][interval_id] = { + "track_id": track_id, + "start_frame_idx": start_idx, + "end_frame_idx": end_idx, + "locked": False, + "opacity": 1.0, + "value": None, + } + if constraint_type == "root2d": + constraint_track.add_interval( + interval_id, + start_idx, + end_idx, + root_pos[start_idx_t : end_idx_t + 1], + ) + elif constraint_type == "fullbody": + constraint_track.add_interval( + interval_id, + start_idx, + end_idx, + constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1], + constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1], + ) + else: + constraint_track.add_interval( + interval_id, + start_idx, + end_idx, + constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1], + constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1], + target["joint_names"], + target["end_effector_type"], + ) + + # add keyframes + for frame, frame_t in zip(single_frames, single_frames_indices): + # Add to timeline + keyframe_id = client.timeline.add_keyframe(track_id, frame) + session.timeline_data["keyframes"][keyframe_id] = { + "track_id": track_id, + "frame": frame, + "locked": False, + "opacity": 1.0, + "value": None, + } + if constraint_type == "root2d": + constraint_track.add_keyframe( + keyframe_id, + frame, + root_pos[frame_t], + ) + elif constraint_type == "fullbody": + constraint_track.add_keyframe( + keyframe_id, + frame, + constraint_obj.global_joints_positions[frame_t], + constraint_obj.global_joints_rots[frame_t], + ) + else: + constraint_track.add_keyframe( + keyframe_id, + frame, + constraint_obj.global_joints_positions[frame_t], + constraint_obj.global_joints_rots[frame_t], + target["joint_names"], + target["end_effector_type"], + ) + + @gui_load_constraints_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + if get_active_session(event_client) is None: + return + + try: + load_path = gui_load_constraints_path_text.value + load_constraints(event_client, load_path) + session = demo.client_sessions[event_client.client_id] + apply_constraint_overlay_visibility(session) + + event_client.add_notification( + title="Constraints loaded!", + body=f"Loaded constraints from {load_path}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to load constraints!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + + with client.gui.add_folder("Exports", expand_by_default=False): + with client.gui.add_folder("Screenshot", expand_by_default=False, visible=not HF_MODE): + gui_screenshot_path_text = client.gui.add_text( + "Save Path", + initial_value="render.png", + hint="Filename for the screenshot (PNG).", + ) + gui_screenshot_button = client.gui.add_button( + "Download Screenshot", + hint="Capture the current canvas and download a PNG.", + ) + with client.gui.add_folder("Video", expand_by_default=False, visible=not HF_MODE): + gui_video_path_text = client.gui.add_text( + "Save Path", + initial_value="render.mp4", + hint="Filename for the video (MP4).", + ) + gui_video_button = client.gui.add_button( + "Download Video", + hint="Render every frame and download as MP4.", + ) + with client.gui.add_folder("Motion", expand_by_default=True): + gui_download_name_text = client.gui.add_text( + "Name", + initial_value="output", + hint="Base filename to save as (extension will be added based on format if omitted).", + ) + gui_download_format_dropdown = client.gui.add_dropdown( + "Format", + options=( + ["NPZ", "CSV"] + if "g1" in model_name.lower() + else ["NPZ", "AMASS NPZ"] + if "smplx" in model_name.lower() + else ["NPZ", "BVH"] + ), + initial_value="NPZ", + ) + gui_download_button = client.gui.add_button( + "Download", + hint="Download the current motion (format + name above).", + ) + + def _download_bytes_to_browser( + event_client: viser.ClientHandle, + *, + data: bytes, + filename: str, + mime_type: str = "application/octet-stream", + ) -> None: + """Trigger a browser download for an in-memory byte payload. + + Important: this intentionally does NOT use `showSaveFilePicker()` to avoid + Chrome/Edge's file-write permission prompt ("this site can see edits you make"). + If you want "always ask where to save", configure your browser download settings. + """ + import base64 + import json + + # Base64 is the most robust way to move binary over our websocket JS channel. + b64 = base64.b64encode(data).decode("ascii") + js = f""" +(() => {{ + const filename = {json.dumps(filename)}; + const mimeType = {json.dumps(mime_type)}; + const b64 = {json.dumps(b64)}; + + // Decode base64 -> Uint8Array. + const binStr = atob(b64); + const bytes = new Uint8Array(binStr.length); + for (let i = 0; i < binStr.length; i++) bytes[i] = binStr.charCodeAt(i); + const blob = new Blob([bytes], {{ type: mimeType }}); + + // Standard browser download behavior. + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(url); +}})(); +""" + # Reuse viser’s JS execution mechanism (used for Plotly setup). + from viser import _messages as _viser_messages + + event_client.gui._websock_interface.queue_message( # type: ignore[attr-defined] + _viser_messages.RunJavascriptMessage(source=js) + ) + + def _motion_to_npz_bytes(motion) -> bytes: + motion_data = _motion_to_numpy_dict(motion) + return kimodo_npz_to_bytes(motion_data) + + def _motion_to_csv_bytes(motion, session: ClientSession) -> bytes: + motion_data = _motion_to_numpy_dict(motion) + return g1_csv_to_bytes(motion_data, session.skeleton, demo.device) + + def _motion_to_amass_npz_bytes(motion, session: ClientSession) -> bytes: + motion_data = _motion_to_numpy_dict(motion) + return amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps) + + def _get_motion_export_formats(loaded_model_name: str) -> list[str]: + model_name_lower = (loaded_model_name or "").lower() + if "g1" in model_name_lower: + return ["NPZ", "CSV"] + if "smplx" in model_name_lower: + return ["NPZ", "AMASS NPZ"] + return ["NPZ", "BVH"] + + def _update_format_dropdown(dropdown, loaded_model_name: str) -> None: + new_options = _get_motion_export_formats(loaded_model_name) + current_value = str(dropdown.value) + dropdown.options = new_options + dropdown.value = current_value if current_value in new_options else new_options[0] + + def _update_motion_export_dropdown(loaded_model_name: str) -> None: + _update_format_dropdown(gui_download_format_dropdown, loaded_model_name) + _update_format_dropdown(gui_save_motion_format_dropdown, loaded_model_name) + + def _coerce_download_filename(raw_name: str, *, ext: str) -> str: + """Coerce a user-entered filename to a safe basename with the desired extension. + + - If empty: uses "output{ext}" + - If no extension: appends ext + - If endswith a known export extension: rewrites extension to ext (prevents mismatches) + - Any provided directory components are stripped + """ + import os + + name = (raw_name or "").strip() + name = os.path.basename(name.replace("\\", "/")) + if name == "": + return f"output{ext}" + + known_exts = (".npz", ".bvh", ".csv", ".png", ".mp4") + lower = name.lower() + if lower.endswith(known_exts): + return os.path.splitext(name)[0] + ext + + root, cur_ext = os.path.splitext(name) + if cur_ext == "": + return name + ext + return name + + def _get_render_size(event_client: viser.ClientHandle) -> tuple[int, int]: + width = int(event_client.camera.image_width) + height = int(event_client.camera.image_height) + if width <= 0 or height <= 0: + # Fall back to a reasonable default if the camera hasn't synced yet. + return (1280, 720) + return (width, height) + + def _round_up_to_multiple(value: int, multiple: int) -> int: + if multiple <= 0: + return value + return ((value + multiple - 1) // multiple) * multiple + + def _download_canvas_to_browser(event_client: viser.ClientHandle, *, filename: str) -> None: + """Use the client-side canvas save path to avoid server-side renders.""" + import json + + js = f""" +(() => {{ + const filename = {json.dumps(filename)}; + const canvases = Array.from(document.querySelectorAll("canvas")); + if (!canvases.length) {{ + console.error("No canvases found to save."); + return; + }} + // Pick the largest canvas by area (usually the main 3D view). + const canvas = canvases.reduce((best, cur) => {{ + const bestArea = (best?.width || 0) * (best?.height || 0); + const curArea = (cur?.width || 0) * (cur?.height || 0); + return curArea > bestArea ? cur : best; + }}, null); + if (!canvas) {{ + console.error("No canvas selected to save."); + return; + }} + canvas.toBlob((blob) => {{ + if (!blob) {{ + console.error("Export failed"); + return; + }} + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(url); + }}, "image/png"); +}})(); +""" + from viser import _messages as _viser_messages + + event_client.gui._websock_interface.queue_message( # type: ignore[attr-defined] + _viser_messages.RunJavascriptMessage(source=js) + ) + + @gui_screenshot_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + if get_active_session(event_client) is None: + return + + try: + filename = _coerce_download_filename( + str(gui_screenshot_path_text.value), + ext=".png", + ) + _download_canvas_to_browser(event_client, filename=filename) + event_client.add_notification( + title="Screenshot download started", + body=f"Saving {filename}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to download screenshot!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + + @gui_video_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + recording_notification: viser.NotificationHandle | None = None + try: + recording_notification = event_client.add_notification( + title="Recording video...", + body="Saving frames, please wait.", + loading=True, + with_close_button=False, + auto_close_seconds=None, + color="blue", + ) + event_client.timeline.disable_constraints() + width, height = _get_render_size(event_client) + # Avoid ffmpeg macro block resizing warnings. + width = _round_up_to_multiple(width, 16) + height = _round_up_to_multiple(height, 16) + original_frame = session.frame_idx + frames = [] + for frame_idx in range(session.max_frame_idx + 1): + demo.set_frame( + event_client.client_id, + frame_idx, + update_timeline=True, + ) + frames.append( + event_client.get_render( + height=height, + width=width, + transport_format="jpeg", + ) + ) + + # Restore the original frame (and timeline). + demo.set_frame(event_client.client_id, original_frame) + + import imageio.v3 as iio + + filename = _coerce_download_filename( + str(gui_video_path_text.value), + ext=".mp4", + ) + payload = iio.imwrite( + "", + frames, + extension=".mp4", + fps=float(session.model_fps), + codec="h264", + plugin="pyav", + ) + event_client.send_file_download(filename, payload, save_immediately=True) + event_client.add_notification( + title="Video download started", + body=f"Saving {filename}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to download video!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + finally: + event_client.timeline.enable_constraints() + if recording_notification is not None: + recording_notification.remove() + + @gui_download_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + motion = _get_primary_motion(session) + try: + fmt = str(gui_download_format_dropdown.value).upper() + raw_name = str(gui_download_name_text.value) + + if fmt == "BVH": + filename = _coerce_download_filename(raw_name, ext=".bvh") + payload = motion_to_bvh_bytes( + motion.joints_local_rot, + motion.joints_pos[:, session.skeleton.root_idx, :], # root positions + skeleton=session.skeleton, + fps=float(session.model_fps), + ) + mime = "text/plain" + elif fmt == "CSV": + filename = _coerce_download_filename(raw_name, ext=".csv") + payload = _motion_to_csv_bytes(motion, session) + mime = "text/csv" + elif fmt == "AMASS NPZ": + filename = _coerce_download_filename(raw_name, ext=".npz") + payload = _motion_to_amass_npz_bytes(motion, session) + mime = "application/octet-stream" + else: + # Default to NPZ (most common and matches existing save/load). + filename = _coerce_download_filename(raw_name, ext=".npz") + payload = _motion_to_npz_bytes(motion) + mime = "application/octet-stream" + + _download_bytes_to_browser( + event_client, + data=payload, + filename=filename, + mime_type=mime, + ) + + event_client.add_notification( + title="Download started", + body=f"Saving {filename}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to download motion!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + + @gui_save_example_button.on_click + def _(event: viser.GuiEvent) -> None: + from kimodo.tools import save_json + + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + save_dir = gui_save_example_path_text.value + if os.path.exists(save_dir): + event_client.add_notification( + title="Failed to save example!", + body="Example directory already exists", + auto_close_seconds=10.0, + color="red", + ) + return + + try: + os.makedirs(save_dir) + # save the constraints + constraint_path = os.path.join(save_dir, "constraints.json") + save_constraints(event_client, constraint_path) + # save the motion + motion_path = os.path.join(save_dir, "motion.npz") + save_motion(event_client, motion_path, "NPZ") + # save the gui metadata + meta_path = os.path.join(save_dir, "meta.json") + prompt_texts = [] + prompt_durations_sec = [] + prompt_values = sorted( + [x for x in client.timeline._prompts.values()], + key=lambda x: x.start_frame, + ) + for i, prompt in enumerate(prompt_values): + prompt_texts.append(prompt.text) + # Match demo/generation convention: + # non-last prompts: [start, end) ; last prompt: [start, end]. + n_frames = prompt.end_frame - prompt.start_frame + if i == len(prompt_values) - 1: + n_frames += 1 + prompt_durations_sec.append(n_frames / session.model_fps) + if len(prompt_texts) == 1: + meta_info = { + "text": prompt_texts[0], + "duration": prompt_durations_sec[0], + } + else: + meta_info = { + "texts": prompt_texts, + "durations": prompt_durations_sec, + } + meta_info["num_samples"] = gui_num_samples_slider.value + meta_info["seed"] = gui_seed.value + meta_info["diffusion_steps"] = gui_diffusion_steps_slider.value + meta_info["cfg"] = { + "enabled": gui_cfg_checkbox.value, + "text_weight": gui_cfg_text_weight_slider.value, + "constraint_weight": gui_cfg_constraint_weight_slider.value, + } + save_json(meta_path, meta_info) + + # update the example dropdown + session.example_dict = viser_utils.load_example_cases(session.examples_base_dir) + update_examples_dropdown(session.example_dict, keep_selection=True) + + event_client.add_notification( + title="Example saved!", + body=f"Saved example to {save_dir}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to save example!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + + def set_new_duration(client_id, new_duration): + session = demo.client_sessions[client_id] + session.cur_duration = new_duration + update_duration_gui(new_duration) + session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) + if session.frame_idx > session.max_frame_idx: + demo.set_frame(client_id, session.max_frame_idx) + + def apply_model_selection(new_model_name: str) -> None: + session = demo.client_sessions[client_id] + if new_model_name == session.model_name: + return + + session.playing = False # Pause playback when switching models. + + old_model_fps = session.model_fps + old_duration = session.cur_duration + old_prompts = [ + (prompt.text, prompt.start_frame, prompt.end_frame) for prompt in client.timeline._prompts.values() + ] + old_default_zoom_frames = client.timeline._default_num_frames_zoom + old_max_zoom_frames = client.timeline._max_frames_zoom + + model_bundle = demo.load_model(new_model_name) + + # Clear motions and constraints when switching models. + if session.edit_mode and session.motions: + exit_editing_mode(session) + session.edit_mode = False + demo.clear_motions(client_id) + with session.timeline_data["keyframe_update_lock"]: + for constraint in list(session.constraints.values()): + constraint.clear() + session.constraints = demo.build_constraint_tracks(client, model_bundle.skeleton) + session.timeline_data["keyframes"] = {} + session.timeline_data["intervals"] = {} + client.timeline.clear_keyframes() + client.timeline.clear_intervals() + + session.model_name = new_model_name + session.model_fps = model_bundle.model_fps + session.skeleton = model_bundle.skeleton + session.motion_rep = model_bundle.motion_rep + session.cur_duration = old_duration + session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) + session.frame_idx = 0 + session.edit_mode = False + + demo.set_timeline_defaults(client.timeline, session.model_fps) + client.timeline.set_current_frame(0) + gui_model_fps.value = session.model_fps + update_duration_gui(session.cur_duration) + + if old_model_fps > 0: + default_zoom_seconds = old_default_zoom_frames / old_model_fps + max_zoom_seconds = old_max_zoom_frames / old_model_fps + new_default_zoom = int(round(default_zoom_seconds * session.model_fps)) + new_max_zoom = int(round(max_zoom_seconds * session.model_fps)) + new_default_zoom = max(1, new_default_zoom) + new_max_zoom = max(new_default_zoom, new_max_zoom) + client.timeline.set_zoom_settings( + default_num_frames_zoom=new_default_zoom, + max_frames_zoom=new_max_zoom, + ) + + client.timeline.clear_prompts() + if old_prompts and old_model_fps > 0: + for i, (prompt_text, start_frame, end_frame) in enumerate(old_prompts): + start_sec = start_frame / old_model_fps + end_sec = end_frame / old_model_fps + new_start = int(round(start_sec * session.model_fps)) + new_end = int(round(end_sec * session.model_fps)) + new_start = max(0, min(new_start, session.max_frame_idx)) + new_end = max(new_start, min(new_end, session.max_frame_idx)) + color = PROMPT_COLORS[i % len(PROMPT_COLORS)] + client.timeline.add_prompt(prompt_text, new_start, new_end, color=color) + + session.examples_base_dir = demo.get_examples_base_dir(new_model_name, absolute=True) + session.example_dict = viser_utils.load_example_cases(session.examples_base_dir) + update_examples_dropdown(session.example_dict, keep_selection=False) + gui_save_example_path_text.value = os.path.join( + demo.get_examples_base_dir(new_model_name, absolute=True), + "custom_example_1", + ) + gui_load_example_path_text.value = os.path.join( + demo.get_examples_base_dir(new_model_name, absolute=True), + "custom_example_1", + ) + + demo.add_character_motion(client, session.skeleton) + apply_constraint_overlay_visibility(session) + + def _update_version_and_display_from_dataset_skeleton() -> None: + dataset_ui = gui_dataset_selector.value + skeleton_display = gui_skeleton_selector.value + skeleton_val = get_skeleton_key_from_display_name(skeleton_display) + if skeleton_val is None: + return + models = get_models_for_dataset_skeleton(dataset_ui, skeleton_val, family="Kimodo") + if not models: + return + gui_version_selector.options = [m.display_name for m in models] + gui_version_selector.value = models[0].display_name + gui_version_selector.visible = len(models) > 1 + gui_model_display.content = f"**Model:** {models[0].display_name}" + + def _update_visibility_for_loaded_model(loaded_model_name: str) -> None: + """Update model-specific controls from the currently loaded model only.""" + if not loaded_model_name: + return + _update_motion_export_dropdown(loaded_model_name) + gui_use_soma_layer_checkbox.visible = "soma" in loaded_model_name + _is_g1 = "g1" in loaded_model_name + gui_real_robot_rotations_checkbox.visible = _is_g1 + gui_postprocess_checkbox.visible = not _is_g1 + gui_root_margin.visible = not _is_g1 and gui_postprocess_checkbox.value + if _is_g1: + gui_gizmo_space_dropdown.value = "Local" + gui_gizmo_space_dropdown.visible = not _is_g1 + gui_gizmo_space_dropdown.disabled = _is_g1 + + def _on_load_model_click(event: viser.GuiEvent) -> None: + """Load the currently selected model (called from Load model button).""" + if get_active_session(event.client) is None: + return + new_model_name = gui_model_selector.value + if not new_model_name: + return + info = get_model_info(new_model_name) + if info is None: + return + session = demo.client_sessions[event.client.client_id] + if new_model_name == session.model_name: + return + loading_notif = event.client.add_notification( + title="Loading model...", + body=f"Loading {info.display_name}", + loading=True, + with_close_button=False, + ) + try: + apply_model_selection(new_model_name) + _update_visibility_for_loaded_model(new_model_name) + loading_notif.title = "Model loaded" + loading_notif.body = f"{info.display_name} is ready." + loading_notif.loading = False + loading_notif.with_close_button = True + loading_notif.auto_close_seconds = 5.0 + loading_notif.color = "green" + except Exception as e: + loading_notif.loading = False + loading_notif.with_close_button = True + event.client.add_notification( + title="Model failed to load", + body=str(e), + color="red", + auto_close_seconds=10.0, + ) + gui_model_selector.set_from_short_key(session.model_name) + + @gui_load_model_button.on_click + def _(event: viser.GuiEvent) -> None: + _on_load_model_click(event) + + @gui_dataset_selector.on_update + def _(event: viser.GuiEvent) -> None: + if get_active_session(event.client) is None: + return + skeleton_labels = get_allowed_skeleton_labels(gui_dataset_selector.value) + gui_skeleton_selector.options = skeleton_labels + gui_skeleton_selector.value = skeleton_labels[0] if skeleton_labels else "" + _update_version_and_display_from_dataset_skeleton() + + @gui_skeleton_selector.on_update + def _(event: viser.GuiEvent) -> None: + if get_active_session(event.client) is None: + return + _update_version_and_display_from_dataset_skeleton() + + @gui_version_selector.on_update + def _(event: viser.GuiEvent) -> None: + if get_active_session(event.client) is None: + return + info = get_model_info(gui_model_selector.value) + if info is not None: + gui_model_display.content = f"**Model:** {info.display_name}" + + @gui_use_soma_layer_checkbox.on_update + def _(event: viser.GuiEvent) -> None: + session = get_active_session(event.client) + if session is None or "soma" not in (session.model_name or ""): + return + + loading_notif = event.client.add_notification( + title="Applying SOMA layer...", + body="Updating mesh.", + loading=True, + with_close_button=False, + ) + try: + current_motion = list(session.motions.values())[0] if session.motions else None + current_frame_idx = session.frame_idx + + # Recreate the character to apply the new SOMA mesh mode selection. + demo.clear_motions(event.client.client_id) + if current_motion is None: + demo.add_character_motion(event.client, session.skeleton) + else: + demo.add_character_motion( + event.client, + session.skeleton, + current_motion.joints_pos, + current_motion.joints_rot, + current_motion.foot_contacts, + ) + + demo.set_frame(event.client.client_id, current_frame_idx) + except Exception as e: + print(e) + event.client.add_notification( + title="SOMA layer failed", + body=str(e), + color="red", + auto_close_seconds=10.0, + ) + gui_use_soma_layer_checkbox.value = not gui_use_soma_layer_checkbox.value + finally: + loading_notif.loading = False + loading_notif.with_close_button = True + loading_notif.auto_close_seconds = 2.0 + + @gui_real_robot_rotations_checkbox.on_update + def _(event: viser.GuiEvent) -> None: + session = get_active_session(event.client) + if session is None or "g1" not in session.model_name: + return + if not isinstance(session.skeleton, G1Skeleton34) or not session.motions: + return + if not gui_real_robot_rotations_checkbox.value: + return + # Reproject all displayed G1 motions to real robot DoF (1-DoF per joint + axis limits). + from kimodo.skeleton import global_rots_to_local_rots + + current_frame_idx = session.frame_idx + for motion in session.motions.values(): + if motion.length <= 1: + continue + rest_pos = motion.joints_pos[0:1] + rest_rot = motion.joints_rot[0:1] + same_as_rest = (motion.joints_pos - rest_pos).abs().max().item() < 1e-6 and ( + motion.joints_rot - rest_rot + ).abs().max().item() < 1e-6 + if same_as_rest: + continue + new_pos, new_rot = generation.apply_g1_real_robot_projection( + session.skeleton, + motion.joints_pos, + motion.joints_rot, + ) + motion.joints_pos = new_pos + motion.joints_rot = new_rot + motion.joints_local_rot = global_rots_to_local_rots(new_rot, session.skeleton) + # Refresh skeleton and skinned mesh caches so the viz uses new positions. + motion.precompute_mesh_info() + demo.set_frame(event.client.client_id, current_frame_idx) + event.client.add_notification( + title="Real robot projection applied", + body="The motion is projected to G1 real robot DoF (1-DoF per joint, clamped to axis limits).", + auto_close_seconds=4.0, + color="green", + ) + + def load_example_from_path( + event_client: viser.ClientHandle, + example_path: str, + load_gt: bool = False, + ) -> None: + from kimodo.meta import parse_prompts_from_meta + from kimodo.tools import load_json + + session = get_active_session(event_client) + if session is None: + return + + # Pause playback when loading an example. + session.playing = False + + if not os.path.isdir(example_path): + event_client.add_notification( + title="Example path not found", + body=f"Directory does not exist: {example_path}", + auto_close_seconds=5.0, + color="red", + ) + return + + try: + # constraints + constraints_path = os.path.join(example_path, "constraints.json") + if os.path.exists(constraints_path): + load_constraints(event_client, constraints_path) + else: + # clear all existing constraints + with session.timeline_data["keyframe_update_lock"]: + for constraint in list(session.constraints.values()): + constraint.clear() + event_client.timeline.clear_keyframes() + event_client.timeline.clear_intervals() + # motion + motion_filename = "gt_motion.npz" if load_gt else "motion.npz" + motion_path = os.path.join(example_path, motion_filename) + if os.path.exists(motion_path): + load_motion(event_client, motion_path) + # metadata + meta_path = os.path.join(example_path, "meta.json") + if os.path.exists(meta_path): + meta_info = load_json(meta_path) + event_client.timeline.clear_prompts() + + texts, durations_sec = parse_prompts_from_meta(meta_info) + fps = session.model_fps + # Convert durations (seconds) to consecutive frame bounds + num_frames = 0 + frame_bounds = [] + for i, d in enumerate(durations_sec): + n_frames = max(1, int(round(d * fps))) + start_frame = num_frames + # Inverse of compute_prompt_num_frames(): + # non-last prompts end at next prompt start (exclusive), + # last prompt includes its end frame. + if i == len(durations_sec) - 1: + end_frame = num_frames + n_frames - 1 + else: + end_frame = num_frames + n_frames + frame_bounds.append((start_frame, end_frame)) + num_frames += n_frames + + # Adapt timeline zoom to the loaded motion. + target_visible_frames = int(math.ceil(1.10 * num_frames)) + event_client.timeline.set_zoom_settings( + default_num_frames_zoom=target_visible_frames, + ) + + for i, (prompt_text, (start_frame, end_frame)) in enumerate(zip(texts, frame_bounds)): + color = PROMPT_COLORS[i % len(PROMPT_COLORS)] + event_client.timeline.add_prompt(prompt_text, start_frame, end_frame, color=color) + + update_duration_auto() + + # Only load optional fields if present + if "num_samples" in meta_info: + gui_num_samples_slider.value = meta_info["num_samples"] + if "seed" in meta_info: + gui_seed.value = meta_info["seed"] + if "diffusion_steps" in meta_info: + gui_diffusion_steps_slider.value = meta_info["diffusion_steps"] + if "cfg" in meta_info: + cfg = meta_info["cfg"] + if "enabled" in cfg: + gui_cfg_checkbox.value = cfg["enabled"] + if "text_weight" in cfg: + gui_cfg_text_weight_slider.value = cfg["text_weight"] + if "constraint_weight" in cfg: + gui_cfg_constraint_weight_slider.value = cfg["constraint_weight"] + + # Set frame to 0 when example is loaded. + session.frame_idx = 0 + event_client.timeline.set_current_frame(0) + demo.set_frame(event_client.client_id, 0) + + event_client.add_notification( + title="Example loaded!", + body=f"Loaded example from {example_path}", + auto_close_seconds=5.0, + color="green", + ) + except Exception as e: + import traceback + + traceback.print_exc() + event_client.add_notification( + title="Failed to load example!", + body=str(e), + auto_close_seconds=10.0, + color="red", + ) + + def load_qwen_example_plan(event_client: viser.ClientHandle) -> None: + """Load a Qwen-generated 10-action prompt plan into the timeline. + + This preserves the native UI flow: + 1) Load Example -> fills timeline prompt segments + 2) Generate -> synthesizes motion from loaded prompts + """ + session = get_active_session(event_client) + if session is None: + return + + def _thread_fn() -> None: + try: + history: list[str] = [] + all_texts: list[str] = [] + all_durations: list[float] = [] + target_actions = 10 + rounds = 0 + + while len(all_texts) < target_actions and rounds < 8: + remaining = target_actions - len(all_texts) + batch, history = call_qwen_for_prompts( + scene="Agentic demo: keep one character in continuous motion", + history=history, + requested_actions=min(5, remaining), + ) + texts = batch.get("texts", []) + durations = batch.get("durations", []) + for t, d in zip(texts, durations): + if len(all_texts) >= target_actions: + break + all_texts.append(t) + all_durations.append(float(d)) + rounds += 1 + + if len(all_texts) == 0: + event_client.add_notification( + title="Qwen example load failed", + body="No prompt segments were produced.", + auto_close_seconds=6.0, + color="red", + ) + return + + fps = session.model_fps + event_client.timeline.clear_prompts() + frame_cursor = 0 + for i, (txt, dur) in enumerate(zip(all_texts, all_durations)): + n_frames = max(1, int(round(dur * fps))) + start_frame = frame_cursor + end_frame = frame_cursor + n_frames if i < len(all_texts) - 1 else frame_cursor + n_frames - 1 + color = PROMPT_COLORS[i % len(PROMPT_COLORS)] + event_client.timeline.add_prompt(txt, start_frame, end_frame, color=color) + frame_cursor += n_frames + + target_visible_frames = int(math.ceil(1.10 * frame_cursor)) + event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames)) + update_duration_auto() + + event_client.add_notification( + title="Qwen example loaded", + body=f"Loaded {len(all_texts)} prompt segments. Click Generate to synthesize motion.", + auto_close_seconds=6.0, + color="green", + ) + except Exception as e: + event_client.add_notification( + title="Qwen example load failed", + body=str(e), + auto_close_seconds=8.0, + color="red", + ) + + threading.Thread(target=_thread_fn, daemon=True).start() + + @gui_load_example_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + if gui_examples_dropdown.value in (QWEN_EXAMPLE_NAME, QWEN_EXAMPLE_LEGACY_NAME): + load_qwen_example_plan(event_client) + return + + if not session.example_dict or (gui_examples_dropdown.value not in session.example_dict): + event_client.add_notification( + title="No examples available", + body="No examples found for the selected model.", + auto_close_seconds=5.0, + color="red", + ) + return + + example_path = session.example_dict[gui_examples_dropdown.value] + load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value) + + @gui_load_example_from_path_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + example_path = gui_load_example_path_text.value + if not example_path: + event_client.add_notification( + title="No example path", + body="Please provide an example directory.", + auto_close_seconds=5.0, + color="red", + ) + return + load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value) + + @gui_cfg_checkbox.on_update + def _(_) -> None: + if not demo.client_active(client_id): + return + val = gui_cfg_checkbox.value + gui_cfg_text_weight_slider.visible = val + gui_cfg_constraint_weight_slider.visible = val + + def exit_editing_mode(session: ClientSession): + gui_edit_constraint_button.label = "Enter Editing Mode" + gui_generate_button.disabled = False + gui_generate_button.label = "Generate" + gui_reset_constraint_button.disabled = True + if "g1" in session.model_name: + gui_gizmo_space_dropdown.value = "Local" + gui_gizmo_space_dropdown.disabled = True + gui_gizmo_space_dropdown.visible = False + else: + gui_gizmo_space_dropdown.disabled = False + gui_gizmo_space_dropdown.visible = True + gui_undo_drag_button.disabled = True + gui_use_soma_layer_checkbox.disabled = False + session.edit_mode_snapshot = None + session.undo_drag_snapshot = None + + motion = list(session.motions.values())[0] + motion.clear_all_gizmos() + motion.character.set_skinned_mesh_wireframe(False) + motion.character.set_skeleton_visibility(False) + motion.character.set_skinned_mesh_visibility(True) + motion.character.set_skinned_mesh_opacity(1.0) + session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 1.0 + + # If the path is dense, put the motion back on the path + if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path: + _update_dense_path(motion, session) + + gui_viz_skinned_mesh_checkbox.value = True + gui_viz_skeleton_checkbox.value = False + + # enter editing mode callback + @gui_edit_constraint_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + session.edit_mode = not session.edit_mode + + edit_alert = "Entered editing mode" + no_edit_alert = "Exited editing mode" + edit_message = "You can now modify pose or path constraints." + no_edit_message = "Can now generate motions." + event_client.add_notification( + title=edit_alert if session.edit_mode else no_edit_alert, + body=edit_message if session.edit_mode else no_edit_message, + auto_close_seconds=10.0, + color="blue", + ) + + if session.edit_mode: + gui_edit_constraint_button.label = "Exit Editing Mode" + gui_generate_button.disabled = True + gui_generate_button.label = "Generate Disabled In Editing Mode" + if "g1" in session.model_name: + gui_gizmo_space_dropdown.value = "Local" + gui_gizmo_space_dropdown.disabled = True + gui_use_soma_layer_checkbox.disabled = True + + assert len(session.motions) == 1, "Only one motion allowed in edit mode" + motion = list(session.motions.values())[0] + snapshot_frame_idx = min(session.frame_idx, motion.length - 1) + session.edit_mode_snapshot = {} + ensure_edit_snapshot(session, motion, snapshot_frame_idx) + gui_reset_constraint_button.disabled = False + + motion.character.set_skeleton_visibility(True) + # motion.character.set_skinned_mesh_wireframe(True) + motion.character.set_skinned_mesh_opacity(0.65) + session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 0.65 + motion.character.set_skinned_mesh_visibility(True) + gui_viz_skinned_mesh_checkbox.value = True + gui_viz_skeleton_checkbox.value = True + + # need gizmos for root translation and individual joints + def _on_root2d_gizmo_release(): + if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path: + mot = list(session.motions.values())[0] + _update_dense_path(mot, session) + + def _on_gizmo_drag_start(): + mot = list(session.motions.values())[0] + frame_idx = min(session.frame_idx, mot.length - 1) + session.undo_drag_snapshot = { + "frame_idx": frame_idx, + "joints_pos": mot.get_joints_pos(frame_idx), + "joints_rot": mot.get_joints_rot(frame_idx), + } + gui_undo_drag_button.disabled = False + + motion.add_root_translation_gizmo( + session.constraints, + on_2d_root_drag_end=_on_root2d_gizmo_release, + on_drag_start=_on_gizmo_drag_start, + ) + gizmo_space = "local" if "g1" in session.model_name else gui_gizmo_space_dropdown.value.lower() + motion.add_joint_gizmos( + session.constraints, + space=gizmo_space, + on_drag_start=_on_gizmo_drag_start, + ) + else: + exit_editing_mode(session) + + @gui_reset_constraint_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None or not session.edit_mode_snapshot: + return + + if not session.motions: + return + motion = list(session.motions.values())[0] + snapshot_frame_idx = min(session.frame_idx, motion.length - 1) + if snapshot_frame_idx not in session.edit_mode_snapshot: + return + motion.update_pose_at_frame( + snapshot_frame_idx, + joints_pos=session.edit_mode_snapshot[snapshot_frame_idx]["joints_pos"], + joints_rot=session.edit_mode_snapshot[snapshot_frame_idx]["joints_rot"], + ) + demo.set_frame(event_client.client_id, snapshot_frame_idx, update_timeline=False) + + @gui_undo_drag_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None or session.undo_drag_snapshot is None: + return + + if not session.motions: + return + motion = list(session.motions.values())[0] + frame_idx = session.undo_drag_snapshot["frame_idx"] + motion.update_pose_at_frame( + frame_idx, + joints_pos=session.undo_drag_snapshot["joints_pos"], + joints_rot=session.undo_drag_snapshot["joints_rot"], + ) + demo.set_frame(event_client.client_id, frame_idx, update_timeline=False) + session.undo_drag_snapshot = None + gui_undo_drag_button.disabled = True + + def validate_interval(start_frame_idx: int, end_frame_idx: int, max_frame_idx: int) -> bool: + if start_frame_idx < 0 or start_frame_idx > max_frame_idx: + return False + if end_frame_idx < 0 or end_frame_idx > max_frame_idx: + return False + if end_frame_idx < start_frame_idx: + return False + return True + + def clamp_interval_to_range( + start_frame_idx: int, end_frame_idx: int, max_frame_idx: int + ) -> Optional[tuple[int, int]]: + if end_frame_idx < 0 or start_frame_idx > max_frame_idx: + return None + start_clamped = max(0, start_frame_idx) + end_clamped = min(max_frame_idx, end_frame_idx) + if end_clamped < start_clamped: + return None + return start_clamped, end_clamped + + # add constraint callback + def add_constraint_callback( + constraint_id: str, + constraint_type: str, + frame_range: tuple[int, int], + joint_names: list[str] = None, + verbose: bool = True, + ): + """Add a constraint to the session. + + Args: + constraint_type: str, the type of constraint to add + frame_range: tuple[int, int], the frame range to add the constraint to + joint_names: list[str], the names of the joints to constraint if the constraint type is End-Effectors + """ + # Check if session still exists + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + + assert len(session.motions) == 1, "Only one motion allowed for adding constraints" + motion = list(session.motions.values())[0] + + end_effector_type = None + if constraint_type in [ + "Left Hand", + "Right Hand", + "Left Foot", + "Right Foot", + ]: + joint_names = [constraint_type.replace(" ", ""), "Hips"] + # Hips are required because of smooth root representation + end_effector_type = constraint_type.replace(" ", "-").lower() + constraint_type = "End-Effectors" + + # check to make sure interval is valid + is_interval = frame_range[1] != frame_range[0] + start_frame_idx = int(frame_range[0]) + end_frame_idx = int(frame_range[1]) + + if is_interval: + clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx) + if clamped is None: + print("Interval outside range! Couldn't add constraint.") + return + start_frame_idx, end_frame_idx = clamped + else: + if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx): + print("Invalid interval! Couldn't add constraint.") + return + + # collect input args for the constraint based on which track it is + if is_interval: + constraint_kwargs = { + "interval_id": constraint_id, + "start_frame_idx": start_frame_idx, + "end_frame_idx": end_frame_idx, + } + else: + constraint_kwargs = { + "keyframe_id": constraint_id, + "frame_idx": start_frame_idx, + } + + if constraint_type in ["Full-Body", "End-Effectors"]: + constraint_kwargs["joints_pos"] = motion.get_joints_pos(start_frame_idx, end_frame_idx) + constraint_kwargs["joints_rot"] = motion.get_joints_rot(start_frame_idx, end_frame_idx) + if constraint_type == "End-Effectors": + constraint_kwargs["joint_names"] = joint_names + constraint_kwargs["end_effector_type"] = end_effector_type + + elif constraint_type == "2D Root": + constraint_kwargs["root_pos"] = motion.get_projected_root_pos(start_frame_idx, end_frame_idx) + + # add the keyframe(s) to the constraint track + constraint = session.constraints[constraint_type] + if is_interval: + constraint.add_interval(**constraint_kwargs) + else: + constraint.add_keyframe(**constraint_kwargs) + + apply_constraint_overlay_visibility(session) + + if verbose: + client.add_notification( + title="Constraint added", + body="", + auto_close_seconds=5.0, + color="blue", + ) + + # timeline callbacks for keyframes and intervals + @client.timeline.on_keyframe_add + def _(keyframe_id: str, track_id: str, frame: int): + """Called when a keyframe is added to a track.""" + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + with session.timeline_data["keyframe_update_lock"]: + constraint_type = session.timeline_data["tracks"][track_id]["name"] + add_constraint_callback( + keyframe_id, + constraint_type, + (frame, frame), + verbose=False, + ) + keyframe_data = client.timeline._keyframes.get(keyframe_id) + session.timeline_data["keyframes"][keyframe_id] = { + "frame": frame, + "track_id": track_id, + "locked": bool(keyframe_data.locked) if keyframe_data is not None else False, + "opacity": keyframe_data.opacity if keyframe_data is not None else 1.0, + "value": keyframe_data.value if keyframe_data is not None else None, + } + # Update smooth path when adding a keyframe (single action, not drag). + if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: + motion = list(session.motions.values())[0] + _update_dense_path(motion, session) + + @client.timeline.on_interval_add + def handle_interval_add(interval_id: str, track_id: str, start_frame: int, end_frame: int): + """Called when an interval is added to a track.""" + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + with session.timeline_data["keyframe_update_lock"]: + constraint_type = session.timeline_data["tracks"][track_id]["name"] + add_constraint_callback( + interval_id, + constraint_type, + (start_frame, end_frame), + verbose=False, + ) + interval_data = client.timeline._intervals.get(interval_id) + session.timeline_data["intervals"][interval_id] = { + "track_id": track_id, + "start_frame_idx": start_frame, + "end_frame_idx": end_frame, + "locked": bool(interval_data.locked) if interval_data is not None else False, + "opacity": interval_data.opacity if interval_data is not None else 1.0, + "value": interval_data.value if interval_data is not None else None, + } + if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: + motion = list(session.motions.values())[0] + _update_dense_path(motion, session) + + def remove_constraint_callback( + constraint_id: str, + constraint_type: str, + frame_range: tuple[int, int], + verbose: bool = True, + ) -> None: + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + session.updating_motions = True + + is_interval = frame_range[1] != frame_range[0] + start_frame_idx = int(frame_range[0]) + end_frame_idx = int(frame_range[1]) + + if is_interval: + clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx) + if clamped is None: + return + start_frame_idx, end_frame_idx = clamped + else: + if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx): + print("Invalid interval! Couldn't remove constraint.") + return + + if constraint_type in [ + "Left Hand", + "Right Hand", + "Left Foot", + "Right Foot", + ]: + constraint_type = "End-Effectors" + + constraint = session.constraints[constraint_type] + if is_interval: + constraint.remove_interval(constraint_id, start_frame_idx, end_frame_idx) + else: + constraint.remove_keyframe(constraint_id, start_frame_idx) + + if verbose: + client.add_notification( + title="Constraint removed", + body="", + auto_close_seconds=5.0, + color="blue", + ) + + @client.timeline.on_keyframe_move + def handle_keyframe_move(keyframe_id: str, new_frame: int): + """Called when a keyframe is moved to a new frame.""" + # print(f"Keyframe moved: {keyframe_id} to frame {new_frame}") + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + + # Cancel any pending timer for this keyframe + timeline_data = session.timeline_data + with timeline_data["keyframe_update_lock"]: + if keyframe_id in timeline_data["keyframe_move_timers"]: + timeline_data["keyframe_move_timers"][keyframe_id].cancel() + + # Store the latest target frame + timeline_data["pending_keyframe_moves"][keyframe_id] = new_frame + # Create a new timer to execute the actual move after a delay + # This debounces rapid movements - only execute when user stops moving + timer = threading.Timer( + 0.03, # 10ms delay - adjust as needed + _execute_keyframe_move, + args=(client_id, keyframe_id, new_frame, session), + ) + timeline_data["keyframe_move_timers"][keyframe_id] = timer + timer.start() + + def _execute_keyframe_move( + client_id: int, + keyframe_id: str, + new_frame: int, + session: ClientSession, + ): + """Actually execute the keyframe move operation (called after debounce delay).""" + + timeline_data = session.timeline_data + with timeline_data["keyframe_update_lock"]: + # Check if this move is still the latest one + if keyframe_id not in timeline_data["pending_keyframe_moves"]: + return # Move was cancelled + + if timeline_data["pending_keyframe_moves"][keyframe_id] != new_frame: + return # A newer move superseded this one + + # Remove from pending + del timeline_data["pending_keyframe_moves"][keyframe_id] + if keyframe_id in timeline_data["keyframe_move_timers"]: + del timeline_data["keyframe_move_timers"][keyframe_id] + + # Now execute the actual move (keep it in the lock so we don't delete it while moving) + if keyframe_id not in timeline_data["keyframes"]: + # double check + return + keyframe_data = timeline_data["keyframes"][keyframe_id] + if not keyframe_data: + return + + # if the frame did not move, don't do anything + if keyframe_data["frame"] == new_frame: + return + + track_id = keyframe_data["track_id"] + constraint_type = timeline_data["tracks"][track_id]["name"] + cur_frame = keyframe_data["frame"] + + # Remove constraint at old frame + remove_constraint_callback( + keyframe_id, + constraint_type, + (cur_frame, cur_frame), + verbose=False, + ) + # Add constraint at new frame + add_constraint_callback( + keyframe_id, + constraint_type, + (new_frame, new_frame), + verbose=False, + ) + + # update our data + keyframe_data["frame"] = new_frame + + # Schedule path update only after user stops dragging (no move for 300ms). + if constraint_type == "2D Root": + _schedule_dense_path_after_release(session) + + @client.timeline.on_keyframe_delete + def handle_keyframe_delete(keyframe_id: str): + """Called when a keyframe is deleted.""" + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + with session.timeline_data["keyframe_update_lock"]: + if keyframe_id not in session.timeline_data["keyframes"]: + return + keyframe_data = session.timeline_data["keyframes"][keyframe_id] + track_id = keyframe_data["track_id"] + constraint_type = session.timeline_data["tracks"][track_id]["name"] + cur_frame = keyframe_data["frame"] + remove_constraint_callback( + keyframe_id, + constraint_type, + (cur_frame, cur_frame), + verbose=False, + ) + del session.timeline_data["keyframes"][keyframe_id] + if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: + motion = list(session.motions.values())[0] + _update_dense_path(motion, session) + + @client.timeline.on_interval_move + def handle_interval_move(interval_id: str, new_start: int, new_end: int): + """Called when an interval is moved or resized.""" + # print(f"Interval moved: {interval_id} to {new_start}-{new_end}") + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + + # Cancel any pending timer for this interval + # We share the same lock for keyframe and interval moves assuming the user can't move both at the same time + timeline_data = session.timeline_data + with timeline_data["keyframe_update_lock"]: + if interval_id in timeline_data["keyframe_move_timers"]: + timeline_data["keyframe_move_timers"][interval_id].cancel() + + # Store the latest target frame + new_interval = (new_start, new_end) + timeline_data["pending_keyframe_moves"][interval_id] = new_interval + # Create a new timer to execute the actual move after a delay + # This debounces rapid movements - only execute when user stops moving + timer = threading.Timer( + 0.5, # 100ms delay - adding interval is much slower than moving a keyframe + _execute_interval_move, + args=(client_id, interval_id, new_interval, session), + ) + timeline_data["keyframe_move_timers"][interval_id] = timer + timer.start() + + def _execute_interval_move( + client_id: int, + interval_id: str, + new_interval: tuple[int, int], + session: ClientSession, + ): + """Actually execute the interval move operation (called after debounce delay).""" + + timeline_data = session.timeline_data + with timeline_data["keyframe_update_lock"]: + # Check if this move is still the latest one + if interval_id not in timeline_data["pending_keyframe_moves"]: + return # Move was cancelled + + if timeline_data["pending_keyframe_moves"][interval_id] != new_interval: + return # A newer move superseded this one + + # Remove from pending + del timeline_data["pending_keyframe_moves"][interval_id] + if interval_id in timeline_data["keyframe_move_timers"]: + del timeline_data["keyframe_move_timers"][interval_id] + + # Now execute the actual move + if interval_id not in timeline_data["intervals"]: + return + interval_data = timeline_data["intervals"][interval_id] + if not interval_data: + return + + # if the interval did not move, don't do anything + if ( + interval_data["start_frame_idx"] == new_interval[0] + and interval_data["end_frame_idx"] == new_interval[1] + ): + return + + track_id = interval_data["track_id"] + constraint_type = timeline_data["tracks"][track_id]["name"] + cur_range = ( + interval_data["start_frame_idx"], + interval_data["end_frame_idx"], + ) + + # Remove constraint at old frame + remove_constraint_callback( + interval_id, + constraint_type, + cur_range, + verbose=False, + ) + # Add constraint at new frame + add_constraint_callback( + interval_id, + constraint_type, + new_interval, + verbose=False, + ) + + # update our data + interval_data["start_frame_idx"] = new_interval[0] + interval_data["end_frame_idx"] = new_interval[1] + + # Schedule path update only after user stops dragging (no move for 300ms). + if constraint_type == "2D Root": + _schedule_dense_path_after_release(session) + + @client.timeline.on_interval_delete + def handle_interval_delete(interval_id: str): + """Called when an interval is deleted.""" + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + with session.timeline_data["keyframe_update_lock"]: + if interval_id not in session.timeline_data["intervals"]: + return + interval_data = session.timeline_data["intervals"][interval_id] + track_id = interval_data["track_id"] + constraint_type = session.timeline_data["tracks"][track_id]["name"] + remove_constraint_callback( + interval_id, + constraint_type, + ( + interval_data["start_frame_idx"], + interval_data["end_frame_idx"], + ), + verbose=False, + ) + del session.timeline_data["intervals"][interval_id] + if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: + motion = list(session.motions.values())[0] + _update_dense_path(motion, session) + + @gui_snap_to_constraint_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + target_character_motion = list(session.motions.values())[0] + frame_idx = session.frame_idx + + if frame_idx >= target_character_motion.length: + # frame idx larger than the motion, could not snap + return + + for constraint_name in ["Full-Body", "End-Effectors"]: + if ( + constraint_name in session.constraints + and frame_idx in session.constraints[constraint_name].keyframes + ): + pos = session.constraints[constraint_name].keyframes[frame_idx]["joints_pos"] + rot = session.constraints[constraint_name].keyframes[frame_idx]["joints_rot"] + + # update the full joints_pos of the character to match the constraints + target_character_motion.update_pose_at_frame( + frame_idx, + joints_pos=pos, + joints_rot=rot, + ) + target_character_motion.set_frame(frame_idx) + return # motion already fully changed + + if "2D Root" in session.constraints and frame_idx in session.constraints["2D Root"].keyframes: + # update only the root position + new_root_pos = session.constraints["2D Root"].keyframes[frame_idx] + old_root_pos = target_character_motion.get_projected_root_pos(frame_idx) + root_diff = new_root_pos - old_root_pos + root_diff[1] = 0.0 # don't change height + + new_joints_pos = ( + target_character_motion.joints_pos[frame_idx] + + to_torch( + root_diff, + device=target_character_motion.joints_pos.device, + dtype=target_character_motion.joints_pos.dtype, + )[None] + ) + rot = target_character_motion.joints_rot[frame_idx] + + target_character_motion.update_pose_at_frame( + frame_idx, + joints_pos=new_joints_pos, + joints_rot=rot, + ) + target_character_motion.set_frame(frame_idx) + + @gui_clear_all_constraints_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + with session.timeline_data["keyframe_update_lock"]: + # use the lock here to wait for any constraint updates to finish + for constraint in list(session.constraints.values()): + constraint.clear() + client.timeline.clear_keyframes() + client.timeline.clear_intervals() + if gui_dense_path_checkbox.value: + gui_dense_path_checkbox.value = False + if "2D Root" in session.constraints: + session.constraints["2D Root"].set_dense_path(False) + + # generation callback + @gui_generate_button.on_click + def _(event: viser.GuiEvent) -> None: + event_client = event.client + session = get_active_session(event_client) + if session is None: + return + + generating_notif = event_client.add_notification( + title="Generating motion...", + body="Generating motions for the given prompt!", + loading=True, + with_close_button=False, + ) + gui_generate_button.disabled = True + client.timeline.disable_constraints() + + num_samples = gui_num_samples_slider.value + timeline = session.client.timeline + + # sort them to avoid issues: + prompt_values = sorted([x for x in timeline._prompts.values()], key=lambda x: x.start_frame) + + texts = [x.text for x in prompt_values] + num_frames = compute_prompt_num_frames(prompt_values) + + # compute the total duration + total_nb_frames = sum(num_frames) + total_duration = total_nb_frames / session.model_fps + + # update just in case + set_new_duration(client_id, total_duration) + + transitions_parameters = { + "num_transition_frames": gui_num_transition_frames_slider.value, + "share_transition": gui_share_transition_checkbox.value, + "percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100, + } + + # G1: postprocessing is disabled (does not work well for this model). + postprocess_parameters = { + "post_processing": (False if "g1" in session.model_name else gui_postprocess_checkbox.value), + "root_margin": gui_root_margin.value, + } + try: + demo.generate( + event_client, + texts, + num_frames, + num_samples, + gui_seed.value, + gui_diffusion_steps_slider.value, + cfg_weight=[ + gui_cfg_text_weight_slider.value, + gui_cfg_constraint_weight_slider.value, + ], + cfg_type="separated" if gui_cfg_checkbox.value else "nocfg", + postprocess_parameters=postprocess_parameters, + transitions_parameters=transitions_parameters, + real_robot_rotations=gui_real_robot_rotations_checkbox.value, + ) + session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) + session.max_frame_idx = int(session.cur_duration * session.model_fps) - 1 + if session.frame_idx > session.max_frame_idx: + session.frame_idx = session.max_frame_idx + + if num_samples > 1: + # add mesh selector to choose character to commit + def commit_motion(event: viser.GuiEvent) -> None: + target = event.target + commit_name = target.name.split("/")[1] # e.g. /character0/simple_skinned + print(f"Committing motion for character: {commit_name}") + # delete non-selected motions + new_motion_kwargs = None + for character_name, motion in session.motions.items(): + if character_name == commit_name: + new_motion_kwargs = { + "skeleton": session.skeleton, + "joints_rot": motion.joints_rot, + "foot_contacts": motion.foot_contacts, + } + root_x_offset = motion.joints_pos[0, session.skeleton.root_idx, 0] + new_joints_pos = motion.joints_pos.clone() + new_joints_pos[..., 0] -= root_x_offset + new_motion_kwargs["joints_pos"] = new_joints_pos + break + # clear and re-add the selected motion + demo.clear_motions(event_client.client_id) + demo.add_character_motion(event_client, **new_motion_kwargs) + gui_edit_constraint_button.disabled = False + gui_generate_button.disabled = False + gui_snap_to_constraint_button.disabled = False + client.timeline.enable_constraints() + gui_generate_button.label = "Generate" + gui_save_example_button.disabled = False + gui_save_motion_button.disabled = False + gui_download_button.disabled = False + gui_save_constraints_button.disabled = False + gui_load_example_button.disabled = False + + for motion in session.motions.values(): + char = motion.character + character_name = char.name # e.g. "character0" + if char.skinned_mesh is not None: + char.skinned_mesh.on_click(commit_motion) + elif char.g1_mesh_rig is not None: + # Register click on every part so any part can be clicked, + # and use highlight_group so the whole robot highlights together. + for handle in char.g1_mesh_rig.mesh_handles: + handle.on_click(commit_motion, highlight_group=character_name) + + gui_edit_constraint_button.disabled = True + gui_generate_button.disabled = True + gui_snap_to_constraint_button.disabled = True + gui_generate_button.label = "Choose Sample Before Generating" + gui_save_example_button.disabled = True + gui_save_motion_button.disabled = True + gui_download_button.disabled = True + gui_save_constraints_button.disabled = True + gui_load_example_button.disabled = True + else: + gui_edit_constraint_button.disabled = False + gui_generate_button.disabled = False + gui_snap_to_constraint_button.disabled = False + client.timeline.enable_constraints() + + generating_notif.title = "Motion generation finished!" + generating_notif.body = "Motions have been generated successfully for the given prompt." + if num_samples > 1: + generating_notif.body += " Now choose which sample to commit." + generating_notif.loading = False + generating_notif.with_close_button = True + generating_notif.auto_close_seconds = 5.0 + generating_notif.color = "green" + + # put the motion at zero + demo.set_frame(client_id, 0) + + except Exception as e: + import traceback + + traceback.print_exc() + print(f"Error during generation for client {event_client.client_id}: {e}") + # Re-enable buttons and notify the user + if event_client.client_id in demo.client_sessions: + session = demo.client_sessions[event_client.client_id] + gui_generate_button.disabled = False + gui_load_example_button.disabled = False + gui_save_example_button.disabled = False + gui_save_motion_button.disabled = False + gui_download_button.disabled = False + # Reuse persistent notification instead of creating a new one + try: + generating_notif.title = "Generation failed!" + generating_notif.body = f"Error: {str(e)}" + generating_notif.loading = False + generating_notif.with_close_button = True + generating_notif.auto_close_seconds = 6.0 + generating_notif.color = "red" + except Exception: + pass + demo.check_cuda_health() + + # + # Visualization settings + # + with tab_group.add_tab("Visualize", viser.Icon.EYE): + with client.gui.add_folder("Playback", expand_by_default=True): + gui_model_fps = client.gui.add_number("Model FPS", initial_value=model_fps, disabled=True) + gui_playback_speed_buttons = client.gui.add_button_group( + "Playback Speed", + options=[ + "0.5x", + "1x", + "2x", + ], + ) + gui_playback_speed_buttons.value = "1x" + + @client.timeline.on_frame_change + def handle_timeline_frame_change(new_frame_idx: int): + """Update the frame when the user clicks on the timeline.""" + demo.set_frame(client_id, new_frame_idx, update_timeline=False) + session = demo.client_sessions.get(client_id) + if session is not None: + if session.edit_mode and session.motions: + motion = list(session.motions.values())[0] + snapshot_frame_idx = min(session.frame_idx, motion.length - 1) + ensure_edit_snapshot(session, motion, snapshot_frame_idx) + update_snap_to_constraint_button(session) + + @client.timeline.on_prompt_add + async def _on_add( + prompt_id: str, + start_frame: int, + end_frame: int, + text: str, + color: tuple[int, int, int] | None, + ) -> None: + update_duration_auto() + + @client.timeline.on_prompt_update + async def _on_update(prompt_id: str, new_text: str) -> None: + update_duration_auto() + + @client.timeline.on_prompt_resize + async def _on_resize(prompt_id: str, new_start: int, new_end: int) -> None: + update_duration_auto() + + @client.timeline.on_prompt_move + async def _on_move(prompt_id: str, new_start: int, new_end: int) -> None: + update_duration_auto() + + @client.timeline.on_prompt_delete + async def _on_delete(prompt_id: str) -> None: + update_duration_auto() + + def play_pause_button_callback(session: ClientSession): + session.playing = not session.playing + + def next_frame_callback(session: ClientSession): + if session.frame_idx < session.max_frame_idx: + session.frame_idx += 1 + if session.frame_idx == session.max_frame_idx: + pass + demo.set_frame(client_id, session.frame_idx) + + def prev_frame_callback(session: ClientSession): + if session.frame_idx > 0: + session.frame_idx -= 1 + if session.frame_idx == 0: + pass + demo.set_frame(client_id, session.frame_idx) + + @gui_playback_speed_buttons.on_click + def _(_) -> None: + if not demo.client_active(client_id): + return + speed_map = { + "0.5x": 0.5, + "1x": 1.0, + "2x": 2.0, + } + session = demo.client_sessions[client_id] + session.playback_speed = speed_map[gui_playback_speed_buttons.value] + + with client.gui.add_folder("Body options", expand_by_default=True): + gui_viz_skinned_mesh_checkbox = client.gui.add_checkbox("Show Mesh", initial_value=True) + gui_viz_skinned_mesh_opacity_slider = client.gui.add_slider( + "Mesh Opacity", min=0.0, max=1.0, step=0.01, initial_value=1.0 + ) + gui_viz_skeleton_checkbox = client.gui.add_checkbox("Show Skeleton", initial_value=False) + gui_viz_foot_contacts_checkbox = client.gui.add_checkbox("Show Foot Contacts", initial_value=False) + gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value + with client.gui.add_folder("Camera options", expand_by_default=True): + gui_camera_fov_slider = client.gui.add_slider( + "Camera FOV (deg)", + min=30.0, + max=90.0, + step=1.0, + initial_value=45.0, + ) + client.camera.fov = np.deg2rad(gui_camera_fov_slider.value) + with client.gui.add_folder("Interface options", expand_by_default=True): + gui_show_timeline_checkbox = client.gui.add_checkbox( + "Show Timeline", + initial_value=True, + ) + gui_show_constraint_tracks_checkbox = client.gui.add_checkbox( + "Show Constraint tracks", + initial_value=True, + ) + gui_show_constraint_labels_checkbox = client.gui.add_checkbox( + "Show Constraint labels", + initial_value=True, + ) + gui_show_starting_direction_checkbox = client.gui.add_checkbox( + "Show Starting Direction", + initial_value=True, + ) + gui_dark_mode_checkbox = client.gui.add_checkbox( + "Dark Mode", + initial_value=False, # Default to light mode + ) + gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value + demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value) + + @gui_dark_mode_checkbox.on_update + def _(_): + # Apply the theme using configure_theme (pass uuid so titlebar toggle stays) + demo.configure_theme( + client, + gui_dark_mode_checkbox.value, + titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid, + ) + session = demo.client_sessions[client.client_id] + for motion in session.motions.values(): + motion.character.change_theme(gui_dark_mode_checkbox.value) + + # Show dark mode toggle in titlebar (right of Github), hide sidebar checkbox + demo.configure_theme( + client, + gui_dark_mode_checkbox.value, + titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid, + ) + gui_dark_mode_checkbox.visible = False + + @gui_show_constraint_labels_checkbox.on_update + def _(_): + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + for constraint in session.constraints.values(): + constraint.set_label_visibility(gui_show_constraint_labels_checkbox.value) + + @gui_show_timeline_checkbox.on_update + def _(_): + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + session.client.timeline.set_visible(gui_show_timeline_checkbox.value) + gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value + if gui_show_timeline_checkbox.value: + demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value) + + @gui_show_constraint_tracks_checkbox.on_update + def _(_): + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value) + + @gui_show_starting_direction_checkbox.on_update + def _(_): + if not demo.client_active(client_id): + return + demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value) + + @gui_viz_skeleton_checkbox.on_update + def _(_) -> None: + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value + if not gui_viz_skeleton_checkbox.value: + gui_viz_foot_contacts_checkbox.value = False + for motion in session.motions.values(): + motion.character.set_skeleton_visibility(gui_viz_skeleton_checkbox.value) + + @gui_viz_foot_contacts_checkbox.on_update + def _(_) -> None: + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + for motion in session.motions.values(): + motion.character.set_show_foot_contacts( + gui_viz_foot_contacts_checkbox.value, frame_idx=motion.cur_frame_idx + ) + + @gui_viz_skinned_mesh_checkbox.on_update + def _(_) -> None: + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + for motion in session.motions.values(): + motion.character.set_skinned_mesh_visibility(gui_viz_skinned_mesh_checkbox.value) + + @gui_viz_skinned_mesh_opacity_slider.on_update + def _(_) -> None: + if not demo.client_active(client_id): + return + session = demo.client_sessions[client_id] + for motion in session.motions.values(): + motion.character.set_skinned_mesh_opacity(gui_viz_skinned_mesh_opacity_slider.value) + + @gui_camera_fov_slider.on_update + def _(_) -> None: + if not demo.client_active(client_id): + return + client.camera.fov = np.deg2rad(gui_camera_fov_slider.value) + + # + + # Instructions tab + # + with tab_group.add_tab("Instructions", viser.Icon.INFO_CIRCLE): + client.gui.add_markdown(DEMO_UI_INSTRUCTIONS_TAB_MD) + + # + # Keyboard events + # + @client.scene.on_keyboard_event("keydown", debounce_ms=100) + def handle_key(event: viser.KeyboardEvent) -> None: + # Check if client session still exists + if client_id not in demo.client_sessions: + return + + session = demo.client_sessions[client_id] + + # Space bar: only toggle on FIRST press + if event.key == " ": + now = time.monotonic() + if now - session.last_space_toggle_time >= 0.2: + session.last_space_toggle_time = now + play_pause_button_callback(session) + return + + # Handle arrow keys: frame navigation (fast OS repeat with 50ms debounce). + elif event.key == "ArrowLeft": + prev_frame_callback(session) + elif event.key == "ArrowRight": + next_frame_callback(session) + + gui_elements = GuiElements( + gui_play_pause_button=gui_play_pause_button, + gui_next_frame_button=gui_next_frame_button, + gui_prev_frame_button=gui_prev_frame_button, + gui_generate_button=gui_generate_button, + gui_model_fps=gui_model_fps, + gui_timeline=gui_timeline, + gui_viz_skeleton_checkbox=gui_viz_skeleton_checkbox, + gui_viz_foot_contacts_checkbox=gui_viz_foot_contacts_checkbox, + gui_viz_skinned_mesh_checkbox=gui_viz_skinned_mesh_checkbox, + gui_viz_skinned_mesh_opacity_slider=gui_viz_skinned_mesh_opacity_slider, + gui_camera_fov_slider=gui_camera_fov_slider, + gui_duration_slider=gui_duration_slider, + gui_num_samples_slider=gui_num_samples_slider, + gui_cfg_checkbox=gui_cfg_checkbox, + gui_cfg_text_weight_slider=gui_cfg_text_weight_slider, + gui_cfg_constraint_weight_slider=gui_cfg_constraint_weight_slider, + gui_diffusion_steps_slider=gui_diffusion_steps_slider, + gui_seed=gui_seed, + gui_postprocess_checkbox=gui_postprocess_checkbox, + gui_root_margin=gui_root_margin, + gui_real_robot_rotations_checkbox=gui_real_robot_rotations_checkbox, + gui_dark_mode_checkbox=gui_dark_mode_checkbox, + gui_use_soma_layer_checkbox=gui_use_soma_layer_checkbox, + ) + return ( + gui_elements, + timeline_tracks, + example_dict, + gui_examples_dropdown, + gui_save_example_path_text, + gui_model_selector, + )