rydlrKE commited on
Commit
2a5255e
·
verified ·
1 Parent(s): 2f71493

Fix CUDA import order - import spaces before torch (commit e28bffd)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ __pycache__/ui.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # ruff: noqa: I001
5
+ import argparse
6
+
7
+ from kimodo.model import DEFAULT_MODEL
8
+ from kimodo.model.registry import resolve_model_name
9
+
10
+ from .app import Demo
11
+
12
+
13
+ def main() -> None:
14
+ parser = argparse.ArgumentParser(description="Run the kimodo demo UI.")
15
+ parser.add_argument(
16
+ "--model",
17
+ type=str,
18
+ default=DEFAULT_MODEL,
19
+ help="Default model to load (e.g. Kimodo-SOMA-RP-v1, kimodo-soma-rp, or SOMA).",
20
+ )
21
+ args = parser.parse_args()
22
+
23
+ resolved = resolve_model_name(args.model, "Kimodo")
24
+ demo = Demo(default_model_name=resolved)
25
+ demo.run()
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
__main__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Entry point for `python -m kimodo.demo`."""
4
+
5
+ from kimodo.demo import main
6
+
7
+ if __name__ == "__main__":
8
+ main()
__pycache__/app.cpython-312.pyc ADDED
Binary file (34.2 kB). View file
 
__pycache__/state.cpython-312.pyc ADDED
Binary file (3.01 kB). View file
 
__pycache__/ui.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:023646c6b51d238f0aedf942347c389dfa7384d452f7b1b99aa34bcef924706e
3
+ size 145667
_qwen_prompts.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Qwen-on-Fireworks helper for auto-generating multi-text-prompt batches."""
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import os
8
+ import re
9
+ import urllib.error
10
+ import urllib.request
11
+
12
+ _MODEL = "accounts/fireworks/models/qwen3p6-27b"
13
+ _BASE = "https://api.fireworks.ai/inference/v1"
14
+
15
+ _SYSTEM = """\
16
+ You are a motion-description writer for a single humanoid character in a 3D animation system.
17
+ Given a scene context and the character's recent motion history, output ONLY a JSON object:
18
+
19
+ {"texts": ["<action phrase 1>", ...], "durations": [<seconds float>, ...]}
20
+
21
+ Rules:
22
+ - Return between 1 and requested_actions short, vivid action phrases that flow naturally from each other.
23
+ - Each phrase describes one distinct physical motion (e.g. "walks forward briskly", "pivots left and crouches").
24
+ - Each duration is between 2.0 and 8.0 seconds.
25
+ - texts and durations must have the same length.
26
+ - Do NOT repeat phrases from history.
27
+ - Return raw JSON only — no markdown, no explanation.
28
+ """
29
+
30
+
31
+ def _call_fireworks(messages: list[dict]) -> str:
32
+ api_key = os.environ.get("FIREWORKS_API_KEY", "").strip()
33
+ if not api_key:
34
+ raise RuntimeError("FIREWORKS_API_KEY is not set")
35
+ body = json.dumps({
36
+ "model": _MODEL,
37
+ "messages": messages,
38
+ "max_tokens": 400,
39
+ "temperature": 0.85,
40
+ }).encode()
41
+ req = urllib.request.Request(
42
+ f"{_BASE}/chat/completions",
43
+ data=body,
44
+ headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
45
+ method="POST",
46
+ )
47
+ try:
48
+ with urllib.request.urlopen(req, timeout=40) as r:
49
+ return json.loads(r.read())["choices"][0]["message"]["content"]
50
+ except urllib.error.HTTPError as e:
51
+ raise RuntimeError(f"Fireworks {e.code}: {e.read().decode(errors='ignore')}") from e
52
+
53
+
54
+ def _parse(raw: str) -> dict:
55
+ text = raw.strip()
56
+ m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
57
+ text = m.group(1) if m else text
58
+ s, e = text.find("{"), text.rfind("}")
59
+ return json.loads(text[s:e + 1])
60
+
61
+
62
+ def _fallback(offset: int) -> dict:
63
+ phrases = [
64
+ "walks forward at a steady pace",
65
+ "turns smoothly to the left",
66
+ "pauses and surveys the surroundings",
67
+ "steps forward and gestures expressively",
68
+ "crouches down then rises back up",
69
+ "sidesteps to the right with purpose",
70
+ ]
71
+ n = len(phrases)
72
+ return {
73
+ "texts": [phrases[(offset + i) % n] for i in range(3)],
74
+ "durations": [3.0, 3.5, 3.0],
75
+ }
76
+
77
+
78
+ def call_qwen_for_prompts(
79
+ scene: str,
80
+ history: list[str],
81
+ requested_actions: int = 5,
82
+ ) -> tuple[dict, list[str]]:
83
+ """Call Qwen to produce the next batch of motion prompts.
84
+
85
+ Returns (batch_dict, updated_history).
86
+ batch_dict has keys "texts" and "durations".
87
+ Raises RuntimeError on API failure (caller may fall back).
88
+ """
89
+ user_msg = (
90
+ f"Scene: {scene or 'a character moving continuously in 3D space'}\n"
91
+ f"Motion history (do not repeat): {json.dumps(history[-12:])}\n\n"
92
+ f"requested_actions: {max(1, min(10, int(requested_actions)))}\n"
93
+ "Generate the next batch of motion prompts."
94
+ )
95
+ try:
96
+ raw = _call_fireworks([{"role": "system", "content": _SYSTEM}, {"role": "user", "content": user_msg}])
97
+ batch = _parse(raw)
98
+ if not isinstance(batch.get("texts"), list) or not isinstance(batch.get("durations"), list):
99
+ raise ValueError("Missing texts or durations")
100
+ n = min(len(batch["texts"]), len(batch["durations"]))
101
+ batch["texts"] = batch["texts"][:n]
102
+ batch["durations"] = batch["durations"][:n]
103
+ except Exception:
104
+ batch = _fallback(len(history))
105
+
106
+ new_history = history + list(batch["texts"])
107
+ return batch, new_history
app.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # ============================================================================
5
+ # CRITICAL: Import spaces FIRST, before any CUDA-related packages.
6
+ # This prevents "CUDA has been initialized before importing spaces" error.
7
+ # ============================================================================
8
+ try:
9
+ import spaces # noqa: F401 - imported early for ZeroGPU compatibility
10
+ except ImportError:
11
+ pass # Not running on HF Spaces with ZeroGPU
12
+
13
+ import base64
14
+ import logging
15
+ import os
16
+ import shutil
17
+ import threading
18
+ import time
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ import viser
25
+ from kimodo.assets import DEMO_ASSETS_ROOT
26
+ from kimodo.model.load_model import load_model
27
+ from kimodo.model.registry import resolve_model_name
28
+ from kimodo.runtime.device import select_runtime_device
29
+ from kimodo.skeleton import SkeletonBase, SOMASkeleton30
30
+ from kimodo.tools import load_json
31
+ from kimodo.viz import viser_utils
32
+ from kimodo.viz.viser_utils import (
33
+ Character,
34
+ CharacterMotion,
35
+ EEJointsKeyframeSet,
36
+ FullbodyKeyframeSet,
37
+ RootKeyframe2DSet,
38
+ )
39
+ from viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage
40
+
41
+ from . import generation, ui
42
+ from .config import (
43
+ DARK_THEME,
44
+ DEFAULT_CUR_DURATION,
45
+ DEFAULT_MODEL,
46
+ DEFAULT_PLAYBACK_SPEED,
47
+ DEFAULT_PROMPT,
48
+ DEMO_UI_QUICK_START_MODAL_MD,
49
+ EXAMPLES_ROOT_DIR,
50
+ HF_MODE,
51
+ LIGHT_THEME,
52
+ MAX_ACTIVE_USERS,
53
+ MAX_DURATION,
54
+ MAX_SESSION_MINUTES,
55
+ MIN_DURATION,
56
+ MODEL_EXAMPLES_DIRS,
57
+ MODEL_NAMES,
58
+ SERVER_NAME,
59
+ SERVER_PORT,
60
+ )
61
+ from .embedding_cache import CachedTextEncoder
62
+ from .queue_manager import QueueManager, UserQueue
63
+ from .state import ClientSession, ModelBundle
64
+
65
+ # Hosted runtimes (HF/Cloud Run) often send non-WS probes to the WS endpoint.
66
+ # Suppress noisy stack traces for these expected invalid handshakes.
67
+ logging.getLogger("websockets.server").setLevel(logging.CRITICAL)
68
+ logging.getLogger("websockets.asyncio.server").setLevel(logging.CRITICAL)
69
+
70
+
71
+ class Demo:
72
+ def __init__(self, default_model_name: str = DEFAULT_MODEL):
73
+ # In hosted HF runtimes (including ZeroGPU), touching CUDA too early can
74
+ # crash startup before queue-managed inference starts.
75
+ requested_device = os.getenv("KIMODO_DEVICE")
76
+ running_in_space = bool(os.getenv("SPACE_ID")) or os.getenv("SYSTEM", "").strip().lower() == "spaces"
77
+ if requested_device is None and (HF_MODE or running_in_space):
78
+ requested_device = "cpu"
79
+ self.device = select_runtime_device(requested=requested_device)
80
+ print(f"Using device: {self.device}")
81
+ self.models: dict[str, ModelBundle] = {}
82
+ resolved = resolve_model_name(default_model_name, "Kimodo")
83
+ if resolved not in MODEL_NAMES:
84
+ raise ValueError(f"Unknown model '{default_model_name}'. Expected one of: {MODEL_NAMES}")
85
+ self.default_model_name = resolved
86
+ self.defer_model_load = os.getenv("KIMODO_DEFER_MODEL_LOAD", "true").strip().lower() in {
87
+ "1",
88
+ "true",
89
+ "yes",
90
+ "on",
91
+ }
92
+ self.ensure_examples_layout()
93
+ if self.defer_model_load:
94
+ print("Deferring model load until first active client session.")
95
+ else:
96
+ self.load_model(self.default_model_name)
97
+
98
+ # Serialize GPU-bound generation across all clients
99
+ self._generation_lock = threading.Lock()
100
+ self._cuda_healthy = True
101
+
102
+ # Per-client sessions
103
+ self.client_sessions: dict[int, ClientSession] = {}
104
+ self.start_direction_markers: dict[int, viser_utils.WaypointMesh] = {}
105
+ self.grid_handles: dict[int, viser.GridHandle] = {}
106
+
107
+ self.server = viser.ViserServer(
108
+ host=SERVER_NAME,
109
+ port=SERVER_PORT,
110
+ label="Kimodo",
111
+ enable_camera_keyboard_controls=False, # don't move the camera with the arrow keys
112
+ )
113
+ self.server.scene.world_axes.visible = False # used for debugging
114
+ self.server.scene.set_up_direction("+y")
115
+
116
+ # Register callbacks for session handling
117
+ self.server.on_client_connect(self.on_client_connect)
118
+ self.server.on_client_disconnect(self.on_client_disconnect)
119
+
120
+ # HF mode: queue and session limit
121
+ if HF_MODE:
122
+ self.user_queue = UserQueue(MAX_ACTIVE_USERS, MAX_SESSION_MINUTES)
123
+ self.queue_manager = QueueManager(
124
+ queue=self.user_queue,
125
+ server=self.server,
126
+ setup_demo_for_client=self._setup_demo_for_client,
127
+ cleanup_session=self._cleanup_session_for_client,
128
+ )
129
+ else:
130
+ self.user_queue = None
131
+ self.queue_manager = None
132
+
133
+ # create grid and floor
134
+ self.floor_len = 20.0 # meters
135
+
136
+ def ensure_examples_layout(self) -> None:
137
+ os.makedirs(EXAMPLES_ROOT_DIR, exist_ok=True)
138
+ for model_dir in MODEL_EXAMPLES_DIRS.values():
139
+ os.makedirs(model_dir, exist_ok=True)
140
+
141
+ for entry in os.listdir(EXAMPLES_ROOT_DIR):
142
+ if entry in MODEL_EXAMPLES_DIRS:
143
+ continue
144
+ src = os.path.join(EXAMPLES_ROOT_DIR, entry)
145
+ if not os.path.isdir(src):
146
+ continue
147
+ dst = os.path.join(
148
+ MODEL_EXAMPLES_DIRS.get(DEFAULT_MODEL, next(iter(MODEL_EXAMPLES_DIRS.values()))),
149
+ entry,
150
+ )
151
+ if not os.path.exists(dst):
152
+ shutil.move(src, dst)
153
+
154
+ def get_examples_base_dir(self, model_name: str, absolute: bool = True) -> str:
155
+ return MODEL_EXAMPLES_DIRS[model_name]
156
+
157
+ def load_model(self, model_name: str) -> ModelBundle:
158
+ if model_name in self.models:
159
+ return self.models[model_name]
160
+
161
+ print(f"Loading model {model_name}...")
162
+ try:
163
+ model = load_model(modelname=model_name, device=self.device)
164
+ except Exception as e:
165
+ print(
166
+ "Error loading model during Kimodo startup. "
167
+ "This often means the text encoder server is not running, the Hugging Face token is missing, "
168
+ "or the gated text encoder model cannot be accessed."
169
+ )
170
+ print(f"Original error: {type(e).__name__}: {e}")
171
+ raise e
172
+
173
+ if hasattr(model, "text_encoder"):
174
+ model.text_encoder = CachedTextEncoder(model.text_encoder, model_name=model_name)
175
+
176
+ skeleton = model.motion_rep.skeleton
177
+ if isinstance(skeleton, SOMASkeleton30):
178
+ skeleton = skeleton.somaskel77.to(model.device)
179
+ bundle = ModelBundle(
180
+ model=model,
181
+ motion_rep=model.motion_rep,
182
+ skeleton=skeleton,
183
+ model_fps=model.motion_rep.fps,
184
+ )
185
+ self.models[model_name] = bundle
186
+ print(f"Model {model_name} loaded successfully")
187
+ self.prewarm_embedding_cache(model_name, bundle.model)
188
+ return bundle
189
+
190
+ def prewarm_embedding_cache(self, model_name: str, model: object) -> None:
191
+ encoder = getattr(model, "text_encoder", None)
192
+ if not isinstance(encoder, CachedTextEncoder):
193
+ return
194
+
195
+ prompt_set = set()
196
+ prompt_set.add(DEFAULT_PROMPT)
197
+
198
+ examples_dir = MODEL_EXAMPLES_DIRS.get(model_name)
199
+ if examples_dir and os.path.isdir(examples_dir):
200
+ for entry in os.listdir(examples_dir):
201
+ example_dir = os.path.join(examples_dir, entry)
202
+ if not os.path.isdir(example_dir):
203
+ continue
204
+ meta_path = os.path.join(example_dir, "meta.json")
205
+ if not os.path.exists(meta_path):
206
+ continue
207
+ try:
208
+ meta = load_json(meta_path)
209
+ except Exception:
210
+ continue
211
+ for prompt in meta.get("prompts_text", []):
212
+ if isinstance(prompt, str):
213
+ prompt_set.add(prompt)
214
+
215
+ if prompt_set:
216
+ try:
217
+ encoder.prewarm(list(prompt_set))
218
+ except Exception as error:
219
+ # Startup should not fail if text encoder is still warming up.
220
+ error_str = str(error)
221
+ if "Encoder initialization failed" in error_str:
222
+ print(
223
+ f"⚠️ WARNING: Text encoder failed to initialize: {error}\n"
224
+ f" This usually means the HuggingFace gated model cannot be accessed.\n"
225
+ f" To fix: Set HF_TOKEN environment variable with access to Meta-Llama-3-8B.\n"
226
+ f" Alternatively: Generation will still work but text embeddings may fail."
227
+ )
228
+ else:
229
+ print(f"Warning: embedding prewarm skipped: {error}")
230
+
231
+ def build_constraint_tracks(
232
+ self, client: viser.ClientHandle, skeleton: SkeletonBase
233
+ ) -> dict[str, viser_utils.ConstraintSet]:
234
+ return {
235
+ "Full-Body": FullbodyKeyframeSet(
236
+ name="Full-Body",
237
+ server=client,
238
+ skeleton=skeleton,
239
+ ),
240
+ "End-Effectors": EEJointsKeyframeSet(
241
+ name="End-Effectors",
242
+ server=client,
243
+ skeleton=skeleton,
244
+ ),
245
+ "2D Root": RootKeyframe2DSet(
246
+ name="2D Root",
247
+ server=client,
248
+ skeleton=skeleton,
249
+ ),
250
+ }
251
+
252
+ def set_timeline_defaults(self, timeline, model_fps: float) -> None:
253
+ timeline.set_defaults(
254
+ default_text=DEFAULT_PROMPT,
255
+ default_duration=int(DEFAULT_CUR_DURATION * model_fps - 1),
256
+ min_duration=int(MIN_DURATION * model_fps - 1), # 2 seconds minimum,
257
+ max_duration=int(
258
+ MAX_DURATION * model_fps - 1 # - NB_TRANSITION_FRAMES
259
+ ), # 10 seconds maximum, minus the transition frames, if needed
260
+ default_num_frames_zoom=int(1.10 * 10 * model_fps), # a bit more than the max
261
+ max_frames_zoom=1000,
262
+ fps=model_fps,
263
+ )
264
+
265
+ def _apply_constraint_overlay_visibility(self, session: ClientSession) -> None:
266
+ """Apply show-all vs show-only-current-frame to constraint overlays."""
267
+ only_frame = session.frame_idx if session.show_only_current_constraint else None
268
+ for constraint in session.constraints.values():
269
+ constraint.set_overlay_visibility(only_frame)
270
+
271
+ def set_constraint_tracks_visible(self, session: ClientSession, visible: bool) -> None:
272
+ timeline = session.client.timeline
273
+ timeline_data = session.timeline_data
274
+ if timeline_data.get("constraint_tracks_visible", True) == visible:
275
+ return
276
+
277
+ with timeline_data["keyframe_update_lock"]:
278
+ if visible:
279
+ for track_id, track_info in timeline_data["tracks"].items():
280
+ timeline.add_track(
281
+ track_info["name"],
282
+ track_type=track_info.get("track_type", "keyframe"),
283
+ color=track_info.get("color"),
284
+ height_scale=track_info.get("height_scale", 1.0),
285
+ uuid=track_id,
286
+ )
287
+
288
+ for keyframe_id, keyframe_data in timeline_data["keyframes"].items():
289
+ timeline.add_keyframe(
290
+ track_id=keyframe_data["track_id"],
291
+ frame=keyframe_data["frame"],
292
+ value=keyframe_data.get("value"),
293
+ opacity=keyframe_data.get("opacity", 1.0),
294
+ locked=keyframe_data.get("locked", False),
295
+ uuid=keyframe_id,
296
+ )
297
+
298
+ for interval_id, interval_data in timeline_data["intervals"].items():
299
+ timeline.add_interval(
300
+ track_id=interval_data["track_id"],
301
+ start_frame=interval_data["start_frame_idx"],
302
+ end_frame=interval_data["end_frame_idx"],
303
+ value=interval_data.get("value"),
304
+ opacity=interval_data.get("opacity", 1.0),
305
+ locked=interval_data.get("locked", False),
306
+ uuid=interval_id,
307
+ )
308
+ else:
309
+ for track_id in list(timeline_data["tracks"].keys()):
310
+ timeline.remove_track(track_id)
311
+
312
+ timeline_data["constraint_tracks_visible"] = visible
313
+
314
+ def _cleanup_session_for_client(self, client_id: int) -> None:
315
+ """Remove session and scene state for a client (e.g. on session expiry)."""
316
+ if client_id in self.client_sessions:
317
+ del self.client_sessions[client_id]
318
+ self.start_direction_markers.pop(client_id, None)
319
+ self.grid_handles.pop(client_id, None)
320
+
321
+ def _setup_demo_for_client(self, client: viser.ClientHandle) -> None:
322
+ """Initialize scene, GUI, and session state for a client (no modals)."""
323
+ self.setup_scene(client)
324
+
325
+ model_bundle = self.load_model(self.default_model_name)
326
+
327
+ # Initialize each empty constraint track
328
+ constraint_tracks = self.build_constraint_tracks(client, model_bundle.skeleton)
329
+
330
+ # Create GUI elements for this client
331
+ (
332
+ gui_elements,
333
+ timeline_tracks,
334
+ example_dict,
335
+ gui_examples_dropdown,
336
+ gui_save_example_path_text,
337
+ gui_model_selector,
338
+ ) = ui.create_gui(
339
+ demo=self,
340
+ client=client,
341
+ model_name=self.default_model_name,
342
+ model_fps=model_bundle.model_fps,
343
+ )
344
+ timeline_data = {
345
+ "tracks": timeline_tracks,
346
+ "tracks_ids": {val["name"]: key for key, val in timeline_tracks.items()},
347
+ "keyframes": {},
348
+ "intervals": {},
349
+ "keyframe_update_lock": threading.Lock(),
350
+ "keyframe_move_timers": {},
351
+ "pending_keyframe_moves": {}, # keyframe_id -> new_frame
352
+ "constraint_tracks_visible": True,
353
+ "dense_path_after_release_timer": None,
354
+ }
355
+
356
+ # Initialize session state
357
+ cur_duration = DEFAULT_CUR_DURATION
358
+ max_frame_idx = int(cur_duration * model_bundle.model_fps - 1)
359
+
360
+ session = ClientSession(
361
+ client=client,
362
+ gui_elements=gui_elements,
363
+ motions={},
364
+ constraints=constraint_tracks,
365
+ timeline_data=timeline_data,
366
+ frame_idx=0,
367
+ playing=False,
368
+ playback_speed=DEFAULT_PLAYBACK_SPEED,
369
+ cur_duration=cur_duration,
370
+ max_frame_idx=max_frame_idx,
371
+ updating_motions=False,
372
+ edit_mode=False,
373
+ model_name=self.default_model_name,
374
+ model_fps=model_bundle.model_fps,
375
+ skeleton=model_bundle.skeleton,
376
+ motion_rep=model_bundle.motion_rep,
377
+ examples_base_dir=self.get_examples_base_dir(self.default_model_name, absolute=True),
378
+ example_dict=example_dict,
379
+ gui_examples_dropdown=gui_examples_dropdown,
380
+ gui_save_example_path_text=gui_save_example_path_text,
381
+ gui_model_selector=gui_model_selector,
382
+ )
383
+
384
+ self.client_sessions[client.client_id] = session
385
+
386
+ # Initialize default character for this client
387
+ self.add_character_motion(client, session.skeleton)
388
+
389
+ def on_client_connect(self, client: viser.ClientHandle) -> None:
390
+ """Initialize GUI and state for each new client."""
391
+ print(f"Client {client.client_id} connected")
392
+
393
+ if HF_MODE and self.queue_manager is not None:
394
+ self.queue_manager.on_client_connect(client)
395
+ else:
396
+ # Show quick start popup when a browser client connects (non-HF mode).
397
+ with client.gui.add_modal(
398
+ "Welcome — Quick Start",
399
+ size="xl",
400
+ show_close_button=True,
401
+ save_choice="kimodo.demo.quick_start_ack",
402
+ ) as modal:
403
+ client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
404
+ client.gui.add_button("Got it (don't remind me again)").on_click(lambda _event: modal.close())
405
+ self._setup_demo_for_client(client)
406
+
407
+ def setup_scene(self, client: viser.ClientHandle) -> None:
408
+ self.configure_theme(client)
409
+ client.camera.position = np.array(
410
+ [2.7417358737841426, 1.8790455698853281, 7.675741569777456],
411
+ dtype=np.float64,
412
+ )
413
+ client.camera.look_at = np.array([0.0, 0.0, 0.0], dtype=np.float64)
414
+ client.camera.up_direction = np.array(
415
+ [-1.1102230246251568e-16, 1.0, 1.3596310734468913e-32],
416
+ dtype=np.float64,
417
+ )
418
+ client.camera.fov = np.deg2rad(45.0)
419
+ grid_handle = client.scene.add_grid(
420
+ "/grid",
421
+ width=self.floor_len,
422
+ height=self.floor_len,
423
+ wxyz=viser.transforms.SO3.from_x_radians(-np.pi / 2.0).wxyz,
424
+ position=(0.0, 0.0001, 0.0),
425
+ fade_distance=3 * self.floor_len,
426
+ section_color=LIGHT_THEME["grid"],
427
+ infinite_grid=True,
428
+ )
429
+ self.grid_handles[client.client_id] = grid_handle
430
+ # marker for origin
431
+ origin_waypoint = viser_utils.WaypointMesh(
432
+ "/origin_waypoint",
433
+ client,
434
+ position=np.array([0.0, 0.0, 0.0]),
435
+ heading=np.array([0.0, 1.0]),
436
+ color=(0, 0, 255),
437
+ )
438
+ self.start_direction_markers[client.client_id] = origin_waypoint
439
+
440
+ def on_client_disconnect(self, client: viser.ClientHandle) -> None:
441
+ """Clean up when client disconnects."""
442
+ print(f"Client {client.client_id} disconnected")
443
+ client_id = client.client_id
444
+
445
+ if HF_MODE and self.queue_manager is not None:
446
+ self.queue_manager.on_client_disconnect(client_id)
447
+
448
+ self._cleanup_session_for_client(client_id)
449
+
450
+ def set_start_direction_visible(self, client_id: int, visible: bool) -> None:
451
+ marker = self.start_direction_markers.get(client_id)
452
+ if marker is None:
453
+ return
454
+ marker.set_visible(visible)
455
+
456
+ def client_active(self, client_id: int) -> bool:
457
+ return client_id in self.client_sessions
458
+
459
+ def add_character_motion(
460
+ self,
461
+ client: viser.ClientHandle,
462
+ skeleton: SkeletonBase,
463
+ joints_pos: Optional[torch.Tensor] = None,
464
+ joints_rot: Optional[torch.Tensor] = None,
465
+ foot_contacts: Optional[torch.Tensor] = None,
466
+ ) -> None:
467
+ client_id = client.client_id
468
+ if not self.client_active(client_id):
469
+ return
470
+ session = self.client_sessions[client_id]
471
+
472
+ ci = len(session.motions)
473
+ character_name = f"character{ci}"
474
+ # build character skeleton and skinning mesh
475
+ if "g1" in session.model_name:
476
+ mesh_mode = "g1_stl"
477
+ elif "smplx" in session.model_name:
478
+ mesh_mode = "smplx_skin"
479
+ elif "soma" in session.model_name:
480
+ if session.gui_elements.gui_use_soma_layer_checkbox.value:
481
+ mesh_mode = "soma_layer_skin"
482
+ else:
483
+ mesh_mode = "soma_skin"
484
+ else:
485
+ raise ValueError("The model name is not recognized for skinning.")
486
+
487
+ new_character = Character(
488
+ character_name,
489
+ client,
490
+ skeleton,
491
+ create_skeleton_mesh=True,
492
+ create_skinned_mesh=True,
493
+ visible_skeleton=False, # don't show immediately
494
+ visible_skinned_mesh=False, # don't show immediately
495
+ skinned_mesh_opacity=session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value,
496
+ show_foot_contacts=session.gui_elements.gui_viz_foot_contacts_checkbox.value,
497
+ dark_mode=session.gui_elements.gui_dark_mode_checkbox.value,
498
+ mesh_mode=mesh_mode,
499
+ gui_use_soma_layer_checkbox=session.gui_elements.gui_use_soma_layer_checkbox,
500
+ )
501
+
502
+ # if no motion given, initialize to character default (rest) pose for one frame
503
+ init_joints_pos, init_joints_rot = new_character.get_pose()
504
+ if joints_pos is None:
505
+ joints_pos = init_joints_pos[None].repeat(session.max_frame_idx + 1, 1, 1)
506
+ if joints_rot is None:
507
+ joints_rot = init_joints_rot[None].repeat(session.max_frame_idx + 1, 1, 1, 1)
508
+
509
+ new_motion = CharacterMotion(new_character, joints_pos, joints_rot, foot_contacts)
510
+ # save the motion in our dict
511
+ session.motions[character_name] = new_motion
512
+
513
+ # put the character at the right frame
514
+ new_motion.set_frame(session.frame_idx)
515
+
516
+ # put them visible with a small delay
517
+ # so that the set_frame function has time to finish
518
+ def _set_visibility():
519
+ new_motion.character.set_skinned_mesh_visibility(session.gui_elements.gui_viz_skinned_mesh_checkbox.value)
520
+ new_motion.character.set_skeleton_visibility(session.gui_elements.gui_viz_skeleton_checkbox.value)
521
+
522
+ timer = threading.Timer(
523
+ 0.2, # 0.2s delay
524
+ _set_visibility,
525
+ )
526
+ timer.start()
527
+
528
+ def clear_motions(self, client_id: int) -> None:
529
+ if not self.client_active(client_id):
530
+ return
531
+ session = self.client_sessions[client_id]
532
+ for motion in list(session.motions.values()):
533
+ motion.clear()
534
+ session.motions.clear()
535
+
536
+ def compute_model_constraints_lst(
537
+ self,
538
+ session: ClientSession,
539
+ model_bundle: ModelBundle,
540
+ num_frames: int,
541
+ ):
542
+ return generation.compute_model_constraints_lst(session, model_bundle, num_frames, self.device)
543
+
544
+ def check_cuda_health(self) -> bool:
545
+ """Check if CUDA is still functional.
546
+
547
+ Trigger auto-restart if corrupted.
548
+ """
549
+ if self.device == "cpu":
550
+ return True
551
+ try:
552
+ torch.tensor([1.0], device=self.device) + torch.tensor([1.0], device=self.device)
553
+ return True
554
+ except RuntimeError as e:
555
+ if "device-side assert" in str(e) or "CUDA error" in str(e):
556
+ if self._cuda_healthy:
557
+ self._cuda_healthy = False
558
+ print("FATAL: CUDA context is corrupted (device-side assert). " "The process must be restarted.")
559
+ self._trigger_restart()
560
+ return False
561
+ raise
562
+
563
+ def _trigger_restart(self) -> None:
564
+ """Exit the process so the HF Space (or systemd/Docker) can restart it."""
565
+ import sys
566
+
567
+ print("Initiating automatic restart due to unrecoverable CUDA error...")
568
+ sys.stdout.flush()
569
+ sys.stderr.flush()
570
+ os._exit(1)
571
+
572
+ def generate(
573
+ self,
574
+ client: viser.ClientHandle,
575
+ prompts: list[str],
576
+ num_frames: list[int],
577
+ num_samples: int,
578
+ seed: int,
579
+ diffusion_steps: int,
580
+ cfg_weight: Optional[list[float]] = None,
581
+ cfg_type: Optional[str] = None,
582
+ postprocess_parameters: Optional[dict] = None,
583
+ transitions_parameters: Optional[dict] = None,
584
+ real_robot_rotations: bool = False,
585
+ ) -> None:
586
+ if not self._cuda_healthy:
587
+ raise RuntimeError("CUDA is in a corrupted state. The space is restarting...")
588
+
589
+ locked = self._generation_lock.acquire(blocking=False)
590
+ if not locked:
591
+ waiting_notif = client.add_notification(
592
+ title="Waiting for GPU...",
593
+ body="Another generation is in progress. Yours will start automatically.",
594
+ loading=True,
595
+ with_close_button=False,
596
+ )
597
+ self._generation_lock.acquire()
598
+ waiting_notif.remove()
599
+
600
+ try:
601
+ session = self.client_sessions[client.client_id]
602
+ model_bundle = self.load_model(session.model_name)
603
+ generation.generate(
604
+ client=client,
605
+ session=session,
606
+ model_bundle=model_bundle,
607
+ prompts=prompts,
608
+ num_frames=num_frames,
609
+ num_samples=num_samples,
610
+ seed=seed,
611
+ diffusion_steps=diffusion_steps,
612
+ cfg_weight=cfg_weight,
613
+ cfg_type=cfg_type,
614
+ postprocess_parameters=postprocess_parameters,
615
+ transitions_parameters=transitions_parameters,
616
+ real_robot_rotations=real_robot_rotations,
617
+ device=self.device,
618
+ clear_motions=self.clear_motions,
619
+ add_character_motion=self.add_character_motion,
620
+ )
621
+ finally:
622
+ self._generation_lock.release()
623
+
624
+ def set_frame(self, client_id: int, frame_idx: int, update_timeline: bool = True):
625
+ if not self.client_active(client_id):
626
+ return
627
+
628
+ session = self.client_sessions[client_id]
629
+
630
+ session.frame_idx = frame_idx
631
+ if update_timeline:
632
+ session.client.timeline.set_current_frame(frame_idx)
633
+ for motion in list(session.motions.values()):
634
+ motion.set_frame(frame_idx)
635
+ self._apply_constraint_overlay_visibility(session)
636
+
637
+ def run(self) -> None:
638
+ last_loop_time = time.perf_counter()
639
+ last_cuda_check_time = 0.0
640
+ while True:
641
+ loop_start_time = time.perf_counter()
642
+ delta_time = loop_start_time - last_loop_time
643
+ last_loop_time = loop_start_time
644
+
645
+ if self.models:
646
+ # the max playback speed is 2x the model fps (from gui_playback_speed_buttons)
647
+ playback_fps = max(bundle.model_fps for bundle in self.models.values()) * 2.0
648
+ else:
649
+ playback_fps = 60.0
650
+
651
+ # update each client session independently
652
+ # copy to a list first to avoid changing size if client disconnects
653
+ for client_id, session in list(self.client_sessions.items()):
654
+ if not session.playing:
655
+ continue
656
+ if session.model_fps <= 0:
657
+ continue
658
+
659
+ # Time-based stepping keeps playback smooth even if loop cadence jitters.
660
+ session.playback_time_accumulator += max(0.0, delta_time) * max(0.0, session.playback_speed)
661
+ frame_period = 1.0 / session.model_fps
662
+ if session.playback_time_accumulator < frame_period:
663
+ continue
664
+
665
+ frames_to_advance = int(session.playback_time_accumulator / frame_period)
666
+ session.playback_time_accumulator -= frames_to_advance * frame_period
667
+ frame_count = max(1, session.max_frame_idx + 1)
668
+ new_frame_idx = (session.frame_idx + frames_to_advance) % frame_count
669
+
670
+ # make sure the client is still active before updating the frame
671
+ if self.client_active(client_id):
672
+ self.set_frame(client_id, new_frame_idx)
673
+
674
+ if loop_start_time - last_cuda_check_time >= 5.0:
675
+ self.check_cuda_health()
676
+ last_cuda_check_time = loop_start_time
677
+
678
+ time_remaining = max(0.0, 1.0 / playback_fps - (time.perf_counter() - loop_start_time))
679
+ time.sleep(time_remaining)
680
+
681
+ def configure_theme(
682
+ self,
683
+ client: viser.ClientHandle,
684
+ dark_mode: bool = False,
685
+ titlebar_dark_mode_checkbox_uuid: str | None = None,
686
+ ):
687
+ # Sync grid color with theme (light vs dark)
688
+ theme = DARK_THEME if dark_mode else LIGHT_THEME
689
+ grid_handle = self.grid_handles.get(client.client_id)
690
+ if grid_handle is not None:
691
+ grid_handle.section_color = theme["grid"]
692
+
693
+ #
694
+ # setup theme
695
+ #
696
+ buttons = (
697
+ TitlebarButton(
698
+ text="Documentation",
699
+ icon="Description",
700
+ href="https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html",
701
+ ),
702
+ TitlebarButton(
703
+ text="Project Page",
704
+ icon=None,
705
+ href="https://research.nvidia.com/labs/sil/projects/kimodo/",
706
+ ),
707
+ TitlebarButton(
708
+ text="Github",
709
+ icon="GitHub",
710
+ href="https://github.com/nv-tlabs/kimodo",
711
+ ),
712
+ )
713
+ assets_dir = DEMO_ASSETS_ROOT
714
+ logo_light_path = assets_dir / "nvidia_logo.png"
715
+ logo_dark_path = assets_dir / "nvidia_logo_dark.png"
716
+ if logo_light_path.exists():
717
+ light_b64 = base64.standard_b64encode(logo_light_path.read_bytes()).decode("ascii")
718
+ dark_b64 = (
719
+ base64.standard_b64encode(logo_dark_path.read_bytes()).decode("ascii")
720
+ if logo_dark_path.exists()
721
+ else None
722
+ )
723
+ image = TitlebarImage(
724
+ image_url_light=f"data:image/png;base64,{light_b64}",
725
+ image_url_dark=(f"data:image/png;base64,{dark_b64}" if dark_b64 else None),
726
+ image_alt="NVIDIA",
727
+ href="https://www.nvidia.com/",
728
+ )
729
+ else:
730
+ image = None
731
+ titlebar_theme = TitlebarConfig(buttons=buttons, image=image, title_text="Movimento")
732
+ client.gui.set_panel_label("Movimento")
733
+ client.gui.configure_theme(
734
+ titlebar_content=titlebar_theme,
735
+ control_layout="floating", # "floating", # ['floating', 'collapsible', 'fixed']
736
+ control_width="large", # ['small', 'medium', 'large']
737
+ dark_mode=dark_mode,
738
+ show_logo=False, # hide viser logo on bottom left corner
739
+ show_share_button=False,
740
+ titlebar_dark_mode_checkbox_uuid=titlebar_dark_mode_checkbox_uuid,
741
+ brand_color=(152, 189, 255), # (60, 131, 0), # (R, G, B) tuple
742
+ )
config.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+
6
+ from kimodo.assets import DEMO_EXAMPLES_ROOT
7
+ from kimodo.model.registry import (
8
+ AVAILABLE_MODELS,
9
+ DEFAULT_MODEL,
10
+ FRIENDLY_NAMES,
11
+ get_datasets,
12
+ get_model_info,
13
+ get_models_for_dataset_skeleton,
14
+ get_short_key_from_display_name,
15
+ get_skeleton_display_name,
16
+ get_skeleton_display_names_for_dataset,
17
+ get_skeleton_key_from_display_name,
18
+ get_skeletons_for_dataset,
19
+ get_versions_for_dataset_skeleton,
20
+ resolve_to_short_key,
21
+ )
22
+
23
+ SERVER_NAME = os.environ.get("SERVER_NAME", "0.0.0.0")
24
+ SERVER_PORT = int(os.environ.get("SERVER_PORT") or os.environ.get("PORT", "7860"))
25
+
26
+
27
+ def _env_bool(name: str, default: bool = False) -> bool:
28
+ raw = os.environ.get(name)
29
+ if raw is None:
30
+ return default
31
+ return str(raw).strip().lower() in ("1", "true", "yes", "on")
32
+
33
+
34
+ HF_MODE = _env_bool("HF_MODE", False)
35
+
36
+ # HF mode: user queue and session limit (override via env in Spaces)
37
+ MAX_ACTIVE_USERS = int(os.environ.get("MAX_ACTIVE_USERS", "5"))
38
+ MAX_SESSION_MINUTES = float(os.environ.get("MAX_SESSION_MINUTES", "5.0"))
39
+
40
+ DEFAULT_PLAYBACK_SPEED = 1.0
41
+ # default start duration is 6.0 sec, but model can handle up to 10 sec
42
+ DEFAULT_CUR_DURATION = 6.0
43
+ DEFAULT_PROMPT = "A person walks forward."
44
+ MIN_DURATION = 2.0
45
+ MAX_DURATION = 10.0
46
+
47
+ SHOW_TRANSITION_PARAMS = False
48
+ INIT_POSTPROCESSING = True
49
+ NB_TRANSITION_FRAMES = 5
50
+
51
+ LIGHT_THEME = dict(
52
+ floor=(220, 220, 220),
53
+ grid=(180, 180, 180),
54
+ )
55
+
56
+ # Dark theme: slightly lighter grid and floor for better visibility and less flat black
57
+ DARK_THEME = dict(
58
+ floor=(48, 48, 52),
59
+ grid=(105, 105, 110),
60
+ )
61
+
62
+ EXAMPLES_ROOT_DIR = str(DEMO_EXAMPLES_ROOT)
63
+
64
+ # Model list and paths from kimodo registry (all models: Kimodo + TMR)
65
+ MODEL_NAMES = tuple(AVAILABLE_MODELS)
66
+ MODEL_EXAMPLES_DIRS = {name: os.path.join(EXAMPLES_ROOT_DIR, name) for name in MODEL_NAMES}
67
+ # Display labels for backward compatibility (short_key -> display name)
68
+ MODEL_LABELS = {name: FRIENDLY_NAMES.get(name, f"Model ({name})") for name in MODEL_NAMES}
69
+ MODEL_LABEL_TO_NAME = {label: name for name, label in MODEL_LABELS.items()}
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # Demo UI copy
73
+ # -----------------------------------------------------------------------------
74
+
75
+ DEMO_UI_QUICK_START_CORE_MD = """
76
+ ### Camera
77
+ - **Left-drag**: rotate
78
+ - **Right-drag**: pan
79
+ - **Scroll**: zoom
80
+
81
+ ### Playback
82
+ - **Space** to play/pause
83
+ - **←/→** to step frames, or click the frame number.
84
+ - **Scroll up/down** in the timeline: move left/right
85
+ - **Shift + scroll** in the timeline: zoom in/out
86
+
87
+ ### Prompts
88
+ - **Double-click** a text prompt to edit it.
89
+ - **Click and drag** the right edge of a prompt box to extend/shorten it.
90
+ - **Click empty space** to add a prompt.
91
+ - **Right-click** a prompt to delete it.
92
+
93
+ ### Generate
94
+ - Go to the **Generate** tab to modify options
95
+ - It is also possible to **load** examples
96
+ - Click **Generate** to generate a motion
97
+
98
+ ### Constraints
99
+ - This is **optional**: should be use after a first generation
100
+ - **Click** in the timeline tracks (Full-Body / 2D root etc) to add a constraint.
101
+ - **Right-click** on a constraint to delete it.
102
+ - To **edit** a constraint:
103
+ - Move playback to the target frame
104
+ - Click **Enter Editing Mode** in the Constraints tab.
105
+ """
106
+
107
+ DEMO_UI_QUICK_START_MODAL_MD = (
108
+ DEMO_UI_QUICK_START_CORE_MD
109
+ + """
110
+
111
+ See the **Instructions** tab for the full user manual.
112
+ """
113
+ )
114
+
115
+ DEMO_UI_INSTRUCTIONS_TAB_MD = (
116
+ """
117
+ ## How to Use This Demo
118
+
119
+ """
120
+ + DEMO_UI_QUICK_START_CORE_MD
121
+ + """
122
+
123
+ ---
124
+
125
+ ### Generating Motion (step-by-step)
126
+
127
+ 1. **Edit the text prompts** in the timeline (e.g., "A person walks forward.")
128
+ 2. **Modify the duration** by moving the right edge of each prompts (2–10 seconds)
129
+ 3. **Add constraints** (optional) to control the motion:
130
+ - Click **Enter Editing Mode** to adjust the character pose
131
+ - Use the timeline to place keyframes or intervals in constraint tracks (see below)
132
+ 4. **Click Generate** to create the motion
133
+ 5. If generating multiple samples, **click on a mesh** to select which one to keep
134
+
135
+ ### Timeline Editing
136
+
137
+ **Adding Constraints:**
138
+ 1. Click anywhere on the timeline to add a keyframe at that frame. The keyframe is created based on the current character motion.
139
+ 2. Ctrl/Cmd+click+drag to add an interval constraint, or expand a keyframe into an interval
140
+ 3. Enter editing mode with the **Enter Editing Mode** button to adjust character pose before/after adding constraints.
141
+
142
+ **Constraint Types:**
143
+ - **Full-Body**: constrains the entire character pose
144
+ - **2D Root**: constrains the character's path on the ground plane
145
+ - Enable **Densify** to create a continuous path
146
+ - **End-Effectors**: constrains hands and feet positions
147
+ - Use separate tracks for Left/Right Hand/Foot
148
+
149
+
150
+ **Moving & Deleting:**
151
+ - **Drag keyframes/intervals** to move them to different frames
152
+ - **Right-click** a keyframe or interval to delete it
153
+ - Use **Clear All Constraints** to remove everything
154
+
155
+ **Tips:**
156
+ - The posing skeleton becomes visible in editing mode for precise positioning
157
+ - Use **Snap to constraint** to align the current frame to a constraint
158
+
159
+ ### Saving & Loading
160
+
161
+ You can save the current constraints or current motion to load in later from the Load/Save menu.
162
+ Saving an **Example** will save the full constraints, motion, and generation metadata.
163
+
164
+ ### Visualization Options
165
+
166
+ Switch to the **Visualize** tab to:
167
+ - Toggle mesh and skeleton visibility
168
+ - Adjust mesh opacity
169
+ - Show/hide foot contact indicators
170
+ - Switch between light and dark modes
171
+ """
172
+ )
embedding_cache.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import contextlib
5
+ import contextvars
6
+ import hashlib
7
+ import json
8
+ import os
9
+ import threading
10
+ import time
11
+ from collections import OrderedDict
12
+ from dataclasses import dataclass
13
+ from typing import Iterable, Optional
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from kimodo.sanitize import sanitize_texts
19
+
20
+ _ACTIVE_SESSION = contextvars.ContextVar("kimodo_demo_active_session", default=None)
21
+
22
+
23
+ @dataclass
24
+ class CacheStats:
25
+ hits: int = 0
26
+ misses: int = 0
27
+ disk_hits: int = 0
28
+
29
+
30
+ class EmbeddingCache:
31
+ """Disk-backed text embedding cache with a small in-memory LRU."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ model_name: str,
37
+ encoder_id: str,
38
+ base_dir: Optional[str] = None,
39
+ max_mem_entries: int = 128,
40
+ ) -> None:
41
+ cache_root = base_dir or os.environ.get(
42
+ "kimodo_EMBED_CACHE_DIR",
43
+ os.path.join("~", ".cache", "kimodo_demo", "embeddings"),
44
+ )
45
+ self.base_dir = os.path.expanduser(cache_root)
46
+ self.model_name = model_name
47
+ self.encoder_id = encoder_id
48
+ self.max_mem_entries = max_mem_entries
49
+ self.stats = CacheStats()
50
+
51
+ self._lock = threading.Lock()
52
+ self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict()
53
+ self._index = {}
54
+ self._index_loaded = False
55
+
56
+ def _model_dir(self) -> str:
57
+ return os.path.join(self.base_dir, self.model_name)
58
+
59
+ def _index_path(self) -> str:
60
+ return os.path.join(self._model_dir(), "index.json")
61
+
62
+ def _prewarm_marker_path(self, key: str) -> str:
63
+ return os.path.join(self._model_dir(), f"prewarm_{key}.json")
64
+
65
+ def has_prewarm_marker(self, key: str) -> bool:
66
+ return os.path.exists(self._prewarm_marker_path(key))
67
+
68
+ def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None:
69
+ os.makedirs(self._model_dir(), exist_ok=True)
70
+ payload = {"prompt_count": prompt_count, "updated_at": time.time()}
71
+ tmp_path = f"{self._prewarm_marker_path(key)}.tmp"
72
+ with open(tmp_path, "w", encoding="utf-8") as f:
73
+ json.dump(payload, f)
74
+ os.replace(tmp_path, self._prewarm_marker_path(key))
75
+
76
+ def _load_index(self) -> None:
77
+ if self._index_loaded:
78
+ return
79
+ index_path = self._index_path()
80
+ if os.path.exists(index_path):
81
+ try:
82
+ with open(index_path, "r", encoding="utf-8") as f:
83
+ self._index = json.load(f)
84
+ except json.JSONDecodeError:
85
+ self._index = {}
86
+ self._index_loaded = True
87
+
88
+ def _save_index(self) -> None:
89
+ os.makedirs(self._model_dir(), exist_ok=True)
90
+ tmp_path = f"{self._index_path()}.tmp"
91
+ with open(tmp_path, "w", encoding="utf-8") as f:
92
+ json.dump(self._index, f)
93
+ os.replace(tmp_path, self._index_path())
94
+
95
+ def _make_key(self, text: str) -> str:
96
+ key_src = f"{self.model_name}|{self.encoder_id}|{text}"
97
+ return hashlib.sha256(key_src.encode("utf-8")).hexdigest()
98
+
99
+ def _entry_path(self, key: str) -> str:
100
+ return os.path.join(self._model_dir(), f"{key}.npy")
101
+
102
+ def _mem_get(self, key: str) -> Optional[np.ndarray]:
103
+ if key in self._mem_cache:
104
+ self._mem_cache.move_to_end(key)
105
+ return self._mem_cache[key]
106
+ return None
107
+
108
+ def _mem_put(self, key: str, value: np.ndarray) -> None:
109
+ self._mem_cache[key] = value
110
+ self._mem_cache.move_to_end(key)
111
+ while len(self._mem_cache) > self.max_mem_entries:
112
+ self._mem_cache.popitem(last=False)
113
+
114
+ def _disk_load(self, key: str) -> Optional[np.ndarray]:
115
+ path = self._entry_path(key)
116
+ if not os.path.exists(path):
117
+ return None
118
+ try:
119
+ return np.load(path)
120
+ except Exception:
121
+ return None
122
+
123
+ def _disk_save(self, key: str, value: np.ndarray) -> None:
124
+ os.makedirs(self._model_dir(), exist_ok=True)
125
+ np.save(self._entry_path(key), value)
126
+ self._index[key] = {
127
+ "length": int(value.shape[0]),
128
+ "dtype": str(value.dtype),
129
+ "updated_at": time.time(),
130
+ }
131
+
132
+ def _maybe_use_session_cache(self, texts: list[str]):
133
+ session = _ACTIVE_SESSION.get()
134
+ if session is None:
135
+ return None
136
+ if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None:
137
+ return session.last_prompt_embeddings, session.last_prompt_lengths
138
+ return None
139
+
140
+ def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None:
141
+ session = _ACTIVE_SESSION.get()
142
+ if session is None:
143
+ return
144
+ session.last_prompt_texts = texts
145
+ session.last_prompt_embeddings = tensor
146
+ session.last_prompt_lengths = lengths
147
+
148
+ def get_or_encode(self, texts: Iterable[str], encoder):
149
+ if isinstance(texts, str):
150
+ texts = [texts]
151
+ texts = sanitize_texts(list(texts))
152
+ if len(texts) == 0:
153
+ empty = torch.empty()
154
+ return empty, []
155
+
156
+ session_cache = self._maybe_use_session_cache(texts)
157
+ if session_cache is not None:
158
+ return session_cache
159
+
160
+ arrays: list[Optional[np.ndarray]] = [None] * len(texts)
161
+ lengths: list[int] = [0] * len(texts)
162
+ misses: list[tuple[int, str, str]] = []
163
+
164
+ with self._lock:
165
+ self._load_index()
166
+ for idx, text in enumerate(texts):
167
+ key = self._make_key(text)
168
+ cached = self._mem_get(key)
169
+ if cached is not None:
170
+ arrays[idx] = cached
171
+ lengths[idx] = cached.shape[0]
172
+ self.stats.hits += 1
173
+ continue
174
+
175
+ cached = self._disk_load(key)
176
+ if cached is not None:
177
+ arrays[idx] = cached
178
+ lengths[idx] = cached.shape[0]
179
+ self._mem_put(key, cached)
180
+ self.stats.disk_hits += 1
181
+ continue
182
+
183
+ misses.append((idx, text, key))
184
+ self.stats.misses += 1
185
+
186
+ if misses:
187
+ miss_texts = [text for _, text, _ in misses]
188
+ miss_tensor, miss_lengths = encoder(miss_texts)
189
+ miss_tensor = miss_tensor.detach().cpu()
190
+ miss_tensor_np = miss_tensor.numpy()
191
+
192
+ with self._lock:
193
+ self._load_index()
194
+ for miss_idx, length in enumerate(miss_lengths):
195
+ idx, _text, key = misses[miss_idx]
196
+ arr = miss_tensor_np[miss_idx, :length].copy()
197
+ arrays[idx] = arr
198
+ lengths[idx] = int(length)
199
+ self._mem_put(key, arr)
200
+ self._disk_save(key, arr)
201
+ self._save_index()
202
+
203
+ max_len = max(lengths) if lengths else 0
204
+ feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0
205
+ dtype = arrays[0].dtype if arrays[0] is not None else np.float32
206
+ padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype)
207
+ for idx, arr in enumerate(arrays):
208
+ if arr is None:
209
+ continue
210
+ padded[idx, : arr.shape[0]] = arr
211
+
212
+ result = torch.from_numpy(padded)
213
+ self._update_session_cache(texts, result, lengths)
214
+ return result, lengths
215
+
216
+
217
+ class CachedTextEncoder:
218
+ """Wrapper around a text encoder to add disk-backed caching."""
219
+
220
+ def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None):
221
+ self.encoder = encoder
222
+ self.model_name = model_name
223
+ encoder_id = f"{type(encoder).__name__}"
224
+ self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir)
225
+
226
+ def __call__(self, texts):
227
+ return self.cache.get_or_encode(texts, self.encoder)
228
+
229
+ def prewarm(self, texts) -> None:
230
+ if isinstance(texts, str):
231
+ texts = [texts]
232
+ texts = sanitize_texts(list(texts))
233
+ prewarm_key = hashlib.sha256("|".join(texts).encode("utf-8")).hexdigest()
234
+ if self.cache.has_prewarm_marker(prewarm_key):
235
+ return
236
+ self.cache.get_or_encode(texts, self.encoder)
237
+ self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts))
238
+
239
+ def to(self, device=None, dtype=None):
240
+ if hasattr(self.encoder, "to"):
241
+ self.encoder.to(device=device, dtype=dtype)
242
+ return self
243
+
244
+ @contextlib.contextmanager
245
+ def session_context(self, session):
246
+ token = _ACTIVE_SESSION.set(session)
247
+ try:
248
+ yield
249
+ finally:
250
+ _ACTIVE_SESSION.reset(token)
251
+
252
+ def __getattr__(self, name):
253
+ return getattr(self.encoder, name)
generation.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from collections import defaultdict
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ import viser
11
+ from kimodo.constraints import (
12
+ TYPE_TO_CLASS,
13
+ FullBodyConstraintSet,
14
+ Root2DConstraintSet,
15
+ )
16
+ from kimodo.exports.mujoco import apply_g1_real_robot_projection
17
+ from kimodo.skeleton import G1Skeleton34, SOMASkeleton30
18
+ from kimodo.tools import seed_everything
19
+
20
+ from .embedding_cache import CachedTextEncoder
21
+ from .state import ClientSession, ModelBundle
22
+
23
+
24
+ def compute_model_constraints_lst(
25
+ session: ClientSession,
26
+ model_bundle: ModelBundle,
27
+ num_frames: int,
28
+ device: str,
29
+ ):
30
+ """Compute the lst of constraints for the model based on the constraints in viser."""
31
+ assert len(session.motions) == 1, "Only one motion allowed for constrained generation"
32
+ if not session.constraints:
33
+ return []
34
+
35
+ model_skeleton = model_bundle.model.skeleton
36
+ # For SOMA, UI uses somaskel77; extract 30-joint subset for the model
37
+ use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints
38
+ skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None
39
+
40
+ dense_smooth_root_pos_2d = None
41
+ if session.constraints["2D Root"].dense_path:
42
+ # get the full 2d root
43
+ dense_smooth_root_pos_2d = session.constraints["2D Root"].get_constraint_info(device=device)["root_pos"][
44
+ :, [0, 2]
45
+ ]
46
+
47
+ model_constraints = []
48
+ for track_name, constraint in session.constraints.items():
49
+ constraint_info = constraint.get_constraint_info(device=device)
50
+ frame_idx = constraint_info["frame_idx"]
51
+ # drop any constraints outside the generation range
52
+ valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames]
53
+ valid_idx = [i for i, _ in valid_info]
54
+ valid_frame_idx = [fi for _, fi in valid_info]
55
+
56
+ if len(valid_frame_idx) == 0:
57
+ continue
58
+
59
+ frame_indices = torch.tensor(valid_frame_idx)
60
+ if track_name == "2D Root":
61
+ smooth_root_pos_2d = constraint_info["root_pos"][valid_idx][:, [0, 2]].to(device)
62
+ # same as "smooth_root_2d"
63
+ model_constraints.append(
64
+ Root2DConstraintSet(
65
+ model_skeleton,
66
+ frame_indices,
67
+ smooth_root_pos_2d,
68
+ )
69
+ )
70
+ elif track_name == "Full-Body":
71
+ constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
72
+ constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
73
+ if skel_slice is not None:
74
+ constraint_joints_pos = constraint_joints_pos[:, skel_slice]
75
+ constraint_joints_rot = constraint_joints_rot[:, skel_slice]
76
+
77
+ smooth_root_pos_2d = None
78
+ if dense_smooth_root_pos_2d is not None:
79
+ smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices]
80
+
81
+ model_constraints.append(
82
+ FullBodyConstraintSet(
83
+ model_skeleton,
84
+ frame_indices,
85
+ constraint_joints_pos,
86
+ constraint_joints_rot,
87
+ smooth_root_2d=smooth_root_pos_2d,
88
+ )
89
+ )
90
+ elif track_name == "End-Effectors":
91
+ constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
92
+ constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
93
+ if skel_slice is not None:
94
+ constraint_joints_pos = constraint_joints_pos[:, skel_slice]
95
+ constraint_joints_rot = constraint_joints_rot[:, skel_slice]
96
+
97
+ end_effector_type_set_lst = [
98
+ end_effector_type_set
99
+ for i, end_effector_type_set in enumerate(constraint_info["end_effector_type"])
100
+ if i in valid_idx
101
+ ]
102
+
103
+ # regroup the end effector data by type
104
+ cls_idx = defaultdict(list)
105
+ for idx, end_effector_type_set in enumerate(end_effector_type_set_lst):
106
+ for end_effector_type in end_effector_type_set:
107
+ cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx)
108
+
109
+ for cls, lst_idx in cls_idx.items():
110
+ frame_indices_cls = frame_indices[lst_idx]
111
+ smooth_root_pos_2d = None
112
+ if dense_smooth_root_pos_2d is not None:
113
+ smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls]
114
+
115
+ constraint_joints_pos_el = constraint_joints_pos[lst_idx]
116
+ constraint_joints_rot_el = constraint_joints_rot[lst_idx]
117
+
118
+ model_constraints.append(
119
+ cls(
120
+ model_skeleton,
121
+ frame_indices_cls,
122
+ constraint_joints_pos_el,
123
+ constraint_joints_rot_el,
124
+ smooth_root_2d=smooth_root_pos_2d,
125
+ )
126
+ )
127
+ else:
128
+ raise ValueError(f"Unsupported constraint type: {constraint.display_name}")
129
+ return model_constraints
130
+
131
+
132
+ def generate(
133
+ *,
134
+ client: viser.ClientHandle,
135
+ session: ClientSession,
136
+ model_bundle: ModelBundle,
137
+ prompts: list[str],
138
+ num_frames: list[int],
139
+ num_samples: int,
140
+ seed: int,
141
+ diffusion_steps: int,
142
+ cfg_weight: Optional[list[float]] = None,
143
+ cfg_type: Optional[str] = None,
144
+ postprocess_parameters: Optional[dict] = None,
145
+ transitions_parameters: Optional[dict] = None,
146
+ real_robot_rotations: bool = False,
147
+ device: str,
148
+ clear_motions,
149
+ add_character_motion,
150
+ ) -> None:
151
+ client_id = client.client_id
152
+ print(
153
+ f"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})"
154
+ )
155
+
156
+ seed_everything(seed)
157
+
158
+ model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device)
159
+ cfg_weight = cfg_weight or [2.0, 2.0]
160
+ postprocess_parameters = postprocess_parameters or {}
161
+ transitions_parameters = transitions_parameters or {}
162
+
163
+ encoder = getattr(model_bundle.model, "text_encoder", None)
164
+ if isinstance(encoder, CachedTextEncoder):
165
+ with encoder.session_context(session):
166
+ pred_joints_output = model_bundle.model(
167
+ prompts,
168
+ num_frames,
169
+ diffusion_steps,
170
+ multi_prompt=True,
171
+ constraint_lst=model_constraints,
172
+ cfg_weight=cfg_weight,
173
+ num_samples=num_samples,
174
+ cfg_type=cfg_type,
175
+ **(postprocess_parameters | transitions_parameters),
176
+ ) # [B, T, motion_rep_dim]
177
+ else:
178
+ pred_joints_output = model_bundle.model(
179
+ prompts,
180
+ num_frames,
181
+ diffusion_steps,
182
+ multi_prompt=True,
183
+ constraint_lst=model_constraints,
184
+ cfg_weight=cfg_weight,
185
+ num_samples=num_samples,
186
+ cfg_type=cfg_type,
187
+ **(postprocess_parameters | transitions_parameters),
188
+ ) # [B, T, motion_rep_dim]
189
+
190
+ joints_pos = pred_joints_output["posed_joints"] # [B, T, J, 3]
191
+ joints_rot = pred_joints_output["global_rot_mats"]
192
+ foot_contacts = pred_joints_output.get("foot_contacts")
193
+
194
+ # Optionally project G1 to real robot DoF (1-DoF per joint, clamped) for display.
195
+ if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34):
196
+ joints_pos, joints_rot = apply_g1_real_robot_projection(
197
+ session.skeleton,
198
+ pred_joints_output["posed_joints"],
199
+ pred_joints_output["global_rot_mats"],
200
+ clamp_to_limits=True,
201
+ )
202
+
203
+ # Display on characters (callbacks keep this module UI-agnostic).
204
+ clear_motions(client_id)
205
+ # Keep one sample centered at the origin so constraints align.
206
+ spread_factor = 1.0 # meters
207
+ center_idx = num_samples // 2
208
+ x_trans = (np.arange(num_samples) - center_idx) * spread_factor
209
+ for i in range(num_samples):
210
+ cur_joints_pos = joints_pos[i]
211
+ cur_joints_pos[..., 0] += x_trans[i]
212
+ add_character_motion(
213
+ client,
214
+ session.skeleton,
215
+ cur_joints_pos,
216
+ joints_rot[i],
217
+ foot_contacts[i],
218
+ )
queue_manager.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """HF mode user queue and session time limit."""
4
+
5
+ import math
6
+ import threading
7
+ import time
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ import viser
12
+
13
+ from .config import DEMO_UI_QUICK_START_MODAL_MD, MAX_SESSION_MINUTES
14
+
15
+ # Link for "Duplicate this Space" on Hugging Face (used in queue and expiry modals).
16
+ DUPLICATE_SPACE_URL = "https://huggingface.co/spaces/nvidia/Kimodo?duplicate=true"
17
+ GITHUB_REPO_URL = "https://github.com/nv-tlabs/kimodo"
18
+
19
+ # How often to refresh queue modal content (position, total, estimated wait).
20
+ QUEUE_MODAL_REFRESH_INTERVAL_SEC = 15
21
+
22
+
23
+ class UserQueue:
24
+ """Thread-safe queue: active users (with activation timestamp) and waiting queue."""
25
+
26
+ def __init__(self, max_active: int, max_minutes: float) -> None:
27
+ self._max_active = max_active
28
+ self._max_minutes = max_minutes
29
+ self._max_seconds = max_minutes * 60.0
30
+ self._active: dict[int, float] = {} # client_id -> activation timestamp
31
+ self._queued: list[int] = []
32
+ self._lock = threading.Lock()
33
+
34
+ def try_activate(self, client_id: int) -> bool:
35
+ """If a slot is free, add client as active and return True.
36
+
37
+ Else return False.
38
+ """
39
+ with self._lock:
40
+ if len(self._active) < self._max_active:
41
+ self._active[client_id] = time.time()
42
+ return True
43
+ return False
44
+
45
+ def enqueue(self, client_id: int) -> None:
46
+ with self._lock:
47
+ if client_id not in self._queued:
48
+ self._queued.append(client_id)
49
+
50
+ def remove(self, client_id: int) -> bool:
51
+ """Remove from active or queue.
52
+
53
+ Returns True if was active.
54
+ """
55
+ with self._lock:
56
+ was_active = client_id in self._active
57
+ self._active.pop(client_id, None)
58
+ if client_id in self._queued:
59
+ self._queued.remove(client_id)
60
+ return was_active
61
+
62
+ def promote_next(self) -> int | None:
63
+ """If queue non-empty, pop first, activate them, return their client_id.
64
+
65
+ Else None.
66
+ """
67
+ with self._lock:
68
+ if not self._queued:
69
+ return None
70
+ client_id = self._queued.pop(0)
71
+ self._active[client_id] = time.time()
72
+ return client_id
73
+
74
+ def get_queue_position(self, client_id: int) -> tuple[int, int] | None:
75
+ """(1-based position, total_in_queue) or None if not queued."""
76
+ with self._lock:
77
+ if client_id not in self._queued:
78
+ return None
79
+ pos = self._queued.index(client_id)
80
+ return (pos + 1, len(self._queued))
81
+
82
+ def get_estimated_wait_seconds(self, client_id: int) -> float:
83
+ """Estimated seconds until this queued client gets a slot."""
84
+ with self._lock:
85
+ if client_id not in self._queued:
86
+ return 0.0
87
+ pos = self._queued.index(client_id) + 1 # 1-based
88
+ # Expiry times of active users (when they free a slot)
89
+ now = time.time()
90
+ expiries = sorted(now + self._max_seconds - (now - t) for t in self._active.values())
91
+ if not expiries:
92
+ return 0.0
93
+ # Nth slot to free (1-indexed) wraps over expiries
94
+ idx = (pos - 1) % len(expiries)
95
+ cycles = (pos - 1) // len(expiries)
96
+ slot_free_time = expiries[idx] + cycles * self._max_seconds
97
+ return max(0.0, slot_free_time - now)
98
+
99
+ def is_active(self, client_id: int) -> bool:
100
+ with self._lock:
101
+ return client_id in self._active
102
+
103
+ def was_active(self, client_id: int) -> bool:
104
+ """True if client is currently active (for use when already holding lock)."""
105
+ return client_id in self._active
106
+
107
+
108
+ def _format_wait(seconds: float) -> str:
109
+ if seconds < 60:
110
+ return "less than a minute"
111
+ mins = int(math.ceil(seconds / 60))
112
+ return f"~{mins} minute{'s' if mins != 1 else ''}"
113
+
114
+
115
+ def _queue_modal_markdown(position: int, total: int, estimated_wait_sec: float) -> str:
116
+ wait_str = _format_wait(estimated_wait_sec)
117
+ mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
118
+ return f"""## Kimodo Demo — Please Wait
119
+
120
+ This demo runs with limited capacity.
121
+ Each user gets **{mins} minute{"s" if mins != 1 else ""}** of interactive time.
122
+
123
+ **Your position in queue:** {position} / {total}
124
+
125
+ **Estimated wait:** {wait_str}
126
+
127
+ Please keep this tab open — the demo will start automatically when it's your turn.
128
+
129
+ ---
130
+ *Want unlimited access? [Duplicate this Space]({DUPLICATE_SPACE_URL}) or clone the [GitHub repo]({GITHUB_REPO_URL}) to run locally!*
131
+ """
132
+
133
+
134
+ def _welcome_modal_markdown() -> str:
135
+ mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
136
+ return f"""## Welcome to Kimodo Demo
137
+
138
+ You have been granted a **{mins}-minute** demo session.
139
+ Your session timer has started.
140
+
141
+ Click the button below to begin!
142
+ """
143
+
144
+
145
+ def _expiry_modal_markdown() -> str:
146
+ mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
147
+ return f"""## Session Expired
148
+
149
+ Your {mins}-minute demo session has ended.
150
+ Thank you for trying Kimodo!
151
+
152
+ Refresh this page to rejoin the queue, or [duplicate this Space]({DUPLICATE_SPACE_URL}) for unlimited access.
153
+ """
154
+
155
+
156
+ class QueueManager:
157
+ """Orchestrates HF mode: queue modals, welcome modal, session timer, promotion."""
158
+
159
+ def __init__(
160
+ self,
161
+ queue: UserQueue,
162
+ server: viser.ViserServer,
163
+ setup_demo_for_client: Callable[[viser.ClientHandle], None],
164
+ cleanup_session: Callable[[int], None],
165
+ ) -> None:
166
+ self._queue = queue
167
+ self._server = server
168
+ self._setup_demo_for_client = setup_demo_for_client
169
+ self._cleanup_session = cleanup_session
170
+ self._max_seconds = queue._max_seconds
171
+
172
+ self._queue_modal_handles: dict[int, tuple[Any, Any]] = {}
173
+ self._welcome_modal_handles: dict[int, Any] = {}
174
+ self._expiry_timers: dict[int, threading.Timer] = {}
175
+ self._lock = threading.Lock()
176
+ self._refresh_stop = threading.Event()
177
+ self._refresh_thread = threading.Thread(
178
+ target=self._queue_modal_refresh_loop,
179
+ name="queue-modal-refresh",
180
+ daemon=True,
181
+ )
182
+ self._refresh_thread.start()
183
+
184
+ def _queue_modal_refresh_loop(self) -> None:
185
+ """Periodically refresh queue modals so position, total, and estimated wait stay current."""
186
+ while not self._refresh_stop.wait(timeout=QUEUE_MODAL_REFRESH_INTERVAL_SEC):
187
+ self._update_all_queue_modals()
188
+
189
+ def on_client_connect(self, client: viser.ClientHandle) -> None:
190
+ """Handle new connection: activate if slot free, else enqueue and show queue modal."""
191
+ client_id = client.client_id
192
+ if self._queue.try_activate(client_id):
193
+ try:
194
+ self._setup_demo_for_client(client)
195
+ except RuntimeError as e:
196
+ if "CUDA error" in str(e):
197
+ print(f"CUDA error while setting up client {client_id}: {e}")
198
+ return
199
+ raise
200
+ self._start_session_timer(client_id)
201
+ self._show_welcome_modal(client)
202
+ else:
203
+ self._queue.enqueue(client_id)
204
+ self._show_queue_modal(client)
205
+ self._update_all_queue_modals()
206
+
207
+ def on_client_disconnect(self, client_id: int) -> None:
208
+ """Remove from queue/active, cancel timer, promote next if was active.
209
+
210
+ Session/scene cleanup is done by the demo's on_client_disconnect.
211
+ """
212
+ with self._lock:
213
+ self._expiry_timers.pop(client_id, None)
214
+ self._queue_modal_handles.pop(client_id, None)
215
+ self._welcome_modal_handles.pop(client_id, None)
216
+ was_active = self._queue.remove(client_id)
217
+ if was_active:
218
+ self._promote_next_user()
219
+ else:
220
+ self._update_all_queue_modals()
221
+
222
+ def _show_queue_modal(self, client: viser.ClientHandle) -> None:
223
+ client_id = client.client_id
224
+ pos, total = self._queue.get_queue_position(client_id) or (0, 0)
225
+ wait_sec = self._queue.get_estimated_wait_seconds(client_id)
226
+ md_content = _queue_modal_markdown(pos, total, wait_sec)
227
+
228
+ modal = client.gui.add_modal(
229
+ "Kimodo Demo — Please Wait",
230
+ size="xl",
231
+ show_close_button=False,
232
+ )
233
+ with modal:
234
+ md_handle = client.gui.add_markdown(md_content)
235
+ with self._lock:
236
+ self._queue_modal_handles[client_id] = (modal, md_handle)
237
+
238
+ def _show_quick_start_modal(self, client: viser.ClientHandle) -> None:
239
+ """Show the quick start instructions modal (same as non-HF mode)."""
240
+ with client.gui.add_modal(
241
+ "Welcome — Quick Start",
242
+ size="xl",
243
+ show_close_button=True,
244
+ save_choice="kimodo.demo.quick_start_ack",
245
+ ) as quick_start_modal:
246
+ client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
247
+ client.gui.add_button("Got it (don't remind me again)").on_click(lambda _: quick_start_modal.close())
248
+
249
+ def _show_welcome_modal(self, client: viser.ClientHandle) -> None:
250
+ client_id = client.client_id
251
+
252
+ def _on_start_demo(_: Any) -> None:
253
+ modal.close()
254
+ self._show_quick_start_modal(client)
255
+
256
+ modal = client.gui.add_modal(
257
+ "Welcome to Kimodo Demo",
258
+ size="xl",
259
+ show_close_button=True,
260
+ )
261
+ with modal:
262
+ client.gui.add_markdown(_welcome_modal_markdown())
263
+ client.gui.add_button("Start Demo").on_click(_on_start_demo)
264
+ with self._lock:
265
+ self._welcome_modal_handles[client_id] = modal
266
+
267
+ def _update_all_queue_modals(self) -> None:
268
+ with self._lock:
269
+ handles = list(self._queue_modal_handles.items())
270
+ for client_id, (modal, md_handle) in handles:
271
+ pos_total = self._queue.get_queue_position(client_id)
272
+ if pos_total is None:
273
+ continue
274
+ pos, total = pos_total
275
+ wait_sec = self._queue.get_estimated_wait_seconds(client_id)
276
+ try:
277
+ md_handle.content = _queue_modal_markdown(pos, total, wait_sec)
278
+ except Exception:
279
+ pass
280
+
281
+ def _promote_next_user(self) -> None:
282
+ promoted_id = self._queue.promote_next()
283
+ if promoted_id is None:
284
+ return
285
+ clients = self._server.get_clients()
286
+ client = clients.get(promoted_id)
287
+ if client is None:
288
+ return
289
+ with self._lock:
290
+ old = self._queue_modal_handles.pop(promoted_id, None)
291
+ if old is not None:
292
+ try:
293
+ old[0].close()
294
+ except Exception:
295
+ pass
296
+ try:
297
+ self._setup_demo_for_client(client)
298
+ except RuntimeError as e:
299
+ if "CUDA error" in str(e):
300
+ print(f"CUDA error while setting up client {promoted_id}: {e}")
301
+ return
302
+ raise
303
+ self._start_session_timer(promoted_id)
304
+ self._show_welcome_modal(client)
305
+ self._update_all_queue_modals()
306
+
307
+ def _start_session_timer(self, client_id: int) -> None:
308
+ def on_expiry() -> None:
309
+ self._on_session_expired(client_id)
310
+
311
+ t = threading.Timer(self._max_seconds, on_expiry)
312
+ t.daemon = True
313
+ with self._lock:
314
+ self._expiry_timers[client_id] = t
315
+ t.start()
316
+
317
+ def _on_session_expired(self, client_id: int) -> None:
318
+ with self._lock:
319
+ self._expiry_timers.pop(client_id, None)
320
+ if not self._queue.is_active(client_id):
321
+ return
322
+ self._queue.remove(client_id)
323
+ clients = self._server.get_clients()
324
+ client = clients.get(client_id)
325
+ if client is not None:
326
+ try:
327
+ with client.gui.add_modal(
328
+ "Session Expired",
329
+ size="lg",
330
+ show_close_button=False,
331
+ ) as modal_ctx:
332
+ client.gui.add_markdown(_expiry_modal_markdown())
333
+ except Exception:
334
+ pass
335
+ self._cleanup_session(client_id)
336
+ self._promote_next_user()
state.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ import kimodo.viz.viser_utils as viser_utils
10
+ import viser
11
+ from kimodo.skeleton import SkeletonBase
12
+ from kimodo.viz.viser_utils import GuiElements
13
+
14
+ from .config import (
15
+ DEFAULT_CUR_DURATION,
16
+ DEFAULT_MODEL,
17
+ DEFAULT_PLAYBACK_SPEED,
18
+ )
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class ModelBundle:
23
+ model: object
24
+ motion_rep: object
25
+ skeleton: SkeletonBase
26
+ model_fps: float
27
+
28
+
29
+ @dataclass
30
+ class ClientSession:
31
+ """Per-client session data."""
32
+
33
+ client: viser.ClientHandle
34
+ gui_elements: GuiElements
35
+ motions: dict # character_name -> CharacterMotion
36
+ constraints: dict[str, viser_utils.ConstraintSet] = field(default_factory=dict)
37
+ timeline_data: object = None
38
+ frame_idx: int = 0
39
+ playing: bool = False
40
+ playback_speed: float = DEFAULT_PLAYBACK_SPEED
41
+ playback_time_accumulator: float = 0.0
42
+ last_space_toggle_time: float = 0.0
43
+ cur_duration: float = DEFAULT_CUR_DURATION
44
+ max_frame_idx: int = 100 # will be updated based on model_fps
45
+ updating_motions: bool = False
46
+ edit_mode: bool = False
47
+ model_name: str = DEFAULT_MODEL
48
+ model_fps: float = 0.0
49
+ skeleton: SkeletonBase | None = None
50
+ motion_rep: object | None = None
51
+ examples_base_dir: str = ""
52
+ example_dict: dict[str, str] = field(default_factory=dict)
53
+ gui_examples_dropdown: Optional[viser.GuiInputHandle] = None
54
+ gui_save_example_path_text: Optional[viser.GuiInputHandle] = None
55
+ gui_model_selector: Optional[viser.GuiInputHandle] = None
56
+ last_prompt_texts: Optional[list[str]] = None
57
+ last_prompt_embeddings: Optional[torch.Tensor] = None
58
+ last_prompt_lengths: Optional[list[int]] = None
59
+ edit_mode_snapshot: Optional[dict[int, dict[str, object]]] = None
60
+ undo_drag_snapshot: Optional[dict[str, object]] = None
61
+ show_only_current_constraint: bool = False # False = Show All, True = Show only Current
ui.py ADDED
The diff for this file is too large to render. See raw diff