Kimodo Bot commited on
Commit
e1c31e5
·
1 Parent(s): 8413010

Boot native demo via proxy and include Qwen example UI flow

Browse files
app.py CHANGED
@@ -1,7 +1,11 @@
1
- """Movimento Space: lightweight host for NVIDIA Kimodo native UI."""
2
  from __future__ import annotations
3
 
 
4
  import os
 
 
 
5
  import gradio as gr
6
 
7
  try:
@@ -12,31 +16,86 @@ except Exception:
12
  def GPU(*args, **kwargs):
13
  def _decorator(fn):
14
  return fn
 
15
  return _decorator
16
 
17
  spaces = _SpacesFallback()
18
 
19
 
20
- # Keep a GPU-decorated function so HF Spaces startup checks pass on zero-a10g.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @spaces.GPU(duration=60)
22
  def _gpu_healthcheck() -> str:
23
  return "ok"
24
 
25
 
26
  def _viewer_html() -> str:
27
- src = os.environ.get("KIMODO_UI_URL", "https://nvidia-kimodo.hf.space").strip()
 
 
 
28
  return (
29
  "<div style='border:1px solid #d9e7ef;border-radius:12px;overflow:hidden;'>"
30
- f"<iframe src='{src}' title='Kimodo Native UI' style='width:100%;border:0' "
31
- "height='900' loading='lazy'></iframe>"
32
  "</div>"
33
  )
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  with gr.Blocks(title="Movimento") as demo:
37
  gr.Markdown("# Movimento")
38
- gr.Markdown("Native NVIDIA Kimodo UI is embedded below.")
39
- gr.HTML(_viewer_html())
 
 
40
 
41
 
42
  if __name__ == "__main__":
 
1
+ """Movimento Space: boot native Kimodo demo UI and embed via proxy."""
2
  from __future__ import annotations
3
 
4
+ import importlib.util
5
  import os
6
+ import threading
7
+ import traceback
8
+
9
  import gradio as gr
10
 
11
  try:
 
16
  def GPU(*args, **kwargs):
17
  def _decorator(fn):
18
  return fn
19
+
20
  return _decorator
21
 
22
  spaces = _SpacesFallback()
23
 
24
 
25
+ NATIVE_PORT = int(os.environ.get("KIMODO_NATIVE_PORT", "8080"))
26
+ os.environ.setdefault("SERVER_NAME", "0.0.0.0")
27
+ os.environ.setdefault("SERVER_PORT", str(NATIVE_PORT))
28
+ os.environ.setdefault("HF_MODE", "1")
29
+
30
+ _state: dict[str, object] = {
31
+ "ok": False,
32
+ "error": None,
33
+ "trace": None,
34
+ "demo": None,
35
+ }
36
+
37
+
38
+ def _boot_native_demo() -> None:
39
+ try:
40
+ if importlib.util.find_spec("viser") is None:
41
+ raise RuntimeError("Missing dependency: viser")
42
+
43
+ from kimodo.demo.app import Demo
44
+
45
+ _state["demo"] = Demo()
46
+ _state["ok"] = True
47
+ _state["error"] = None
48
+ _state["trace"] = None
49
+ except Exception as exc: # noqa: BLE001
50
+ _state["ok"] = False
51
+ _state["error"] = str(exc)
52
+ _state["trace"] = traceback.format_exc(limit=8)
53
+
54
+
55
+ threading.Thread(target=_boot_native_demo, daemon=True).start()
56
+
57
+
58
+ # Keep a GPU-decorated function so HF startup checks pass.
59
  @spaces.GPU(duration=60)
60
  def _gpu_healthcheck() -> str:
61
  return "ok"
62
 
63
 
64
  def _viewer_html() -> str:
65
+ if bool(_state.get("ok")):
66
+ src = f"/proxy/{NATIVE_PORT}/"
67
+ else:
68
+ src = os.environ.get("KIMODO_UI_URL", "https://nvidia-kimodo.hf.space").strip()
69
  return (
70
  "<div style='border:1px solid #d9e7ef;border-radius:12px;overflow:hidden;'>"
71
+ f"<iframe src='{src}' title='Kimodo UI' style='width:100%;border:0' "
72
+ "height='920' loading='lazy'></iframe>"
73
  "</div>"
74
  )
75
 
76
 
77
+ def _status_markdown() -> str:
78
+ if bool(_state.get("ok")):
79
+ return f"Native demo running on /proxy/{NATIVE_PORT}/."
80
+ err = _state.get("error")
81
+ if err:
82
+ return (
83
+ "Native demo unavailable, showing fallback UI. "
84
+ f"Reason: {err}"
85
+ )
86
+ return "Starting native demo..."
87
+
88
+
89
+ def _refresh() -> tuple[str, str]:
90
+ return _status_markdown(), _viewer_html()
91
+
92
+
93
  with gr.Blocks(title="Movimento") as demo:
94
  gr.Markdown("# Movimento")
95
+ status_md = gr.Markdown(_status_markdown())
96
+ viewer = gr.HTML(_viewer_html())
97
+ refresh_btn = gr.Button("Refresh UI Status")
98
+ refresh_btn.click(fn=_refresh, inputs=[], outputs=[status_md, viewer])
99
 
100
 
101
  if __name__ == "__main__":
