rydlrKE commited on
Commit
2f71493
·
verified ·
1 Parent(s): 96ea6fb

Fix ZeroGPU spaces import order (before CUDA init; commit e28bffd)

Browse files
Files changed (1) hide show
  1. kimodo/demo/app.py +742 -0
kimodo/demo/app.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # ============================================================================
5
+ # CRITICAL: Import spaces FIRST, before any CUDA-related packages.
6
+ # This prevents "CUDA has been initialized before importing spaces" error.
7
+ # ============================================================================
8
+ try:
9
+ import spaces # noqa: F401 - imported early for ZeroGPU compatibility
10
+ except ImportError:
11
+ pass # Not running on HF Spaces with ZeroGPU
12
+
13
+ import base64
14
+ import logging
15
+ import os
16
+ import shutil
17
+ import threading
18
+ import time
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ import viser
25
+ from kimodo.assets import DEMO_ASSETS_ROOT
26
+ from kimodo.model.load_model import load_model
27
+ from kimodo.model.registry import resolve_model_name
28
+ from kimodo.runtime.device import select_runtime_device
29
+ from kimodo.skeleton import SkeletonBase, SOMASkeleton30
30
+ from kimodo.tools import load_json
31
+ from kimodo.viz import viser_utils
32
+ from kimodo.viz.viser_utils import (
33
+ Character,
34
+ CharacterMotion,
35
+ EEJointsKeyframeSet,
36
+ FullbodyKeyframeSet,
37
+ RootKeyframe2DSet,
38
+ )
39
+ from viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage
40
+
41
+ from . import generation, ui
42
+ from .config import (
43
+ DARK_THEME,
44
+ DEFAULT_CUR_DURATION,
45
+ DEFAULT_MODEL,
46
+ DEFAULT_PLAYBACK_SPEED,
47
+ DEFAULT_PROMPT,
48
+ DEMO_UI_QUICK_START_MODAL_MD,
49
+ EXAMPLES_ROOT_DIR,
50
+ HF_MODE,
51
+ LIGHT_THEME,
52
+ MAX_ACTIVE_USERS,
53
+ MAX_DURATION,
54
+ MAX_SESSION_MINUTES,
55
+ MIN_DURATION,
56
+ MODEL_EXAMPLES_DIRS,
57
+ MODEL_NAMES,
58
+ SERVER_NAME,
59
+ SERVER_PORT,
60
+ )
61
+ from .embedding_cache import CachedTextEncoder
62
+ from .queue_manager import QueueManager, UserQueue
63
+ from .state import ClientSession, ModelBundle
64
+
65
+ # Hosted runtimes (HF/Cloud Run) often send non-WS probes to the WS endpoint.
66
+ # Suppress noisy stack traces for these expected invalid handshakes.
67
+ logging.getLogger("websockets.server").setLevel(logging.CRITICAL)
68
+ logging.getLogger("websockets.asyncio.server").setLevel(logging.CRITICAL)
69
+
70
+
71
+ class Demo:
72
+ def __init__(self, default_model_name: str = DEFAULT_MODEL):
73
+ # In hosted HF runtimes (including ZeroGPU), touching CUDA too early can
74
+ # crash startup before queue-managed inference starts.
75
+ requested_device = os.getenv("KIMODO_DEVICE")
76
+ running_in_space = bool(os.getenv("SPACE_ID")) or os.getenv("SYSTEM", "").strip().lower() == "spaces"
77
+ if requested_device is None and (HF_MODE or running_in_space):
78
+ requested_device = "cpu"
79
+ self.device = select_runtime_device(requested=requested_device)
80
+ print(f"Using device: {self.device}")
81
+ self.models: dict[str, ModelBundle] = {}
82
+ resolved = resolve_model_name(default_model_name, "Kimodo")
83
+ if resolved not in MODEL_NAMES:
84
+ raise ValueError(f"Unknown model '{default_model_name}'. Expected one of: {MODEL_NAMES}")
85
+ self.default_model_name = resolved
86
+ self.defer_model_load = os.getenv("KIMODO_DEFER_MODEL_LOAD", "true").strip().lower() in {
87
+ "1",
88
+ "true",
89
+ "yes",
90
+ "on",
91
+ }
92
+ self.ensure_examples_layout()
93
+ if self.defer_model_load:
94
+ print("Deferring model load until first active client session.")
95
+ else:
96
+ self.load_model(self.default_model_name)
97
+
98
+ # Serialize GPU-bound generation across all clients
99
+ self._generation_lock = threading.Lock()
100
+ self._cuda_healthy = True
101
+
102
+ # Per-client sessions
103
+ self.client_sessions: dict[int, ClientSession] = {}
104
+ self.start_direction_markers: dict[int, viser_utils.WaypointMesh] = {}
105
+ self.grid_handles: dict[int, viser.GridHandle] = {}
106
+
107
+ self.server = viser.ViserServer(
108
+ host=SERVER_NAME,
109
+ port=SERVER_PORT,
110
+ label="Kimodo",
111
+ enable_camera_keyboard_controls=False, # don't move the camera with the arrow keys
112
+ )
113
+ self.server.scene.world_axes.visible = False # used for debugging
114
+ self.server.scene.set_up_direction("+y")
115
+
116
+ # Register callbacks for session handling
117
+ self.server.on_client_connect(self.on_client_connect)
118
+ self.server.on_client_disconnect(self.on_client_disconnect)
119
+
120
+ # HF mode: queue and session limit
121
+ if HF_MODE:
122
+ self.user_queue = UserQueue(MAX_ACTIVE_USERS, MAX_SESSION_MINUTES)
123
+ self.queue_manager = QueueManager(
124
+ queue=self.user_queue,
125
+ server=self.server,
126
+ setup_demo_for_client=self._setup_demo_for_client,
127
+ cleanup_session=self._cleanup_session_for_client,
128
+ )
129
+ else:
130
+ self.user_queue = None
131
+ self.queue_manager = None
132
+
133
+ # create grid and floor
134
+ self.floor_len = 20.0 # meters
135
+
136
+ def ensure_examples_layout(self) -> None:
137
+ os.makedirs(EXAMPLES_ROOT_DIR, exist_ok=True)
138
+ for model_dir in MODEL_EXAMPLES_DIRS.values():
139
+ os.makedirs(model_dir, exist_ok=True)
140
+
141
+ for entry in os.listdir(EXAMPLES_ROOT_DIR):
142
+ if entry in MODEL_EXAMPLES_DIRS:
143
+ continue
144
+ src = os.path.join(EXAMPLES_ROOT_DIR, entry)
145
+ if not os.path.isdir(src):
146
+ continue
147
+ dst = os.path.join(
148
+ MODEL_EXAMPLES_DIRS.get(DEFAULT_MODEL, next(iter(MODEL_EXAMPLES_DIRS.values()))),
149
+ entry,
150
+ )
151
+ if not os.path.exists(dst):
152
+ shutil.move(src, dst)
153
+
154
+ def get_examples_base_dir(self, model_name: str, absolute: bool = True) -> str:
155
+ return MODEL_EXAMPLES_DIRS[model_name]
156
+
157
+ def load_model(self, model_name: str) -> ModelBundle:
158
+ if model_name in self.models:
159
+ return self.models[model_name]
160
+
161
+ print(f"Loading model {model_name}...")
162
+ try:
163
+ model = load_model(modelname=model_name, device=self.device)
164
+ except Exception as e:
165
+ print(
166
+ "Error loading model during Kimodo startup. "
167
+ "This often means the text encoder server is not running, the Hugging Face token is missing, "
168
+ "or the gated text encoder model cannot be accessed."
169
+ )
170
+ print(f"Original error: {type(e).__name__}: {e}")
171
+ raise e
172
+
173
+ if hasattr(model, "text_encoder"):
174
+ model.text_encoder = CachedTextEncoder(model.text_encoder, model_name=model_name)
175
+
176
+ skeleton = model.motion_rep.skeleton
177
+ if isinstance(skeleton, SOMASkeleton30):
178
+ skeleton = skeleton.somaskel77.to(model.device)
179
+ bundle = ModelBundle(
180
+ model=model,
181
+ motion_rep=model.motion_rep,
182
+ skeleton=skeleton,
183
+ model_fps=model.motion_rep.fps,
184
+ )
185
+ self.models[model_name] = bundle
186
+ print(f"Model {model_name} loaded successfully")
187
+ self.prewarm_embedding_cache(model_name, bundle.model)
188
+ return bundle
189
+
190
+ def prewarm_embedding_cache(self, model_name: str, model: object) -> None:
191
+ encoder = getattr(model, "text_encoder", None)
192
+ if not isinstance(encoder, CachedTextEncoder):
193
+ return
194
+
195
+ prompt_set = set()
196
+ prompt_set.add(DEFAULT_PROMPT)
197
+
198
+ examples_dir = MODEL_EXAMPLES_DIRS.get(model_name)
199
+ if examples_dir and os.path.isdir(examples_dir):
200
+ for entry in os.listdir(examples_dir):
201
+ example_dir = os.path.join(examples_dir, entry)
202
+ if not os.path.isdir(example_dir):
203
+ continue
204
+ meta_path = os.path.join(example_dir, "meta.json")
205
+ if not os.path.exists(meta_path):
206
+ continue
207
+ try:
208
+ meta = load_json(meta_path)
209
+ except Exception:
210
+ continue
211
+ for prompt in meta.get("prompts_text", []):
212
+ if isinstance(prompt, str):
213
+ prompt_set.add(prompt)
214
+
215
+ if prompt_set:
216
+ try:
217
+ encoder.prewarm(list(prompt_set))
218
+ except Exception as error:
219
+ # Startup should not fail if text encoder is still warming up.
220
+ error_str = str(error)
221
+ if "Encoder initialization failed" in error_str:
222
+ print(
223
+ f"⚠️ WARNING: Text encoder failed to initialize: {error}\n"
224
+ f" This usually means the HuggingFace gated model cannot be accessed.\n"
225
+ f" To fix: Set HF_TOKEN environment variable with access to Meta-Llama-3-8B.\n"
226
+ f" Alternatively: Generation will still work but text embeddings may fail."
227
+ )
228
+ else:
229
+ print(f"Warning: embedding prewarm skipped: {error}")
230
+
231
+ def build_constraint_tracks(
232
+ self, client: viser.ClientHandle, skeleton: SkeletonBase
233
+ ) -> dict[str, viser_utils.ConstraintSet]:
234
+ return {
235
+ "Full-Body": FullbodyKeyframeSet(
236
+ name="Full-Body",
237
+ server=client,
238
+ skeleton=skeleton,
239
+ ),
240
+ "End-Effectors": EEJointsKeyframeSet(
241
+ name="End-Effectors",
242
+ server=client,
243
+ skeleton=skeleton,
244
+ ),
245
+ "2D Root": RootKeyframe2DSet(
246
+ name="2D Root",
247
+ server=client,
248
+ skeleton=skeleton,
249
+ ),
250
+ }
251
+
252
+ def set_timeline_defaults(self, timeline, model_fps: float) -> None:
253
+ timeline.set_defaults(
254
+ default_text=DEFAULT_PROMPT,
255
+ default_duration=int(DEFAULT_CUR_DURATION * model_fps - 1),
256
+ min_duration=int(MIN_DURATION * model_fps - 1), # 2 seconds minimum,
257
+ max_duration=int(
258
+ MAX_DURATION * model_fps - 1 # - NB_TRANSITION_FRAMES
259
+ ), # 10 seconds maximum, minus the transition frames, if needed
260
+ default_num_frames_zoom=int(1.10 * 10 * model_fps), # a bit more than the max
261
+ max_frames_zoom=1000,
262
+ fps=model_fps,
263
+ )
264
+
265
+ def _apply_constraint_overlay_visibility(self, session: ClientSession) -> None:
266
+ """Apply show-all vs show-only-current-frame to constraint overlays."""
267
+ only_frame = session.frame_idx if session.show_only_current_constraint else None
268
+ for constraint in session.constraints.values():
269
+ constraint.set_overlay_visibility(only_frame)
270
+
271
+ def set_constraint_tracks_visible(self, session: ClientSession, visible: bool) -> None:
272
+ timeline = session.client.timeline
273
+ timeline_data = session.timeline_data
274
+ if timeline_data.get("constraint_tracks_visible", True) == visible:
275
+ return
276
+
277
+ with timeline_data["keyframe_update_lock"]:
278
+ if visible:
279
+ for track_id, track_info in timeline_data["tracks"].items():
280
+ timeline.add_track(
281
+ track_info["name"],
282
+ track_type=track_info.get("track_type", "keyframe"),
283
+ color=track_info.get("color"),
284
+ height_scale=track_info.get("height_scale", 1.0),
285
+ uuid=track_id,
286
+ )
287
+
288
+ for keyframe_id, keyframe_data in timeline_data["keyframes"].items():
289
+ timeline.add_keyframe(
290
+ track_id=keyframe_data["track_id"],
291
+ frame=keyframe_data["frame"],
292
+ value=keyframe_data.get("value"),
293
+ opacity=keyframe_data.get("opacity", 1.0),
294
+ locked=keyframe_data.get("locked", False),
295
+ uuid=keyframe_id,
296
+ )
297
+
298
+ for interval_id, interval_data in timeline_data["intervals"].items():
299
+ timeline.add_interval(
300
+ track_id=interval_data["track_id"],
301
+ start_frame=interval_data["start_frame_idx"],
302
+ end_frame=interval_data["end_frame_idx"],
303
+ value=interval_data.get("value"),
304
+ opacity=interval_data.get("opacity", 1.0),
305
+ locked=interval_data.get("locked", False),
306
+ uuid=interval_id,
307
+ )
308
+ else:
309
+ for track_id in list(timeline_data["tracks"].keys()):
310
+ timeline.remove_track(track_id)
311
+
312
+ timeline_data["constraint_tracks_visible"] = visible
313
+
314
+ def _cleanup_session_for_client(self, client_id: int) -> None:
315
+ """Remove session and scene state for a client (e.g. on session expiry)."""
316
+ if client_id in self.client_sessions:
317
+ del self.client_sessions[client_id]
318
+ self.start_direction_markers.pop(client_id, None)
319
+ self.grid_handles.pop(client_id, None)
320
+
321
+ def _setup_demo_for_client(self, client: viser.ClientHandle) -> None:
322
+ """Initialize scene, GUI, and session state for a client (no modals)."""
323
+ self.setup_scene(client)
324
+
325
+ model_bundle = self.load_model(self.default_model_name)
326
+
327
+ # Initialize each empty constraint track
328
+ constraint_tracks = self.build_constraint_tracks(client, model_bundle.skeleton)
329
+
330
+ # Create GUI elements for this client
331
+ (
332
+ gui_elements,
333
+ timeline_tracks,
334
+ example_dict,
335
+ gui_examples_dropdown,
336
+ gui_save_example_path_text,
337
+ gui_model_selector,
338
+ ) = ui.create_gui(
339
+ demo=self,
340
+ client=client,
341
+ model_name=self.default_model_name,
342
+ model_fps=model_bundle.model_fps,
343
+ )
344
+ timeline_data = {
345
+ "tracks": timeline_tracks,
346
+ "tracks_ids": {val["name"]: key for key, val in timeline_tracks.items()},
347
+ "keyframes": {},
348
+ "intervals": {},
349
+ "keyframe_update_lock": threading.Lock(),
350
+ "keyframe_move_timers": {},
351
+ "pending_keyframe_moves": {}, # keyframe_id -> new_frame
352
+ "constraint_tracks_visible": True,
353
+ "dense_path_after_release_timer": None,
354
+ }
355
+
356
+ # Initialize session state
357
+ cur_duration = DEFAULT_CUR_DURATION
358
+ max_frame_idx = int(cur_duration * model_bundle.model_fps - 1)
359
+
360
+ session = ClientSession(
361
+ client=client,
362
+ gui_elements=gui_elements,
363
+ motions={},
364
+ constraints=constraint_tracks,
365
+ timeline_data=timeline_data,
366
+ frame_idx=0,
367
+ playing=False,
368
+ playback_speed=DEFAULT_PLAYBACK_SPEED,
369
+ cur_duration=cur_duration,
370
+ max_frame_idx=max_frame_idx,
371
+ updating_motions=False,
372
+ edit_mode=False,
373
+ model_name=self.default_model_name,
374
+ model_fps=model_bundle.model_fps,
375
+ skeleton=model_bundle.skeleton,
376
+ motion_rep=model_bundle.motion_rep,
377
+ examples_base_dir=self.get_examples_base_dir(self.default_model_name, absolute=True),
378
+ example_dict=example_dict,
379
+ gui_examples_dropdown=gui_examples_dropdown,
380
+ gui_save_example_path_text=gui_save_example_path_text,
381
+ gui_model_selector=gui_model_selector,
382
+ )
383
+
384
+ self.client_sessions[client.client_id] = session
385
+
386
+ # Initialize default character for this client
387
+ self.add_character_motion(client, session.skeleton)
388
+
389
+ def on_client_connect(self, client: viser.ClientHandle) -> None:
390
+ """Initialize GUI and state for each new client."""
391
+ print(f"Client {client.client_id} connected")
392
+
393
+ if HF_MODE and self.queue_manager is not None:
394
+ self.queue_manager.on_client_connect(client)
395
+ else:
396
+ # Show quick start popup when a browser client connects (non-HF mode).
397
+ with client.gui.add_modal(
398
+ "Welcome — Quick Start",
399
+ size="xl",
400
+ show_close_button=True,
401
+ save_choice="kimodo.demo.quick_start_ack",
402
+ ) as modal:
403
+ client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
404
+ client.gui.add_button("Got it (don't remind me again)").on_click(lambda _event: modal.close())
405
+ self._setup_demo_for_client(client)
406
+
407
+ def setup_scene(self, client: viser.ClientHandle) -> None:
408
+ self.configure_theme(client)
409
+ client.camera.position = np.array(
410
+ [2.7417358737841426, 1.8790455698853281, 7.675741569777456],
411
+ dtype=np.float64,
412
+ )
413
+ client.camera.look_at = np.array([0.0, 0.0, 0.0], dtype=np.float64)
414
+ client.camera.up_direction = np.array(
415
+ [-1.1102230246251568e-16, 1.0, 1.3596310734468913e-32],
416
+ dtype=np.float64,
417
+ )
418
+ client.camera.fov = np.deg2rad(45.0)
419
+ grid_handle = client.scene.add_grid(
420
+ "/grid",
421
+ width=self.floor_len,
422
+ height=self.floor_len,
423
+ wxyz=viser.transforms.SO3.from_x_radians(-np.pi / 2.0).wxyz,
424
+ position=(0.0, 0.0001, 0.0),
425
+ fade_distance=3 * self.floor_len,
426
+ section_color=LIGHT_THEME["grid"],
427
+ infinite_grid=True,
428
+ )
429
+ self.grid_handles[client.client_id] = grid_handle
430
+ # marker for origin
431
+ origin_waypoint = viser_utils.WaypointMesh(
432
+ "/origin_waypoint",
433
+ client,
434
+ position=np.array([0.0, 0.0, 0.0]),
435
+ heading=np.array([0.0, 1.0]),
436
+ color=(0, 0, 255),
437
+ )
438
+ self.start_direction_markers[client.client_id] = origin_waypoint
439
+
440
+ def on_client_disconnect(self, client: viser.ClientHandle) -> None:
441
+ """Clean up when client disconnects."""
442
+ print(f"Client {client.client_id} disconnected")
443
+ client_id = client.client_id
444
+
445
+ if HF_MODE and self.queue_manager is not None:
446
+ self.queue_manager.on_client_disconnect(client_id)
447
+
448
+ self._cleanup_session_for_client(client_id)
449
+
450
+ def set_start_direction_visible(self, client_id: int, visible: bool) -> None:
451
+ marker = self.start_direction_markers.get(client_id)
452
+ if marker is None:
453
+ return
454
+ marker.set_visible(visible)
455
+
456
+ def client_active(self, client_id: int) -> bool:
457
+ return client_id in self.client_sessions
458
+
459
+ def add_character_motion(
460
+ self,
461
+ client: viser.ClientHandle,
462
+ skeleton: SkeletonBase,
463
+ joints_pos: Optional[torch.Tensor] = None,
464
+ joints_rot: Optional[torch.Tensor] = None,
465
+ foot_contacts: Optional[torch.Tensor] = None,
466
+ ) -> None:
467
+ client_id = client.client_id
468
+ if not self.client_active(client_id):
469
+ return
470
+ session = self.client_sessions[client_id]
471
+
472
+ ci = len(session.motions)
473
+ character_name = f"character{ci}"
474
+ # build character skeleton and skinning mesh
475
+ if "g1" in session.model_name:
476
+ mesh_mode = "g1_stl"
477
+ elif "smplx" in session.model_name:
478
+ mesh_mode = "smplx_skin"
479
+ elif "soma" in session.model_name:
480
+ if session.gui_elements.gui_use_soma_layer_checkbox.value:
481
+ mesh_mode = "soma_layer_skin"
482
+ else:
483
+ mesh_mode = "soma_skin"
484
+ else:
485
+ raise ValueError("The model name is not recognized for skinning.")
486
+
487
+ new_character = Character(
488
+ character_name,
489
+ client,
490
+ skeleton,
491
+ create_skeleton_mesh=True,
492
+ create_skinned_mesh=True,
493
+ visible_skeleton=False, # don't show immediately
494
+ visible_skinned_mesh=False, # don't show immediately
495
+ skinned_mesh_opacity=session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value,
496
+ show_foot_contacts=session.gui_elements.gui_viz_foot_contacts_checkbox.value,
497
+ dark_mode=session.gui_elements.gui_dark_mode_checkbox.value,
498
+ mesh_mode=mesh_mode,
499
+ gui_use_soma_layer_checkbox=session.gui_elements.gui_use_soma_layer_checkbox,
500
+ )
501
+
502
+ # if no motion given, initialize to character default (rest) pose for one frame
503
+ init_joints_pos, init_joints_rot = new_character.get_pose()
504
+ if joints_pos is None:
505
+ joints_pos = init_joints_pos[None].repeat(session.max_frame_idx + 1, 1, 1)
506
+ if joints_rot is None:
507
+ joints_rot = init_joints_rot[None].repeat(session.max_frame_idx + 1, 1, 1, 1)
508
+
509
+ new_motion = CharacterMotion(new_character, joints_pos, joints_rot, foot_contacts)
510
+ # save the motion in our dict
511
+ session.motions[character_name] = new_motion
512
+
513
+ # put the character at the right frame
514
+ new_motion.set_frame(session.frame_idx)
515
+
516
+ # put them visible with a small delay
517
+ # so that the set_frame function has time to finish
518
+ def _set_visibility():
519
+ new_motion.character.set_skinned_mesh_visibility(session.gui_elements.gui_viz_skinned_mesh_checkbox.value)
520
+ new_motion.character.set_skeleton_visibility(session.gui_elements.gui_viz_skeleton_checkbox.value)
521
+
522
+ timer = threading.Timer(
523
+ 0.2, # 0.2s delay
524
+ _set_visibility,
525
+ )
526
+ timer.start()
527
+
528
+ def clear_motions(self, client_id: int) -> None:
529
+ if not self.client_active(client_id):
530
+ return
531
+ session = self.client_sessions[client_id]
532
+ for motion in list(session.motions.values()):
533
+ motion.clear()
534
+ session.motions.clear()
535
+
536
+ def compute_model_constraints_lst(
537
+ self,
538
+ session: ClientSession,
539
+ model_bundle: ModelBundle,
540
+ num_frames: int,
541
+ ):
542
+ return generation.compute_model_constraints_lst(session, model_bundle, num_frames, self.device)
543
+
544
+ def check_cuda_health(self) -> bool:
545
+ """Check if CUDA is still functional.
546
+
547
+ Trigger auto-restart if corrupted.
548
+ """
549
+ if self.device == "cpu":
550
+ return True
551
+ try:
552
+ torch.tensor([1.0], device=self.device) + torch.tensor([1.0], device=self.device)
553
+ return True
554
+ except RuntimeError as e:
555
+ if "device-side assert" in str(e) or "CUDA error" in str(e):
556
+ if self._cuda_healthy:
557
+ self._cuda_healthy = False
558
+ print("FATAL: CUDA context is corrupted (device-side assert). " "The process must be restarted.")
559
+ self._trigger_restart()
560
+ return False
561
+ raise
562
+
563
+ def _trigger_restart(self) -> None:
564
+ """Exit the process so the HF Space (or systemd/Docker) can restart it."""
565
+ import sys
566
+
567
+ print("Initiating automatic restart due to unrecoverable CUDA error...")
568
+ sys.stdout.flush()
569
+ sys.stderr.flush()
570
+ os._exit(1)
571
+
572
+ def generate(
573
+ self,
574
+ client: viser.ClientHandle,
575
+ prompts: list[str],
576
+ num_frames: list[int],
577
+ num_samples: int,
578
+ seed: int,
579
+ diffusion_steps: int,
580
+ cfg_weight: Optional[list[float]] = None,
581
+ cfg_type: Optional[str] = None,
582
+ postprocess_parameters: Optional[dict] = None,
583
+ transitions_parameters: Optional[dict] = None,
584
+ real_robot_rotations: bool = False,
585
+ ) -> None:
586
+ if not self._cuda_healthy:
587
+ raise RuntimeError("CUDA is in a corrupted state. The space is restarting...")
588
+
589
+ locked = self._generation_lock.acquire(blocking=False)
590
+ if not locked:
591
+ waiting_notif = client.add_notification(
592
+ title="Waiting for GPU...",
593
+ body="Another generation is in progress. Yours will start automatically.",
594
+ loading=True,
595
+ with_close_button=False,
596
+ )
597
+ self._generation_lock.acquire()
598
+ waiting_notif.remove()
599
+
600
+ try:
601
+ session = self.client_sessions[client.client_id]
602
+ model_bundle = self.load_model(session.model_name)
603
+ generation.generate(
604
+ client=client,
605
+ session=session,
606
+ model_bundle=model_bundle,
607
+ prompts=prompts,
608
+ num_frames=num_frames,
609
+ num_samples=num_samples,
610
+ seed=seed,
611
+ diffusion_steps=diffusion_steps,
612
+ cfg_weight=cfg_weight,
613
+ cfg_type=cfg_type,
614
+ postprocess_parameters=postprocess_parameters,
615
+ transitions_parameters=transitions_parameters,
616
+ real_robot_rotations=real_robot_rotations,
617
+ device=self.device,
618
+ clear_motions=self.clear_motions,
619
+ add_character_motion=self.add_character_motion,
620
+ )
621
+ finally:
622
+ self._generation_lock.release()
623
+
624
+ def set_frame(self, client_id: int, frame_idx: int, update_timeline: bool = True):
625
+ if not self.client_active(client_id):
626
+ return
627
+
628
+ session = self.client_sessions[client_id]
629
+
630
+ session.frame_idx = frame_idx
631
+ if update_timeline:
632
+ session.client.timeline.set_current_frame(frame_idx)
633
+ for motion in list(session.motions.values()):
634
+ motion.set_frame(frame_idx)
635
+ self._apply_constraint_overlay_visibility(session)
636
+
637
+ def run(self) -> None:
638
+ last_loop_time = time.perf_counter()
639
+ last_cuda_check_time = 0.0
640
+ while True:
641
+ loop_start_time = time.perf_counter()
642
+ delta_time = loop_start_time - last_loop_time
643
+ last_loop_time = loop_start_time
644
+
645
+ if self.models:
646
+ # the max playback speed is 2x the model fps (from gui_playback_speed_buttons)
647
+ playback_fps = max(bundle.model_fps for bundle in self.models.values()) * 2.0
648
+ else:
649
+ playback_fps = 60.0
650
+
651
+ # update each client session independently
652
+ # copy to a list first to avoid changing size if client disconnects
653
+ for client_id, session in list(self.client_sessions.items()):
654
+ if not session.playing:
655
+ continue
656
+ if session.model_fps <= 0:
657
+ continue
658
+
659
+ # Time-based stepping keeps playback smooth even if loop cadence jitters.
660
+ session.playback_time_accumulator += max(0.0, delta_time) * max(0.0, session.playback_speed)
661
+ frame_period = 1.0 / session.model_fps
662
+ if session.playback_time_accumulator < frame_period:
663
+ continue
664
+
665
+ frames_to_advance = int(session.playback_time_accumulator / frame_period)
666
+ session.playback_time_accumulator -= frames_to_advance * frame_period
667
+ frame_count = max(1, session.max_frame_idx + 1)
668
+ new_frame_idx = (session.frame_idx + frames_to_advance) % frame_count
669
+
670
+ # make sure the client is still active before updating the frame
671
+ if self.client_active(client_id):
672
+ self.set_frame(client_id, new_frame_idx)
673
+
674
+ if loop_start_time - last_cuda_check_time >= 5.0:
675
+ self.check_cuda_health()
676
+ last_cuda_check_time = loop_start_time
677
+
678
+ time_remaining = max(0.0, 1.0 / playback_fps - (time.perf_counter() - loop_start_time))
679
+ time.sleep(time_remaining)
680
+
681
+ def configure_theme(
682
+ self,
683
+ client: viser.ClientHandle,
684
+ dark_mode: bool = False,
685
+ titlebar_dark_mode_checkbox_uuid: str | None = None,
686
+ ):
687
+ # Sync grid color with theme (light vs dark)
688
+ theme = DARK_THEME if dark_mode else LIGHT_THEME
689
+ grid_handle = self.grid_handles.get(client.client_id)
690
+ if grid_handle is not None:
691
+ grid_handle.section_color = theme["grid"]
692
+
693
+ #
694
+ # setup theme
695
+ #
696
+ buttons = (
697
+ TitlebarButton(
698
+ text="Documentation",
699
+ icon="Description",
700
+ href="https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html",
701
+ ),
702
+ TitlebarButton(
703
+ text="Project Page",
704
+ icon=None,
705
+ href="https://research.nvidia.com/labs/sil/projects/kimodo/",
706
+ ),
707
+ TitlebarButton(
708
+ text="Github",
709
+ icon="GitHub",
710
+ href="https://github.com/nv-tlabs/kimodo",
711
+ ),
712
+ )
713
+ assets_dir = DEMO_ASSETS_ROOT
714
+ logo_light_path = assets_dir / "nvidia_logo.png"
715
+ logo_dark_path = assets_dir / "nvidia_logo_dark.png"
716
+ if logo_light_path.exists():
717
+ light_b64 = base64.standard_b64encode(logo_light_path.read_bytes()).decode("ascii")
718
+ dark_b64 = (
719
+ base64.standard_b64encode(logo_dark_path.read_bytes()).decode("ascii")
720
+ if logo_dark_path.exists()
721
+ else None
722
+ )
723
+ image = TitlebarImage(
724
+ image_url_light=f"data:image/png;base64,{light_b64}",
725
+ image_url_dark=(f"data:image/png;base64,{dark_b64}" if dark_b64 else None),
726
+ image_alt="NVIDIA",
727
+ href="https://www.nvidia.com/",
728
+ )
729
+ else:
730
+ image = None
731
+ titlebar_theme = TitlebarConfig(buttons=buttons, image=image, title_text="Movimento")
732
+ client.gui.set_panel_label("Movimento")
733
+ client.gui.configure_theme(
734
+ titlebar_content=titlebar_theme,
735
+ control_layout="floating", # "floating", # ['floating', 'collapsible', 'fixed']
736
+ control_width="large", # ['small', 'medium', 'large']
737
+ dark_mode=dark_mode,
738
+ show_logo=False, # hide viser logo on bottom left corner
739
+ show_share_button=False,
740
+ titlebar_dark_mode_checkbox_uuid=titlebar_dark_mode_checkbox_uuid,
741
+ brand_color=(152, 189, 255), # (60, 131, 0), # (R, G, B) tuple
742
+ )