kimodo/demo/__init__.py CHANGED
@@ -1 +1,29 @@
1
- # Demo module for native Kimodo UI integration in Space builds.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # ruff: noqa: I001
5
+ import argparse
6
+
7
+ from kimodo.model import DEFAULT_MODEL
8
+ from kimodo.model.registry import resolve_model_name
9
+
10
+ from .app import Demo
11
+
12
+
13
+ def main() -> None:
14
+ parser = argparse.ArgumentParser(description="Run the kimodo demo UI.")
15
+ parser.add_argument(
16
+ "--model",
17
+ type=str,
18
+ default=DEFAULT_MODEL,
19
+ help="Default model to load (e.g. Kimodo-SOMA-RP-v1, kimodo-soma-rp, or SOMA).",
20
+ )
21
+ args = parser.parse_args()
22
+
23
+ resolved = resolve_model_name(args.model, "Kimodo")
24
+ demo = Demo(default_model_name=resolved)
25
+ demo.run()
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
kimodo/demo/__main__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Entry point for `python -m kimodo.demo`."""
4
+
5
+ from kimodo.demo import main
6
+
7
+ if __name__ == "__main__":
8
+ main()
kimodo/demo/app.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import base64
5
+ import os
6
+ import shutil
7
+ import threading
8
+ import time
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ import viser
15
+ from kimodo.assets import DEMO_ASSETS_ROOT
16
+ from kimodo.model.load_model import load_model
17
+ from kimodo.model.registry import resolve_model_name
18
+ from kimodo.skeleton import SkeletonBase, SOMASkeleton30
19
+ from kimodo.tools import load_json
20
+ from kimodo.viz import viser_utils
21
+ from kimodo.viz.viser_utils import (
22
+ Character,
23
+ CharacterMotion,
24
+ EEJointsKeyframeSet,
25
+ FullbodyKeyframeSet,
26
+ RootKeyframe2DSet,
27
+ )
28
+ from viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage
29
+
30
+ from . import generation, ui
31
+ from .config import (
32
+ DARK_THEME,
33
+ DEFAULT_CUR_DURATION,
34
+ DEFAULT_MODEL,
35
+ DEFAULT_PLAYBACK_SPEED,
36
+ DEFAULT_PROMPT,
37
+ DEMO_UI_QUICK_START_MODAL_MD,
38
+ EXAMPLES_ROOT_DIR,
39
+ HF_MODE,
40
+ LIGHT_THEME,
41
+ MAX_ACTIVE_USERS,
42
+ MAX_DURATION,
43
+ MAX_SESSION_MINUTES,
44
+ MIN_DURATION,
45
+ MODEL_EXAMPLES_DIRS,
46
+ MODEL_NAMES,
47
+ SERVER_NAME,
48
+ SERVER_PORT,
49
+ )
50
+ from .embedding_cache import CachedTextEncoder
51
+ from .queue_manager import QueueManager, UserQueue
52
+ from .state import ClientSession, ModelBundle
53
+
54
+
55
+ class Demo:
56
+ def __init__(self, default_model_name: str = DEFAULT_MODEL):
57
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
58
+ print(f"Using device: {self.device}")
59
+ self.models: dict[str, ModelBundle] = {}
60
+ resolved = resolve_model_name(default_model_name, "Kimodo")
61
+ if resolved not in MODEL_NAMES:
62
+ raise ValueError(f"Unknown model '{default_model_name}'. Expected one of: {MODEL_NAMES}")
63
+ self.default_model_name = resolved
64
+ self.ensure_examples_layout()
65
+ self.load_model(self.default_model_name)
66
+
67
+ # Serialize GPU-bound generation across all clients
68
+ self._generation_lock = threading.Lock()
69
+ self._cuda_healthy = True
70
+
71
+ # Per-client sessions
72
+ self.client_sessions: dict[int, ClientSession] = {}
73
+ self.start_direction_markers: dict[int, viser_utils.WaypointMesh] = {}
74
+ self.grid_handles: dict[int, viser.GridHandle] = {}
75
+
76
+ self.server = viser.ViserServer(
77
+ host=SERVER_NAME,
78
+ port=SERVER_PORT,
79
+ label="Kimodo",
80
+ enable_camera_keyboard_controls=False, # don't move the camera with the arrow keys
81
+ )
82
+ self.server.scene.world_axes.visible = False # used for debugging
83
+ self.server.scene.set_up_direction("+y")
84
+
85
+ # Register callbacks for session handling
86
+ self.server.on_client_connect(self.on_client_connect)
87
+ self.server.on_client_disconnect(self.on_client_disconnect)
88
+
89
+ # HF mode: queue and session limit
90
+ if HF_MODE:
91
+ self.user_queue = UserQueue(MAX_ACTIVE_USERS, MAX_SESSION_MINUTES)
92
+ self.queue_manager = QueueManager(
93
+ queue=self.user_queue,
94
+ server=self.server,
95
+ setup_demo_for_client=self._setup_demo_for_client,
96
+ cleanup_session=self._cleanup_session_for_client,
97
+ )
98
+ else:
99
+ self.user_queue = None
100
+ self.queue_manager = None
101
+
102
+ # create grid and floor
103
+ self.floor_len = 20.0 # meters
104
+
105
+ def ensure_examples_layout(self) -> None:
106
+ os.makedirs(EXAMPLES_ROOT_DIR, exist_ok=True)
107
+ for model_dir in MODEL_EXAMPLES_DIRS.values():
108
+ os.makedirs(model_dir, exist_ok=True)
109
+
110
+ for entry in os.listdir(EXAMPLES_ROOT_DIR):
111
+ if entry in MODEL_EXAMPLES_DIRS:
112
+ continue
113
+ src = os.path.join(EXAMPLES_ROOT_DIR, entry)
114
+ if not os.path.isdir(src):
115
+ continue
116
+ dst = os.path.join(
117
+ MODEL_EXAMPLES_DIRS.get(DEFAULT_MODEL, next(iter(MODEL_EXAMPLES_DIRS.values()))),
118
+ entry,
119
+ )
120
+ if not os.path.exists(dst):
121
+ shutil.move(src, dst)
122
+
123
+ def get_examples_base_dir(self, model_name: str, absolute: bool = True) -> str:
124
+ return MODEL_EXAMPLES_DIRS[model_name]
125
+
126
+ def load_model(self, model_name: str) -> ModelBundle:
127
+ if model_name in self.models:
128
+ return self.models[model_name]
129
+
130
+ print(f"Loading model {model_name}...")
131
+ try:
132
+ model = load_model(modelname=model_name, device=self.device)
133
+ except Exception as e:
134
+ print(f"Error loading model: {e}\nMake sure text encoder server is running!")
135
+ raise e
136
+
137
+ if hasattr(model, "text_encoder"):
138
+ model.text_encoder = CachedTextEncoder(model.text_encoder, model_name=model_name)
139
+
140
+ skeleton = model.motion_rep.skeleton
141
+ if isinstance(skeleton, SOMASkeleton30):
142
+ skeleton = skeleton.somaskel77.to(model.device)
143
+ bundle = ModelBundle(
144
+ model=model,
145
+ motion_rep=model.motion_rep,
146
+ skeleton=skeleton,
147
+ model_fps=model.motion_rep.fps,
148
+ )
149
+ self.models[model_name] = bundle
150
+ print(f"Model {model_name} loaded successfully")
151
+ self.prewarm_embedding_cache(model_name, bundle.model)
152
+ return bundle
153
+
154
+ def prewarm_embedding_cache(self, model_name: str, model: object) -> None:
155
+ encoder = getattr(model, "text_encoder", None)
156
+ if not isinstance(encoder, CachedTextEncoder):
157
+ return
158
+
159
+ prompt_set = set()
160
+ prompt_set.add(DEFAULT_PROMPT)
161
+
162
+ examples_dir = MODEL_EXAMPLES_DIRS.get(model_name)
163
+ if examples_dir and os.path.isdir(examples_dir):
164
+ for entry in os.listdir(examples_dir):
165
+ example_dir = os.path.join(examples_dir, entry)
166
+ if not os.path.isdir(example_dir):
167
+ continue
168
+ meta_path = os.path.join(example_dir, "meta.json")
169
+ if not os.path.exists(meta_path):
170
+ continue
171
+ try:
172
+ meta = load_json(meta_path)
173
+ except Exception:
174
+ continue
175
+ for prompt in meta.get("prompts_text", []):
176
+ if isinstance(prompt, str):
177
+ prompt_set.add(prompt)
178
+
179
+ if prompt_set:
180
+ try:
181
+ encoder.prewarm(list(prompt_set))
182
+ except Exception as error:
183
+ # Startup should not fail if text encoder is still warming up.
184
+ print(f"Warning: embedding prewarm skipped: {error}")
185
+
186
+ def build_constraint_tracks(
187
+ self, client: viser.ClientHandle, skeleton: SkeletonBase
188
+ ) -> dict[str, viser_utils.ConstraintSet]:
189
+ return {
190
+ "Full-Body": FullbodyKeyframeSet(
191
+ name="Full-Body",
192
+ server=client,
193
+ skeleton=skeleton,
194
+ ),
195
+ "End-Effectors": EEJointsKeyframeSet(
196
+ name="End-Effectors",
197
+ server=client,
198
+ skeleton=skeleton,
199
+ ),
200
+ "2D Root": RootKeyframe2DSet(
201
+ name="2D Root",
202
+ server=client,
203
+ skeleton=skeleton,
204
+ ),
205
+ }
206
+
207
+ def set_timeline_defaults(self, timeline, model_fps: float) -> None:
208
+ timeline.set_defaults(
209
+ default_text=DEFAULT_PROMPT,
210
+ default_duration=int(DEFAULT_CUR_DURATION * model_fps - 1),
211
+ min_duration=int(MIN_DURATION * model_fps - 1), # 2 seconds minimum,
212
+ max_duration=int(
213
+ MAX_DURATION * model_fps - 1 # - NB_TRANSITION_FRAMES
214
+ ), # 10 seconds maximum, minus the transition frames, if needed
215
+ default_num_frames_zoom=int(1.10 * 10 * model_fps), # a bit more than the max
216
+ max_frames_zoom=1000,
217
+ fps=model_fps,
218
+ )
219
+
220
+ def _apply_constraint_overlay_visibility(self, session: ClientSession) -> None:
221
+ """Apply show-all vs show-only-current-frame to constraint overlays."""
222
+ only_frame = session.frame_idx if session.show_only_current_constraint else None
223
+ for constraint in session.constraints.values():
224
+ constraint.set_overlay_visibility(only_frame)
225
+
226
+ def set_constraint_tracks_visible(self, session: ClientSession, visible: bool) -> None:
227
+ timeline = session.client.timeline
228
+ timeline_data = session.timeline_data
229
+ if timeline_data.get("constraint_tracks_visible", True) == visible:
230
+ return
231
+
232
+ with timeline_data["keyframe_update_lock"]:
233
+ if visible:
234
+ for track_id, track_info in timeline_data["tracks"].items():
235
+ timeline.add_track(
236
+ track_info["name"],
237
+ track_type=track_info.get("track_type", "keyframe"),
238
+ color=track_info.get("color"),
239
+ height_scale=track_info.get("height_scale", 1.0),
240
+ uuid=track_id,
241
+ )
242
+
243
+ for keyframe_id, keyframe_data in timeline_data["keyframes"].items():
244
+ timeline.add_keyframe(
245
+ track_id=keyframe_data["track_id"],
246
+ frame=keyframe_data["frame"],
247
+ value=keyframe_data.get("value"),
248
+ opacity=keyframe_data.get("opacity", 1.0),
249
+ locked=keyframe_data.get("locked", False),
250
+ uuid=keyframe_id,
251
+ )
252
+
253
+ for interval_id, interval_data in timeline_data["intervals"].items():
254
+ timeline.add_interval(
255
+ track_id=interval_data["track_id"],
256
+ start_frame=interval_data["start_frame_idx"],
257
+ end_frame=interval_data["end_frame_idx"],
258
+ value=interval_data.get("value"),
259
+ opacity=interval_data.get("opacity", 1.0),
260
+ locked=interval_data.get("locked", False),
261
+ uuid=interval_id,
262
+ )
263
+ else:
264
+ for track_id in list(timeline_data["tracks"].keys()):
265
+ timeline.remove_track(track_id)
266
+
267
+ timeline_data["constraint_tracks_visible"] = visible
268
+
269
+ def _cleanup_session_for_client(self, client_id: int) -> None:
270
+ """Remove session and scene state for a client (e.g. on session expiry)."""
271
+ if client_id in self.client_sessions:
272
+ del self.client_sessions[client_id]
273
+ self.start_direction_markers.pop(client_id, None)
274
+ self.grid_handles.pop(client_id, None)
275
+
276
+ def _setup_demo_for_client(self, client: viser.ClientHandle) -> None:
277
+ """Initialize scene, GUI, and session state for a client (no modals)."""
278
+ self.setup_scene(client)
279
+
280
+ model_bundle = self.load_model(self.default_model_name)
281
+
282
+ # Initialize each empty constraint track
283
+ constraint_tracks = self.build_constraint_tracks(client, model_bundle.skeleton)
284
+
285
+ # Create GUI elements for this client
286
+ (
287
+ gui_elements,
288
+ timeline_tracks,
289
+ example_dict,
290
+ gui_examples_dropdown,
291
+ gui_save_example_path_text,
292
+ gui_model_selector,
293
+ ) = ui.create_gui(
294
+ demo=self,
295
+ client=client,
296
+ model_name=self.default_model_name,
297
+ model_fps=model_bundle.model_fps,
298
+ )
299
+ timeline_data = {
300
+ "tracks": timeline_tracks,
301
+ "tracks_ids": {val["name"]: key for key, val in timeline_tracks.items()},
302
+ "keyframes": {},
303
+ "intervals": {},
304
+ "keyframe_update_lock": threading.Lock(),
305
+ "keyframe_move_timers": {},
306
+ "pending_keyframe_moves": {}, # keyframe_id -> new_frame
307
+ "constraint_tracks_visible": True,
308
+ "dense_path_after_release_timer": None,
309
+ }
310
+
311
+ # Initialize session state
312
+ cur_duration = DEFAULT_CUR_DURATION
313
+ max_frame_idx = int(cur_duration * model_bundle.model_fps - 1)
314
+
315
+ session = ClientSession(
316
+ client=client,
317
+ gui_elements=gui_elements,
318
+ motions={},
319
+ constraints=constraint_tracks,
320
+ timeline_data=timeline_data,
321
+ frame_idx=0,
322
+ playing=False,
323
+ playback_speed=DEFAULT_PLAYBACK_SPEED,
324
+ cur_duration=cur_duration,
325
+ max_frame_idx=max_frame_idx,
326
+ updating_motions=False,
327
+ edit_mode=False,
328
+ model_name=self.default_model_name,
329
+ model_fps=model_bundle.model_fps,
330
+ skeleton=model_bundle.skeleton,
331
+ motion_rep=model_bundle.motion_rep,
332
+ examples_base_dir=self.get_examples_base_dir(self.default_model_name, absolute=True),
333
+ example_dict=example_dict,
334
+ gui_examples_dropdown=gui_examples_dropdown,
335
+ gui_save_example_path_text=gui_save_example_path_text,
336
+ gui_model_selector=gui_model_selector,
337
+ )
338
+
339
+ self.client_sessions[client.client_id] = session
340
+
341
+ # Initialize default character for this client
342
+ self.add_character_motion(client, session.skeleton)
343
+
344
+ def on_client_connect(self, client: viser.ClientHandle) -> None:
345
+ """Initialize GUI and state for each new client."""
346
+ print(f"Client {client.client_id} connected")
347
+
348
+ if HF_MODE and self.queue_manager is not None:
349
+ self.queue_manager.on_client_connect(client)
350
+ else:
351
+ # Show quick start popup when a browser client connects (non-HF mode).
352
+ with client.gui.add_modal(
353
+ "Welcome — Quick Start",
354
+ size="xl",
355
+ show_close_button=True,
356
+ save_choice="kimodo.demo.quick_start_ack",
357
+ ) as modal:
358
+ client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
359
+ client.gui.add_button("Got it (don't remind me again)").on_click(lambda _event: modal.close())
360
+ self._setup_demo_for_client(client)
361
+
362
+ def setup_scene(self, client: viser.ClientHandle) -> None:
363
+ self.configure_theme(client)
364
+ client.camera.position = np.array(
365
+ [2.7417358737841426, 1.8790455698853281, 7.675741569777456],
366
+ dtype=np.float64,
367
+ )
368
+ client.camera.look_at = np.array([0.0, 0.0, 0.0], dtype=np.float64)
369
+ client.camera.up_direction = np.array(
370
+ [-1.1102230246251568e-16, 1.0, 1.3596310734468913e-32],
371
+ dtype=np.float64,
372
+ )
373
+ client.camera.fov = np.deg2rad(45.0)
374
+ grid_handle = client.scene.add_grid(
375
+ "/grid",
376
+ width=self.floor_len,
377
+ height=self.floor_len,
378
+ wxyz=viser.transforms.SO3.from_x_radians(-np.pi / 2.0).wxyz,
379
+ position=(0.0, 0.0001, 0.0),
380
+ fade_distance=3 * self.floor_len,
381
+ section_color=LIGHT_THEME["grid"],
382
+ infinite_grid=True,
383
+ )
384
+ self.grid_handles[client.client_id] = grid_handle
385
+ # marker for origin
386
+ origin_waypoint = viser_utils.WaypointMesh(
387
+ "/origin_waypoint",
388
+ client,
389
+ position=np.array([0.0, 0.0, 0.0]),
390
+ heading=np.array([0.0, 1.0]),
391
+ color=(0, 0, 255),
392
+ )
393
+ self.start_direction_markers[client.client_id] = origin_waypoint
394
+
395
+ def on_client_disconnect(self, client: viser.ClientHandle) -> None:
396
+ """Clean up when client disconnects."""
397
+ print(f"Client {client.client_id} disconnected")
398
+ client_id = client.client_id
399
+
400
+ if HF_MODE and self.queue_manager is not None:
401
+ self.queue_manager.on_client_disconnect(client_id)
402
+
403
+ self._cleanup_session_for_client(client_id)
404
+
405
+ def set_start_direction_visible(self, client_id: int, visible: bool) -> None:
406
+ marker = self.start_direction_markers.get(client_id)
407
+ if marker is None:
408
+ return
409
+ marker.set_visible(visible)
410
+
411
+ def client_active(self, client_id: int) -> bool:
412
+ return client_id in self.client_sessions
413
+
414
+ def add_character_motion(
415
+ self,
416
+ client: viser.ClientHandle,
417
+ skeleton: SkeletonBase,
418
+ joints_pos: Optional[torch.Tensor] = None,
419
+ joints_rot: Optional[torch.Tensor] = None,
420
+ foot_contacts: Optional[torch.Tensor] = None,
421
+ ) -> None:
422
+ client_id = client.client_id
423
+ if not self.client_active(client_id):
424
+ return
425
+ session = self.client_sessions[client_id]
426
+
427
+ ci = len(session.motions)
428
+ character_name = f"character{ci}"
429
+ # build character skeleton and skinning mesh
430
+ if "g1" in session.model_name:
431
+ mesh_mode = "g1_stl"
432
+ elif "smplx" in session.model_name:
433
+ mesh_mode = "smplx_skin"
434
+ elif "soma" in session.model_name:
435
+ if session.gui_elements.gui_use_soma_layer_checkbox.value:
436
+ mesh_mode = "soma_layer_skin"
437
+ else:
438
+ mesh_mode = "soma_skin"
439
+ else:
440
+ raise ValueError("The model name is not recognized for skinning.")
441
+
442
+ new_character = Character(
443
+ character_name,
444
+ client,
445
+ skeleton,
446
+ create_skeleton_mesh=True,
447
+ create_skinned_mesh=True,
448
+ visible_skeleton=False, # don't show immediately
449
+ visible_skinned_mesh=False, # don't show immediately
450
+ skinned_mesh_opacity=session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value,
451
+ show_foot_contacts=session.gui_elements.gui_viz_foot_contacts_checkbox.value,
452
+ dark_mode=session.gui_elements.gui_dark_mode_checkbox.value,
453
+ mesh_mode=mesh_mode,
454
+ gui_use_soma_layer_checkbox=session.gui_elements.gui_use_soma_layer_checkbox,
455
+ )
456
+
457
+ # if no motion given, initialize to character default (rest) pose for one frame
458
+ init_joints_pos, init_joints_rot = new_character.get_pose()
459
+ if joints_pos is None:
460
+ joints_pos = init_joints_pos[None].repeat(session.max_frame_idx + 1, 1, 1)
461
+ if joints_rot is None:
462
+ joints_rot = init_joints_rot[None].repeat(session.max_frame_idx + 1, 1, 1, 1)
463
+
464
+ new_motion = CharacterMotion(new_character, joints_pos, joints_rot, foot_contacts)
465
+ # save the motion in our dict
466
+ session.motions[character_name] = new_motion
467
+
468
+ # put the character at the right frame
469
+ new_motion.set_frame(session.frame_idx)
470
+
471
+ # put them visible with a small delay
472
+ # so that the set_frame function has time to finish
473
+ def _set_visibility():
474
+ new_motion.character.set_skinned_mesh_visibility(session.gui_elements.gui_viz_skinned_mesh_checkbox.value)
475
+ new_motion.character.set_skeleton_visibility(session.gui_elements.gui_viz_skeleton_checkbox.value)
476
+
477
+ timer = threading.Timer(
478
+ 0.2, # 0.2s delay
479
+ _set_visibility,
480
+ )
481
+ timer.start()
482
+
483
+ def clear_motions(self, client_id: int) -> None:
484
+ if not self.client_active(client_id):
485
+ return
486
+ session = self.client_sessions[client_id]
487
+ for motion in list(session.motions.values()):
488
+ motion.clear()
489
+ session.motions.clear()
490
+
491
+ def compute_model_constraints_lst(
492
+ self,
493
+ session: ClientSession,
494
+ model_bundle: ModelBundle,
495
+ num_frames: int,
496
+ ):
497
+ return generation.compute_model_constraints_lst(session, model_bundle, num_frames, self.device)
498
+
499
+ def check_cuda_health(self) -> bool:
500
+ """Check if CUDA is still functional.
501
+
502
+ Trigger auto-restart if corrupted.
503
+ """
504
+ if self.device == "cpu":
505
+ return True
506
+ try:
507
+ torch.tensor([1.0], device=self.device) + torch.tensor([1.0], device=self.device)
508
+ return True
509
+ except RuntimeError as e:
510
+ if "device-side assert" in str(e) or "CUDA error" in str(e):
511
+ if self._cuda_healthy:
512
+ self._cuda_healthy = False
513
+ print("FATAL: CUDA context is corrupted (device-side assert). " "The process must be restarted.")
514
+ self._trigger_restart()
515
+ return False
516
+ raise
517
+
518
+ def _trigger_restart(self) -> None:
519
+ """Exit the process so the HF Space (or systemd/Docker) can restart it."""
520
+ import sys
521
+
522
+ print("Initiating automatic restart due to unrecoverable CUDA error...")
523
+ sys.stdout.flush()
524
+ sys.stderr.flush()
525
+ os._exit(1)
526
+
527
+ def generate(
528
+ self,
529
+ client: viser.ClientHandle,
530
+ prompts: list[str],
531
+ num_frames: list[int],
532
+ num_samples: int,
533
+ seed: int,
534
+ diffusion_steps: int,
535
+ cfg_weight: Optional[list[float]] = None,
536
+ cfg_type: Optional[str] = None,
537
+ postprocess_parameters: Optional[dict] = None,
538
+ transitions_parameters: Optional[dict] = None,
539
+ real_robot_rotations: bool = False,
540
+ ) -> None:
541
+ if not self._cuda_healthy:
542
+ raise RuntimeError("CUDA is in a corrupted state. The space is restarting...")
543
+
544
+ locked = self._generation_lock.acquire(blocking=False)
545
+ if not locked:
546
+ waiting_notif = client.add_notification(
547
+ title="Waiting for GPU...",
548
+ body="Another generation is in progress. Yours will start automatically.",
549
+ loading=True,
550
+ with_close_button=False,
551
+ )
552
+ self._generation_lock.acquire()
553
+ waiting_notif.remove()
554
+
555
+ try:
556
+ session = self.client_sessions[client.client_id]
557
+ model_bundle = self.load_model(session.model_name)
558
+ generation.generate(
559
+ client=client,
560
+ session=session,
561
+ model_bundle=model_bundle,
562
+ prompts=prompts,
563
+ num_frames=num_frames,
564
+ num_samples=num_samples,
565
+ seed=seed,
566
+ diffusion_steps=diffusion_steps,
567
+ cfg_weight=cfg_weight,
568
+ cfg_type=cfg_type,
569
+ postprocess_parameters=postprocess_parameters,
570
+ transitions_parameters=transitions_parameters,
571
+ real_robot_rotations=real_robot_rotations,
572
+ device=self.device,
573
+ clear_motions=self.clear_motions,
574
+ add_character_motion=self.add_character_motion,
575
+ )
576
+ finally:
577
+ self._generation_lock.release()
578
+
579
+ def set_frame(self, client_id: int, frame_idx: int, update_timeline: bool = True):
580
+ if not self.client_active(client_id):
581
+ return
582
+
583
+ session = self.client_sessions[client_id]
584
+
585
+ session.frame_idx = frame_idx
586
+ if update_timeline:
587
+ session.client.timeline.set_current_frame(frame_idx)
588
+ for motion in list(session.motions.values()):
589
+ motion.set_frame(frame_idx)
590
+ self._apply_constraint_overlay_visibility(session)
591
+
592
+ def run(self) -> None:
593
+ update_counter = 0
594
+ cuda_check_interval = 300
595
+ while True:
596
+ last_update_time = time.time()
597
+ if self.models:
598
+ # the max playback speed is 2x the model fps (from gui_playback_speed_buttons)
599
+ playback_fps = max(bundle.model_fps for bundle in self.models.values()) * 2.0
600
+ else:
601
+ playback_fps = 60.0
602
+
603
+ # update each client session independently
604
+ # copy to a list first to avoid changing size if client disconnects
605
+ for client_id, session in list(self.client_sessions.items()):
606
+ update_interval = int(playback_fps / (session.playback_speed * session.model_fps))
607
+ new_frame_idx = session.frame_idx
608
+ if session.playing and update_counter % update_interval == 0:
609
+ if session.frame_idx >= session.max_frame_idx:
610
+ new_frame_idx = 0
611
+ else:
612
+ new_frame_idx = session.frame_idx + 1
613
+
614
+ # make sure the client is still active before updating the frame
615
+ if self.client_active(client_id):
616
+ self.set_frame(client_id, new_frame_idx)
617
+
618
+ if update_counter % cuda_check_interval == 0:
619
+ self.check_cuda_health()
620
+
621
+ time_remaining = max(0, 1.0 / playback_fps - (time.time() - last_update_time))
622
+ time.sleep(time_remaining)
623
+ update_counter += 1
624
+ update_counter %= playback_fps # wrap around to 0 every second
625
+
626
+ def configure_theme(
627
+ self,
628
+ client: viser.ClientHandle,
629
+ dark_mode: bool = False,
630
+ titlebar_dark_mode_checkbox_uuid: str | None = None,
631
+ ):
632
+ # Sync grid color with theme (light vs dark)
633
+ theme = DARK_THEME if dark_mode else LIGHT_THEME
634
+ grid_handle = self.grid_handles.get(client.client_id)
635
+ if grid_handle is not None:
636
+ grid_handle.section_color = theme["grid"]
637
+
638
+ #
639
+ # setup theme
640
+ #
641
+ buttons = (
642
+ TitlebarButton(
643
+ text="Documentation",
644
+ icon="Description",
645
+ href="https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html",
646
+ ),
647
+ TitlebarButton(
648
+ text="Project Page",
649
+ icon=None,
650
+ href="https://research.nvidia.com/labs/sil/projects/kimodo/",
651
+ ),
652
+ TitlebarButton(
653
+ text="Github",
654
+ icon="GitHub",
655
+ href="https://github.com/nv-tlabs/kimodo",
656
+ ),
657
+ )
658
+ assets_dir = DEMO_ASSETS_ROOT
659
+ logo_light_path = assets_dir / "nvidia_logo.png"
660
+ logo_dark_path = assets_dir / "nvidia_logo_dark.png"
661
+ if logo_light_path.exists():
662
+ light_b64 = base64.standard_b64encode(logo_light_path.read_bytes()).decode("ascii")
663
+ dark_b64 = (
664
+ base64.standard_b64encode(logo_dark_path.read_bytes()).decode("ascii")
665
+ if logo_dark_path.exists()
666
+ else None
667
+ )
668
+ image = TitlebarImage(
669
+ image_url_light=f"data:image/png;base64,{light_b64}",
670
+ image_url_dark=(f"data:image/png;base64,{dark_b64}" if dark_b64 else None),
671
+ image_alt="NVIDIA",
672
+ href="https://www.nvidia.com/",
673
+ )
674
+ else:
675
+ image = None
676
+ titlebar_theme = TitlebarConfig(buttons=buttons, image=image, title_text="Kimodo")
677
+ client.gui.set_panel_label("Kimodo")
678
+ client.gui.configure_theme(
679
+ titlebar_content=titlebar_theme,
680
+ control_layout="floating", # "floating", # ['floating', 'collapsible', 'fixed']
681
+ control_width="large", # ['small', 'medium', 'large']
682
+ dark_mode=dark_mode,
683
+ show_logo=False, # hide viser logo on bottom left corner
684
+ show_share_button=False,
685
+ titlebar_dark_mode_checkbox_uuid=titlebar_dark_mode_checkbox_uuid,
686
+ brand_color=(152, 189, 255), # (60, 131, 0), # (R, G, B) tuple
687
+ )
kimodo/demo/config.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+
6
+ from kimodo.assets import DEMO_EXAMPLES_ROOT
7
+ from kimodo.model.registry import (
8
+ AVAILABLE_MODELS,
9
+ DEFAULT_MODEL,
10
+ FRIENDLY_NAMES,
11
+ get_datasets,
12
+ get_model_info,
13
+ get_models_for_dataset_skeleton,
14
+ get_short_key_from_display_name,
15
+ get_skeleton_display_name,
16
+ get_skeleton_display_names_for_dataset,
17
+ get_skeleton_key_from_display_name,
18
+ get_skeletons_for_dataset,
19
+ get_versions_for_dataset_skeleton,
20
+ resolve_to_short_key,
21
+ )
22
+
23
+ SERVER_NAME = os.environ.get("SERVER_NAME", "0.0.0.0")
24
+ SERVER_PORT = int(os.environ.get("SERVER_PORT") or os.environ.get("PORT", "7860"))
25
+
26
+
27
+ def _env_bool(name: str, default: bool = False) -> bool:
28
+ raw = os.environ.get(name)
29
+ if raw is None:
30
+ return default
31
+ return str(raw).strip().lower() in ("1", "true", "yes", "on")
32
+
33
+
34
+ HF_MODE = _env_bool("HF_MODE", False)
35
+
36
+ # HF mode: user queue and session limit (override via env in Spaces)
37
+ MAX_ACTIVE_USERS = int(os.environ.get("MAX_ACTIVE_USERS", "5"))
38
+ MAX_SESSION_MINUTES = float(os.environ.get("MAX_SESSION_MINUTES", "5.0"))
39
+
40
+ DEFAULT_PLAYBACK_SPEED = 1.0
41
+ # default start duration is 6.0 sec, but model can handle up to 10 sec
42
+ DEFAULT_CUR_DURATION = 6.0
43
+ DEFAULT_PROMPT = "A person walks forward."
44
+ MIN_DURATION = 2.0
45
+ MAX_DURATION = 10.0
46
+
47
+ SHOW_TRANSITION_PARAMS = False
48
+ INIT_POSTPROCESSING = True
49
+ NB_TRANSITION_FRAMES = 5
50
+
51
+ LIGHT_THEME = dict(
52
+ floor=(220, 220, 220),
53
+ grid=(180, 180, 180),
54
+ )
55
+
56
+ # Dark theme: slightly lighter grid and floor for better visibility and less flat black
57
+ DARK_THEME = dict(
58
+ floor=(48, 48, 52),
59
+ grid=(105, 105, 110),
60
+ )
61
+
62
+ EXAMPLES_ROOT_DIR = str(DEMO_EXAMPLES_ROOT)
63
+
64
+ # Model list and paths from kimodo registry (all models: Kimodo + TMR)
65
+ MODEL_NAMES = tuple(AVAILABLE_MODELS)
66
+ MODEL_EXAMPLES_DIRS = {name: os.path.join(EXAMPLES_ROOT_DIR, name) for name in MODEL_NAMES}
67
+ # Display labels for backward compatibility (short_key -> display name)
68
+ MODEL_LABELS = {name: FRIENDLY_NAMES.get(name, f"Model ({name})") for name in MODEL_NAMES}
69
+ MODEL_LABEL_TO_NAME = {label: name for name, label in MODEL_LABELS.items()}
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # Demo UI copy
73
+ # -----------------------------------------------------------------------------
74
+
75
+ DEMO_UI_QUICK_START_CORE_MD = """
76
+ ### Camera
77
+ - **Left-drag**: rotate
78
+ - **Right-drag**: pan
79
+ - **Scroll**: zoom
80
+
81
+ ### Playback
82
+ - **Space** to play/pause
83
+ - **←/→** to step frames, or click the frame number.
84
+ - **Scroll up/down** in the timeline: move left/right
85
+ - **Shift + scroll** in the timeline: zoom in/out
86
+
87
+ ### Prompts
88
+ - **Double-click** a text prompt to edit it.
89
+ - **Click and drag** the right edge of a prompt box to extend/shorten it.
90
+ - **Click empty space** to add a prompt.
91
+ - **Right-click** a prompt to delete it.
92
+
93
+ ### Generate
94
+ - Go to the **Generate** tab to modify options
95
+ - It is also possible to **load** examples
96
+ - Click **Generate** to generate a motion
97
+
98
+ ### Constraints
99
+ - This is **optional**: should be use after a first generation
100
+ - **Click** in the timeline tracks (Full-Body / 2D root etc) to add a constraint.
101
+ - **Right-click** on a constraint to delete it.
102
+ - To **edit** a constraint:
103
+ - Move playback to the target frame
104
+ - Click **Enter Editing Mode** in the Constraints tab.
105
+ """
106
+
107
+ DEMO_UI_QUICK_START_MODAL_MD = (
108
+ DEMO_UI_QUICK_START_CORE_MD
109
+ + """
110
+
111
+ See the **Instructions** tab for the full user manual.
112
+ """
113
+ )
114
+
115
+ DEMO_UI_INSTRUCTIONS_TAB_MD = (
116
+ """
117
+ ## How to Use This Demo
118
+
119
+ """
120
+ + DEMO_UI_QUICK_START_CORE_MD
121
+ + """
122
+
123
+ ---
124
+
125
+ ### Generating Motion (step-by-step)
126
+
127
+ 1. **Edit the text prompts** in the timeline (e.g., "A person walks forward.")
128
+ 2. **Modify the duration** by moving the right edge of each prompts (2–10 seconds)
129
+ 3. **Add constraints** (optional) to control the motion:
130
+ - Click **Enter Editing Mode** to adjust the character pose
131
+ - Use the timeline to place keyframes or intervals in constraint tracks (see below)
132
+ 4. **Click Generate** to create the motion
133
+ 5. If generating multiple samples, **click on a mesh** to select which one to keep
134
+
135
+ ### Timeline Editing
136
+
137
+ **Adding Constraints:**
138
+ 1. Click anywhere on the timeline to add a keyframe at that frame. The keyframe is created based on the current character motion.
139
+ 2. Ctrl/Cmd+click+drag to add an interval constraint, or expand a keyframe into an interval
140
+ 3. Enter editing mode with the **Enter Editing Mode** button to adjust character pose before/after adding constraints.
141
+
142
+ **Constraint Types:**
143
+ - **Full-Body**: constrains the entire character pose
144
+ - **2D Root**: constrains the character's path on the ground plane
145
+ - Enable **Densify** to create a continuous path
146
+ - **End-Effectors**: constrains hands and feet positions
147
+ - Use separate tracks for Left/Right Hand/Foot
148
+
149
+
150
+ **Moving & Deleting:**
151
+ - **Drag keyframes/intervals** to move them to different frames
152
+ - **Right-click** a keyframe or interval to delete it
153
+ - Use **Clear All Constraints** to remove everything
154
+
155
+ **Tips:**
156
+ - The posing skeleton becomes visible in editing mode for precise positioning
157
+ - Use **Snap to constraint** to align the current frame to a constraint
158
+
159
+ ### Saving & Loading
160
+
161
+ You can save the current constraints or current motion to load in later from the Load/Save menu.
162
+ Saving an **Example** will save the full constraints, motion, and generation metadata.
163
+
164
+ ### Visualization Options
165
+
166
+ Switch to the **Visualize** tab to:
167
+ - Toggle mesh and skeleton visibility
168
+ - Adjust mesh opacity
169
+ - Show/hide foot contact indicators
170
+ - Switch between light and dark modes
171
+ """
172
+ )
kimodo/demo/embedding_cache.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import contextlib
5
+ import contextvars
6
+ import hashlib
7
+ import json
8
+ import os
9
+ import threading
10
+ import time
11
+ from collections import OrderedDict
12
+ from dataclasses import dataclass
13
+ from typing import Iterable, Optional
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from kimodo.sanitize import sanitize_texts
19
+
20
+ _ACTIVE_SESSION = contextvars.ContextVar("kimodo_demo_active_session", default=None)
21
+
22
+
23
+ @dataclass
24
+ class CacheStats:
25
+ hits: int = 0
26
+ misses: int = 0
27
+ disk_hits: int = 0
28
+
29
+
30
+ class EmbeddingCache:
31
+ """Disk-backed text embedding cache with a small in-memory LRU."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ model_name: str,
37
+ encoder_id: str,
38
+ base_dir: Optional[str] = None,
39
+ max_mem_entries: int = 128,
40
+ ) -> None:
41
+ cache_root = base_dir or os.environ.get(
42
+ "kimodo_EMBED_CACHE_DIR",
43
+ os.path.join("~", ".cache", "kimodo_demo", "embeddings"),
44
+ )
45
+ self.base_dir = os.path.expanduser(cache_root)
46
+ self.model_name = model_name
47
+ self.encoder_id = encoder_id
48
+ self.max_mem_entries = max_mem_entries
49
+ self.stats = CacheStats()
50
+
51
+ self._lock = threading.Lock()
52
+ self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict()
53
+ self._index = {}
54
+ self._index_loaded = False
55
+
56
+ def _model_dir(self) -> str:
57
+ return os.path.join(self.base_dir, self.model_name)
58
+
59
+ def _index_path(self) -> str:
60
+ return os.path.join(self._model_dir(), "index.json")
61
+
62
+ def _prewarm_marker_path(self, key: str) -> str:
63
+ return os.path.join(self._model_dir(), f"prewarm_{key}.json")
64
+
65
+ def has_prewarm_marker(self, key: str) -> bool:
66
+ return os.path.exists(self._prewarm_marker_path(key))
67
+
68
+ def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None:
69
+ os.makedirs(self._model_dir(), exist_ok=True)
70
+ payload = {"prompt_count": prompt_count, "updated_at": time.time()}
71
+ tmp_path = f"{self._prewarm_marker_path(key)}.tmp"
72
+ with open(tmp_path, "w", encoding="utf-8") as f:
73
+ json.dump(payload, f)
74
+ os.replace(tmp_path, self._prewarm_marker_path(key))
75
+
76
+ def _load_index(self) -> None:
77
+ if self._index_loaded:
78
+ return
79
+ index_path = self._index_path()
80
+ if os.path.exists(index_path):
81
+ try:
82
+ with open(index_path, "r", encoding="utf-8") as f:
83
+ self._index = json.load(f)
84
+ except json.JSONDecodeError:
85
+ self._index = {}
86
+ self._index_loaded = True
87
+
88
+ def _save_index(self) -> None:
89
+ os.makedirs(self._model_dir(), exist_ok=True)
90
+ tmp_path = f"{self._index_path()}.tmp"
91
+ with open(tmp_path, "w", encoding="utf-8") as f:
92
+ json.dump(self._index, f)
93
+ os.replace(tmp_path, self._index_path())
94
+
95
+ def _make_key(self, text: str) -> str:
96
+ key_src = f"{self.model_name}|{self.encoder_id}|{text}"
97
+ return hashlib.sha256(key_src.encode("utf-8")).hexdigest()
98
+
99
+ def _entry_path(self, key: str) -> str:
100
+ return os.path.join(self._model_dir(), f"{key}.npy")
101
+
102
+ def _mem_get(self, key: str) -> Optional[np.ndarray]:
103
+ if key in self._mem_cache:
104
+ self._mem_cache.move_to_end(key)
105
+ return self._mem_cache[key]
106
+ return None
107
+
108
+ def _mem_put(self, key: str, value: np.ndarray) -> None:
109
+ self._mem_cache[key] = value
110
+ self._mem_cache.move_to_end(key)
111
+ while len(self._mem_cache) > self.max_mem_entries:
112
+ self._mem_cache.popitem(last=False)
113
+
114
+ def _disk_load(self, key: str) -> Optional[np.ndarray]:
115
+ path = self._entry_path(key)
116
+ if not os.path.exists(path):
117
+ return None
118
+ try:
119
+ return np.load(path)
120
+ except Exception:
121
+ return None
122
+
123
+ def _disk_save(self, key: str, value: np.ndarray) -> None:
124
+ os.makedirs(self._model_dir(), exist_ok=True)
125
+ np.save(self._entry_path(key), value)
126
+ self._index[key] = {
127
+ "length": int(value.shape[0]),
128
+ "dtype": str(value.dtype),
129
+ "updated_at": time.time(),
130
+ }
131
+
132
+ def _maybe_use_session_cache(self, texts: list[str]):
133
+ session = _ACTIVE_SESSION.get()
134
+ if session is None:
135
+ return None
136
+ if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None:
137
+ return session.last_prompt_embeddings, session.last_prompt_lengths
138
+ return None
139
+
140
+ def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None:
141
+ session = _ACTIVE_SESSION.get()
142
+ if session is None:
143
+ return
144
+ session.last_prompt_texts = texts
145
+ session.last_prompt_embeddings = tensor
146
+ session.last_prompt_lengths = lengths
147
+
148
+ def get_or_encode(self, texts: Iterable[str], encoder):
149
+ if isinstance(texts, str):
150
+ texts = [texts]
151
+ texts = sanitize_texts(list(texts))
152
+ if len(texts) == 0:
153
+ empty = torch.empty()
154
+ return empty, []
155
+
156
+ session_cache = self._maybe_use_session_cache(texts)
157
+ if session_cache is not None:
158
+ return session_cache
159
+
160
+ arrays: list[Optional[np.ndarray]] = [None] * len(texts)
161
+ lengths: list[int] = [0] * len(texts)
162
+ misses: list[tuple[int, str, str]] = []
163
+
164
+ with self._lock:
165
+ self._load_index()
166
+ for idx, text in enumerate(texts):
167
+ key = self._make_key(text)
168
+ cached = self._mem_get(key)
169
+ if cached is not None:
170
+ arrays[idx] = cached
171
+ lengths[idx] = cached.shape[0]
172
+ self.stats.hits += 1
173
+ continue
174
+
175
+ cached = self._disk_load(key)
176
+ if cached is not None:
177
+ arrays[idx] = cached
178
+ lengths[idx] = cached.shape[0]
179
+ self._mem_put(key, cached)
180
+ self.stats.disk_hits += 1
181
+ continue
182
+
183
+ misses.append((idx, text, key))
184
+ self.stats.misses += 1
185
+
186
+ if misses:
187
+ miss_texts = [text for _, text, _ in misses]
188
+ miss_tensor, miss_lengths = encoder(miss_texts)
189
+ miss_tensor = miss_tensor.detach().cpu()
190
+ miss_tensor_np = miss_tensor.numpy()
191
+
192
+ with self._lock:
193
+ self._load_index()
194
+ for miss_idx, length in enumerate(miss_lengths):
195
+ idx, _text, key = misses[miss_idx]
196
+ arr = miss_tensor_np[miss_idx, :length].copy()
197
+ arrays[idx] = arr
198
+ lengths[idx] = int(length)
199
+ self._mem_put(key, arr)
200
+ self._disk_save(key, arr)
201
+ self._save_index()
202
+
203
+ max_len = max(lengths) if lengths else 0
204
+ feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0
205
+ dtype = arrays[0].dtype if arrays[0] is not None else np.float32
206
+ padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype)
207
+ for idx, arr in enumerate(arrays):
208
+ if arr is None:
209
+ continue
210
+ padded[idx, : arr.shape[0]] = arr
211
+
212
+ result = torch.from_numpy(padded)
213
+ self._update_session_cache(texts, result, lengths)
214
+ return result, lengths
215
+
216
+
217
+ class CachedTextEncoder:
218
+ """Wrapper around a text encoder to add disk-backed caching."""
219
+
220
+ def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None):
221
+ self.encoder = encoder
222
+ self.model_name = model_name
223
+ encoder_id = f"{type(encoder).__name__}"
224
+ self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir)
225
+
226
+ def __call__(self, texts):
227
+ return self.cache.get_or_encode(texts, self.encoder)
228
+
229
+ def prewarm(self, texts) -> None:
230
+ if isinstance(texts, str):
231
+ texts = [texts]
232
+ texts = sanitize_texts(list(texts))
233
+ prewarm_key = hashlib.sha256("|".join(texts).encode("utf-8")).hexdigest()
234
+ if self.cache.has_prewarm_marker(prewarm_key):
235
+ return
236
+ self.cache.get_or_encode(texts, self.encoder)
237
+ self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts))
238
+
239
+ def to(self, device=None, dtype=None):
240
+ if hasattr(self.encoder, "to"):
241
+ self.encoder.to(device=device, dtype=dtype)
242
+ return self
243
+
244
+ @contextlib.contextmanager
245
+ def session_context(self, session):
246
+ token = _ACTIVE_SESSION.set(session)
247
+ try:
248
+ yield
249
+ finally:
250
+ _ACTIVE_SESSION.reset(token)
251
+
252
+ def __getattr__(self, name):
253
+ return getattr(self.encoder, name)
kimodo/demo/generation.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from collections import defaultdict
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ import viser
11
+ from kimodo.constraints import (
12
+ TYPE_TO_CLASS,
13
+ FullBodyConstraintSet,
14
+ Root2DConstraintSet,
15
+ )
16
+ from kimodo.exports.mujoco import apply_g1_real_robot_projection
17
+ from kimodo.skeleton import G1Skeleton34, SOMASkeleton30
18
+ from kimodo.tools import seed_everything
19
+
20
+ from .embedding_cache import CachedTextEncoder
21
+ from .state import ClientSession, ModelBundle
22
+
23
+
24
+ def compute_model_constraints_lst(
25
+ session: ClientSession,
26
+ model_bundle: ModelBundle,
27
+ num_frames: int,
28
+ device: str,
29
+ ):
30
+ """Compute the lst of constraints for the model based on the constraints in viser."""
31
+ assert len(session.motions) == 1, "Only one motion allowed for constrained generation"
32
+ if not session.constraints:
33
+ return []
34
+
35
+ model_skeleton = model_bundle.model.skeleton
36
+ # For SOMA, UI uses somaskel77; extract 30-joint subset for the model
37
+ use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints
38
+ skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None
39
+
40
+ dense_smooth_root_pos_2d = None
41
+ if session.constraints["2D Root"].dense_path:
42
+ # get the full 2d root
43
+ dense_smooth_root_pos_2d = session.constraints["2D Root"].get_constraint_info(device=device)["root_pos"][
44
+ :, [0, 2]
45
+ ]
46
+
47
+ model_constraints = []
48
+ for track_name, constraint in session.constraints.items():
49
+ constraint_info = constraint.get_constraint_info(device=device)
50
+ frame_idx = constraint_info["frame_idx"]
51
+ # drop any constraints outside the generation range
52
+ valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames]
53
+ valid_idx = [i for i, _ in valid_info]
54
+ valid_frame_idx = [fi for _, fi in valid_info]
55
+
56
+ if len(valid_frame_idx) == 0:
57
+ continue
58
+
59
+ frame_indices = torch.tensor(valid_frame_idx)
60
+ if track_name == "2D Root":
61
+ smooth_root_pos_2d = constraint_info["root_pos"][valid_idx][:, [0, 2]].to(device)
62
+ # same as "smooth_root_2d"
63
+ model_constraints.append(
64
+ Root2DConstraintSet(
65
+ model_skeleton,
66
+ frame_indices,
67
+ smooth_root_pos_2d,
68
+ )
69
+ )
70
+ elif track_name == "Full-Body":
71
+ constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
72
+ constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
73
+ if skel_slice is not None:
74
+ constraint_joints_pos = constraint_joints_pos[:, skel_slice]
75
+ constraint_joints_rot = constraint_joints_rot[:, skel_slice]
76
+
77
+ smooth_root_pos_2d = None
78
+ if dense_smooth_root_pos_2d is not None:
79
+ smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices]
80
+
81
+ model_constraints.append(
82
+ FullBodyConstraintSet(
83
+ model_skeleton,
84
+ frame_indices,
85
+ constraint_joints_pos,
86
+ constraint_joints_rot,
87
+ smooth_root_2d=smooth_root_pos_2d,
88
+ )
89
+ )
90
+ elif track_name == "End-Effectors":
91
+ constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
92
+ constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
93
+ if skel_slice is not None:
94
+ constraint_joints_pos = constraint_joints_pos[:, skel_slice]
95
+ constraint_joints_rot = constraint_joints_rot[:, skel_slice]
96
+
97
+ end_effector_type_set_lst = [
98
+ end_effector_type_set
99
+ for i, end_effector_type_set in enumerate(constraint_info["end_effector_type"])
100
+ if i in valid_idx
101
+ ]
102
+
103
+ # regroup the end effector data by type
104
+ cls_idx = defaultdict(list)
105
+ for idx, end_effector_type_set in enumerate(end_effector_type_set_lst):
106
+ for end_effector_type in end_effector_type_set:
107
+ cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx)
108
+
109
+ for cls, lst_idx in cls_idx.items():
110
+ frame_indices_cls = frame_indices[lst_idx]
111
+ smooth_root_pos_2d = None
112
+ if dense_smooth_root_pos_2d is not None:
113
+ smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls]
114
+
115
+ constraint_joints_pos_el = constraint_joints_pos[lst_idx]
116
+ constraint_joints_rot_el = constraint_joints_rot[lst_idx]
117
+
118
+ model_constraints.append(
119
+ cls(
120
+ model_skeleton,
121
+ frame_indices_cls,
122
+ constraint_joints_pos_el,
123
+ constraint_joints_rot_el,
124
+ smooth_root_2d=smooth_root_pos_2d,
125
+ )
126
+ )
127
+ else:
128
+ raise ValueError(f"Unsupported constraint type: {constraint.display_name}")
129
+ return model_constraints
130
+
131
+
132
+ def generate(
133
+ *,
134
+ client: viser.ClientHandle,
135
+ session: ClientSession,
136
+ model_bundle: ModelBundle,
137
+ prompts: list[str],
138
+ num_frames: list[int],
139
+ num_samples: int,
140
+ seed: int,
141
+ diffusion_steps: int,
142
+ cfg_weight: Optional[list[float]] = None,
143
+ cfg_type: Optional[str] = None,
144
+ postprocess_parameters: Optional[dict] = None,
145
+ transitions_parameters: Optional[dict] = None,
146
+ real_robot_rotations: bool = False,
147
+ device: str,
148
+ clear_motions,
149
+ add_character_motion,
150
+ ) -> None:
151
+ client_id = client.client_id
152
+ print(
153
+ f"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})"
154
+ )
155
+
156
+ seed_everything(seed)
157
+
158
+ model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device)
159
+ cfg_weight = cfg_weight or [2.0, 2.0]
160
+ postprocess_parameters = postprocess_parameters or {}
161
+ transitions_parameters = transitions_parameters or {}
162
+
163
+ encoder = getattr(model_bundle.model, "text_encoder", None)
164
+ if isinstance(encoder, CachedTextEncoder):
165
+ with encoder.session_context(session):
166
+ pred_joints_output = model_bundle.model(
167
+ prompts,
168
+ num_frames,
169
+ diffusion_steps,
170
+ multi_prompt=True,
171
+ constraint_lst=model_constraints,
172
+ cfg_weight=cfg_weight,
173
+ num_samples=num_samples,
174
+ cfg_type=cfg_type,
175
+ **(postprocess_parameters | transitions_parameters),
176
+ ) # [B, T, motion_rep_dim]
177
+ else:
178
+ pred_joints_output = model_bundle.model(
179
+ prompts,
180
+ num_frames,
181
+ diffusion_steps,
182
+ multi_prompt=True,
183
+ constraint_lst=model_constraints,
184
+ cfg_weight=cfg_weight,
185
+ num_samples=num_samples,
186
+ cfg_type=cfg_type,
187
+ **(postprocess_parameters | transitions_parameters),
188
+ ) # [B, T, motion_rep_dim]
189
+
190
+ joints_pos = pred_joints_output["posed_joints"] # [B, T, J, 3]
191
+ joints_rot = pred_joints_output["global_rot_mats"]
192
+ foot_contacts = pred_joints_output.get("foot_contacts")
193
+
194
+ # Optionally project G1 to real robot DoF (1-DoF per joint, clamped) for display.
195
+ if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34):
196
+ joints_pos, joints_rot = apply_g1_real_robot_projection(
197
+ session.skeleton,
198
+ pred_joints_output["posed_joints"],
199
+ pred_joints_output["global_rot_mats"],
200
+ clamp_to_limits=True,
201
+ )
202
+
203
+ # Display on characters (callbacks keep this module UI-agnostic).
204
+ clear_motions(client_id)
205
+ # Keep one sample centered at the origin so constraints align.
206
+ spread_factor = 1.0 # meters
207
+ center_idx = num_samples // 2
208
+ x_trans = (np.arange(num_samples) - center_idx) * spread_factor
209
+ for i in range(num_samples):
210
+ cur_joints_pos = joints_pos[i]
211
+ cur_joints_pos[..., 0] += x_trans[i]
212
+ add_character_motion(
213
+ client,
214
+ session.skeleton,
215
+ cur_joints_pos,
216
+ joints_rot[i],
217
+ foot_contacts[i],
218
+ )
kimodo/demo/queue_manager.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """HF mode user queue and session time limit."""
4
+
5
+ import math
6
+ import threading
7
+ import time
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ import viser
12
+
13
+ from .config import DEMO_UI_QUICK_START_MODAL_MD, MAX_SESSION_MINUTES
14
+
15
+ # Link for "Duplicate this Space" on Hugging Face (used in queue and expiry modals).
16
+ DUPLICATE_SPACE_URL = "https://huggingface.co/spaces/nvidia/Kimodo?duplicate=true"
17
+ GITHUB_REPO_URL = "https://github.com/nv-tlabs/kimodo"
18
+
19
+ # How often to refresh queue modal content (position, total, estimated wait).
20
+ QUEUE_MODAL_REFRESH_INTERVAL_SEC = 15
21
+
22
+
23
+ class UserQueue:
24
+ """Thread-safe queue: active users (with activation timestamp) and waiting queue."""
25
+
26
+ def __init__(self, max_active: int, max_minutes: float) -> None:
27
+ self._max_active = max_active
28
+ self._max_minutes = max_minutes
29
+ self._max_seconds = max_minutes * 60.0
30
+ self._active: dict[int, float] = {} # client_id -> activation timestamp
31
+ self._queued: list[int] = []
32
+ self._lock = threading.Lock()
33
+
34
+ def try_activate(self, client_id: int) -> bool:
35
+ """If a slot is free, add client as active and return True.
36
+
37
+ Else return False.
38
+ """
39
+ with self._lock:
40
+ if len(self._active) < self._max_active:
41
+ self._active[client_id] = time.time()
42
+ return True
43
+ return False
44
+
45
+ def enqueue(self, client_id: int) -> None:
46
+ with self._lock:
47
+ if client_id not in self._queued:
48
+ self._queued.append(client_id)
49
+
50
+ def remove(self, client_id: int) -> bool:
51
+ """Remove from active or queue.
52
+
53
+ Returns True if was active.
54
+ """
55
+ with self._lock:
56
+ was_active = client_id in self._active
57
+ self._active.pop(client_id, None)
58
+ if client_id in self._queued:
59
+ self._queued.remove(client_id)
60
+ return was_active
61
+
62
+ def promote_next(self) -> int | None:
63
+ """If queue non-empty, pop first, activate them, return their client_id.
64
+
65
+ Else None.
66
+ """
67
+ with self._lock:
68
+ if not self._queued:
69
+ return None
70
+ client_id = self._queued.pop(0)
71
+ self._active[client_id] = time.time()
72
+ return client_id
73
+
74
+ def get_queue_position(self, client_id: int) -> tuple[int, int] | None:
75
+ """(1-based position, total_in_queue) or None if not queued."""
76
+ with self._lock:
77
+ if client_id not in self._queued:
78
+ return None
79
+ pos = self._queued.index(client_id)
80
+ return (pos + 1, len(self._queued))
81
+
82
+ def get_estimated_wait_seconds(self, client_id: int) -> float:
83
+ """Estimated seconds until this queued client gets a slot."""
84
+ with self._lock:
85
+ if client_id not in self._queued:
86
+ return 0.0
87
+ pos = self._queued.index(client_id) + 1 # 1-based
88
+ # Expiry times of active users (when they free a slot)
89
+ now = time.time()
90
+ expiries = sorted(now + self._max_seconds - (now - t) for t in self._active.values())
91
+ if not expiries:
92
+ return 0.0
93
+ # Nth slot to free (1-indexed) wraps over expiries
94
+ idx = (pos - 1) % len(expiries)
95
+ cycles = (pos - 1) // len(expiries)
96
+ slot_free_time = expiries[idx] + cycles * self._max_seconds
97
+ return max(0.0, slot_free_time - now)
98
+
99
+ def is_active(self, client_id: int) -> bool:
100
+ with self._lock:
101
+ return client_id in self._active
102
+
103
+ def was_active(self, client_id: int) -> bool:
104
+ """True if client is currently active (for use when already holding lock)."""
105
+ return client_id in self._active
106
+
107
+
108
+ def _format_wait(seconds: float) -> str:
109
+ if seconds < 60:
110
+ return "less than a minute"
111
+ mins = int(math.ceil(seconds / 60))
112
+ return f"~{mins} minute{'s' if mins != 1 else ''}"
113
+
114
+
115
+ def _queue_modal_markdown(position: int, total: int, estimated_wait_sec: float) -> str:
116
+ wait_str = _format_wait(estimated_wait_sec)
117
+ mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
118
+ return f"""## Kimodo Demo — Please Wait
119
+
120
+ This demo runs with limited capacity.
121
+ Each user gets **{mins} minute{"s" if mins != 1 else ""}** of interactive time.
122
+
123
+ **Your position in queue:** {position} / {total}
124
+
125
+ **Estimated wait:** {wait_str}
126
+
127
+ Please keep this tab open — the demo will start automatically when it's your turn.
128
+
129
+ ---
130
+ *Want unlimited access? [Duplicate this Space]({DUPLICATE_SPACE_URL}) or clone the [GitHub repo]({GITHUB_REPO_URL}) to run locally!*
131
+ """
132
+
133
+
134
+ def _welcome_modal_markdown() -> str:
135
+ mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
136
+ return f"""## Welcome to Kimodo Demo
137
+
138
+ You have been granted a **{mins}-minute** demo session.
139
+ Your session timer has started.
140
+
141
+ Click the button below to begin!
142
+ """
143
+
144
+
145
+ def _expiry_modal_markdown() -> str:
146
+ mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
147
+ return f"""## Session Expired
148
+
149
+ Your {mins}-minute demo session has ended.
150
+ Thank you for trying Kimodo!
151
+
152
+ Refresh this page to rejoin the queue, or [duplicate this Space]({DUPLICATE_SPACE_URL}) for unlimited access.
153
+ """
154
+
155
+
156
+ class QueueManager:
157
+ """Orchestrates HF mode: queue modals, welcome modal, session timer, promotion."""
158
+
159
+ def __init__(
160
+ self,
161
+ queue: UserQueue,
162
+ server: viser.ViserServer,
163
+ setup_demo_for_client: Callable[[viser.ClientHandle], None],
164
+ cleanup_session: Callable[[int], None],
165
+ ) -> None:
166
+ self._queue = queue
167
+ self._server = server
168
+ self._setup_demo_for_client = setup_demo_for_client
169
+ self._cleanup_session = cleanup_session
170
+ self._max_seconds = queue._max_seconds
171
+
172
+ self._queue_modal_handles: dict[int, tuple[Any, Any]] = {}
173
+ self._welcome_modal_handles: dict[int, Any] = {}
174
+ self._expiry_timers: dict[int, threading.Timer] = {}
175
+ self._lock = threading.Lock()
176
+ self._refresh_stop = threading.Event()
177
+ self._refresh_thread = threading.Thread(
178
+ target=self._queue_modal_refresh_loop,
179
+ name="queue-modal-refresh",
180
+ daemon=True,
181
+ )
182
+ self._refresh_thread.start()
183
+
184
+ def _queue_modal_refresh_loop(self) -> None:
185
+ """Periodically refresh queue modals so position, total, and estimated wait stay current."""
186
+ while not self._refresh_stop.wait(timeout=QUEUE_MODAL_REFRESH_INTERVAL_SEC):
187
+ self._update_all_queue_modals()
188
+
189
+ def on_client_connect(self, client: viser.ClientHandle) -> None:
190
+ """Handle new connection: activate if slot free, else enqueue and show queue modal."""
191
+ client_id = client.client_id
192
+ if self._queue.try_activate(client_id):
193
+ try:
194
+ self._setup_demo_for_client(client)
195
+ except RuntimeError as e:
196
+ if "CUDA error" in str(e):
197
+ print(f"CUDA error while setting up client {client_id}: {e}")
198
+ return
199
+ raise
200
+ self._start_session_timer(client_id)
201
+ self._show_welcome_modal(client)
202
+ else:
203
+ self._queue.enqueue(client_id)
204
+ self._show_queue_modal(client)
205
+ self._update_all_queue_modals()
206
+
207
+ def on_client_disconnect(self, client_id: int) -> None:
208
+ """Remove from queue/active, cancel timer, promote next if was active.
209
+
210
+ Session/scene cleanup is done by the demo's on_client_disconnect.
211
+ """
212
+ with self._lock:
213
+ self._expiry_timers.pop(client_id, None)
214
+ self._queue_modal_handles.pop(client_id, None)
215
+ self._welcome_modal_handles.pop(client_id, None)
216
+ was_active = self._queue.remove(client_id)
217
+ if was_active:
218
+ self._promote_next_user()
219
+ else:
220
+ self._update_all_queue_modals()
221
+
222
+ def _show_queue_modal(self, client: viser.ClientHandle) -> None:
223
+ client_id = client.client_id
224
+ pos, total = self._queue.get_queue_position(client_id) or (0, 0)
225
+ wait_sec = self._queue.get_estimated_wait_seconds(client_id)
226
+ md_content = _queue_modal_markdown(pos, total, wait_sec)
227
+
228
+ modal = client.gui.add_modal(
229
+ "Kimodo Demo — Please Wait",
230
+ size="xl",
231
+ show_close_button=False,
232
+ )
233
+ with modal:
234
+ md_handle = client.gui.add_markdown(md_content)
235
+ with self._lock:
236
+ self._queue_modal_handles[client_id] = (modal, md_handle)
237
+
238
+ def _show_quick_start_modal(self, client: viser.ClientHandle) -> None:
239
+ """Show the quick start instructions modal (same as non-HF mode)."""
240
+ with client.gui.add_modal(
241
+ "Welcome — Quick Start",
242
+ size="xl",
243
+ show_close_button=True,
244
+ save_choice="kimodo.demo.quick_start_ack",
245
+ ) as quick_start_modal:
246
+ client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
247
+ client.gui.add_button("Got it (don't remind me again)").on_click(lambda _: quick_start_modal.close())
248
+
249
+ def _show_welcome_modal(self, client: viser.ClientHandle) -> None:
250
+ client_id = client.client_id
251
+
252
+ def _on_start_demo(_: Any) -> None:
253
+ modal.close()
254
+ self._show_quick_start_modal(client)
255
+
256
+ modal = client.gui.add_modal(
257
+ "Welcome to Kimodo Demo",
258
+ size="xl",
259
+ show_close_button=True,
260
+ )
261
+ with modal:
262
+ client.gui.add_markdown(_welcome_modal_markdown())
263
+ client.gui.add_button("Start Demo").on_click(_on_start_demo)
264
+ with self._lock:
265
+ self._welcome_modal_handles[client_id] = modal
266
+
267
+ def _update_all_queue_modals(self) -> None:
268
+ with self._lock:
269
+ handles = list(self._queue_modal_handles.items())
270
+ for client_id, (modal, md_handle) in handles:
271
+ pos_total = self._queue.get_queue_position(client_id)
272
+ if pos_total is None:
273
+ continue
274
+ pos, total = pos_total
275
+ wait_sec = self._queue.get_estimated_wait_seconds(client_id)
276
+ try:
277
+ md_handle.content = _queue_modal_markdown(pos, total, wait_sec)
278
+ except Exception:
279
+ pass
280
+
281
+ def _promote_next_user(self) -> None:
282
+ promoted_id = self._queue.promote_next()
283
+ if promoted_id is None:
284
+ return
285
+ clients = self._server.get_clients()
286
+ client = clients.get(promoted_id)
287
+ if client is None:
288
+ return
289
+ with self._lock:
290
+ old = self._queue_modal_handles.pop(promoted_id, None)
291
+ if old is not None:
292
+ try:
293
+ old[0].close()
294
+ except Exception:
295
+ pass
296
+ try:
297
+ self._setup_demo_for_client(client)
298
+ except RuntimeError as e:
299
+ if "CUDA error" in str(e):
300
+ print(f"CUDA error while setting up client {promoted_id}: {e}")
301
+ return
302
+ raise
303
+ self._start_session_timer(promoted_id)
304
+ self._show_welcome_modal(client)
305
+ self._update_all_queue_modals()
306
+
307
+ def _start_session_timer(self, client_id: int) -> None:
308
+ def on_expiry() -> None:
309
+ self._on_session_expired(client_id)
310
+
311
+ t = threading.Timer(self._max_seconds, on_expiry)
312
+ t.daemon = True
313
+ with self._lock:
314
+ self._expiry_timers[client_id] = t
315
+ t.start()
316
+
317
+ def _on_session_expired(self, client_id: int) -> None:
318
+ with self._lock:
319
+ self._expiry_timers.pop(client_id, None)
320
+ if not self._queue.is_active(client_id):
321
+ return
322
+ self._queue.remove(client_id)
323
+ clients = self._server.get_clients()
324
+ client = clients.get(client_id)
325
+ if client is not None:
326
+ try:
327
+ with client.gui.add_modal(
328
+ "Session Expired",
329
+ size="lg",
330
+ show_close_button=False,
331
+ ) as modal_ctx:
332
+ client.gui.add_markdown(_expiry_modal_markdown())
333
+ except Exception:
334
+ pass
335
+ self._cleanup_session(client_id)
336
+ self._promote_next_user()
kimodo/demo/state.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ import kimodo.viz.viser_utils as viser_utils
10
+ import viser
11
+ from kimodo.skeleton import SkeletonBase
12
+ from kimodo.viz.viser_utils import GuiElements
13
+
14
+ from .config import (
15
+ DEFAULT_CUR_DURATION,
16
+ DEFAULT_MODEL,
17
+ DEFAULT_PLAYBACK_SPEED,
18
+ )
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class ModelBundle:
23
+ model: object
24
+ motion_rep: object
25
+ skeleton: SkeletonBase
26
+ model_fps: float
27
+
28
+
29
+ @dataclass
30
+ class ClientSession:
31
+ """Per-client session data."""
32
+
33
+ client: viser.ClientHandle
34
+ gui_elements: GuiElements
35
+ motions: dict # character_name -> CharacterMotion
36
+ constraints: dict[str, viser_utils.ConstraintSet] = field(default_factory=dict)
37
+ timeline_data: object = None
38
+ frame_idx: int = 0
39
+ playing: bool = False
40
+ playback_speed: float = DEFAULT_PLAYBACK_SPEED
41
+ cur_duration: float = DEFAULT_CUR_DURATION
42
+ max_frame_idx: int = 100 # will be updated based on model_fps
43
+ updating_motions: bool = False
44
+ edit_mode: bool = False
45
+ model_name: str = DEFAULT_MODEL
46
+ model_fps: float = 0.0
47
+ skeleton: SkeletonBase | None = None
48
+ motion_rep: object | None = None
49
+ examples_base_dir: str = ""
50
+ example_dict: dict[str, str] = field(default_factory=dict)
51
+ gui_examples_dropdown: Optional[viser.GuiInputHandle] = None
52
+ gui_save_example_path_text: Optional[viser.GuiInputHandle] = None
53
+ gui_model_selector: Optional[viser.GuiInputHandle] = None
54
+ last_prompt_texts: Optional[list[str]] = None
55
+ last_prompt_embeddings: Optional[torch.Tensor] = None
56
+ last_prompt_lengths: Optional[list[int]] = None
57
+ edit_mode_snapshot: Optional[dict[int, dict[str, object]]] = None
58
+ undo_drag_snapshot: Optional[dict[str, object]] = None
59
+ show_only_current_constraint: bool = False # False = Show All, True = Show only Current