Fix CUDA import order - import spaces before torch (commit e28bffd)
Browse files- .gitattributes +1 -0
- __init__.py +29 -0
- __main__.py +8 -0
- __pycache__/app.cpython-312.pyc +0 -0
- __pycache__/state.cpython-312.pyc +0 -0
- __pycache__/ui.cpython-312.pyc +3 -0
- _qwen_prompts.py +107 -0
- app.py +742 -0
- config.py +172 -0
- embedding_cache.py +253 -0
- generation.py +218 -0
- queue_manager.py +336 -0
- state.py +61 -0
- ui.py +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
__pycache__/ui.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# ruff: noqa: I001
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from kimodo.model import DEFAULT_MODEL
|
| 8 |
+
from kimodo.model.registry import resolve_model_name
|
| 9 |
+
|
| 10 |
+
from .app import Demo
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main() -> None:
|
| 14 |
+
parser = argparse.ArgumentParser(description="Run the kimodo demo UI.")
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--model",
|
| 17 |
+
type=str,
|
| 18 |
+
default=DEFAULT_MODEL,
|
| 19 |
+
help="Default model to load (e.g. Kimodo-SOMA-RP-v1, kimodo-soma-rp, or SOMA).",
|
| 20 |
+
)
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
resolved = resolve_model_name(args.model, "Kimodo")
|
| 24 |
+
demo = Demo(default_model_name=resolved)
|
| 25 |
+
demo.run()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
main()
|
__main__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Entry point for `python -m kimodo.demo`."""
|
| 4 |
+
|
| 5 |
+
from kimodo.demo import main
|
| 6 |
+
|
| 7 |
+
if __name__ == "__main__":
|
| 8 |
+
main()
|
__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (34.2 kB). View file
|
|
|
__pycache__/state.cpython-312.pyc
ADDED
|
Binary file (3.01 kB). View file
|
|
|
__pycache__/ui.cpython-312.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:023646c6b51d238f0aedf942347c389dfa7384d452f7b1b99aa34bcef924706e
|
| 3 |
+
size 145667
|
_qwen_prompts.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Qwen-on-Fireworks helper for auto-generating multi-text-prompt batches."""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import urllib.error
|
| 10 |
+
import urllib.request
|
| 11 |
+
|
| 12 |
+
_MODEL = "accounts/fireworks/models/qwen3p6-27b"
|
| 13 |
+
_BASE = "https://api.fireworks.ai/inference/v1"
|
| 14 |
+
|
| 15 |
+
_SYSTEM = """\
|
| 16 |
+
You are a motion-description writer for a single humanoid character in a 3D animation system.
|
| 17 |
+
Given a scene context and the character's recent motion history, output ONLY a JSON object:
|
| 18 |
+
|
| 19 |
+
{"texts": ["<action phrase 1>", ...], "durations": [<seconds float>, ...]}
|
| 20 |
+
|
| 21 |
+
Rules:
|
| 22 |
+
- Return between 1 and requested_actions short, vivid action phrases that flow naturally from each other.
|
| 23 |
+
- Each phrase describes one distinct physical motion (e.g. "walks forward briskly", "pivots left and crouches").
|
| 24 |
+
- Each duration is between 2.0 and 8.0 seconds.
|
| 25 |
+
- texts and durations must have the same length.
|
| 26 |
+
- Do NOT repeat phrases from history.
|
| 27 |
+
- Return raw JSON only — no markdown, no explanation.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _call_fireworks(messages: list[dict]) -> str:
|
| 32 |
+
api_key = os.environ.get("FIREWORKS_API_KEY", "").strip()
|
| 33 |
+
if not api_key:
|
| 34 |
+
raise RuntimeError("FIREWORKS_API_KEY is not set")
|
| 35 |
+
body = json.dumps({
|
| 36 |
+
"model": _MODEL,
|
| 37 |
+
"messages": messages,
|
| 38 |
+
"max_tokens": 400,
|
| 39 |
+
"temperature": 0.85,
|
| 40 |
+
}).encode()
|
| 41 |
+
req = urllib.request.Request(
|
| 42 |
+
f"{_BASE}/chat/completions",
|
| 43 |
+
data=body,
|
| 44 |
+
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
| 45 |
+
method="POST",
|
| 46 |
+
)
|
| 47 |
+
try:
|
| 48 |
+
with urllib.request.urlopen(req, timeout=40) as r:
|
| 49 |
+
return json.loads(r.read())["choices"][0]["message"]["content"]
|
| 50 |
+
except urllib.error.HTTPError as e:
|
| 51 |
+
raise RuntimeError(f"Fireworks {e.code}: {e.read().decode(errors='ignore')}") from e
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _parse(raw: str) -> dict:
|
| 55 |
+
text = raw.strip()
|
| 56 |
+
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
| 57 |
+
text = m.group(1) if m else text
|
| 58 |
+
s, e = text.find("{"), text.rfind("}")
|
| 59 |
+
return json.loads(text[s:e + 1])
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _fallback(offset: int) -> dict:
|
| 63 |
+
phrases = [
|
| 64 |
+
"walks forward at a steady pace",
|
| 65 |
+
"turns smoothly to the left",
|
| 66 |
+
"pauses and surveys the surroundings",
|
| 67 |
+
"steps forward and gestures expressively",
|
| 68 |
+
"crouches down then rises back up",
|
| 69 |
+
"sidesteps to the right with purpose",
|
| 70 |
+
]
|
| 71 |
+
n = len(phrases)
|
| 72 |
+
return {
|
| 73 |
+
"texts": [phrases[(offset + i) % n] for i in range(3)],
|
| 74 |
+
"durations": [3.0, 3.5, 3.0],
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def call_qwen_for_prompts(
|
| 79 |
+
scene: str,
|
| 80 |
+
history: list[str],
|
| 81 |
+
requested_actions: int = 5,
|
| 82 |
+
) -> tuple[dict, list[str]]:
|
| 83 |
+
"""Call Qwen to produce the next batch of motion prompts.
|
| 84 |
+
|
| 85 |
+
Returns (batch_dict, updated_history).
|
| 86 |
+
batch_dict has keys "texts" and "durations".
|
| 87 |
+
Raises RuntimeError on API failure (caller may fall back).
|
| 88 |
+
"""
|
| 89 |
+
user_msg = (
|
| 90 |
+
f"Scene: {scene or 'a character moving continuously in 3D space'}\n"
|
| 91 |
+
f"Motion history (do not repeat): {json.dumps(history[-12:])}\n\n"
|
| 92 |
+
f"requested_actions: {max(1, min(10, int(requested_actions)))}\n"
|
| 93 |
+
"Generate the next batch of motion prompts."
|
| 94 |
+
)
|
| 95 |
+
try:
|
| 96 |
+
raw = _call_fireworks([{"role": "system", "content": _SYSTEM}, {"role": "user", "content": user_msg}])
|
| 97 |
+
batch = _parse(raw)
|
| 98 |
+
if not isinstance(batch.get("texts"), list) or not isinstance(batch.get("durations"), list):
|
| 99 |
+
raise ValueError("Missing texts or durations")
|
| 100 |
+
n = min(len(batch["texts"]), len(batch["durations"]))
|
| 101 |
+
batch["texts"] = batch["texts"][:n]
|
| 102 |
+
batch["durations"] = batch["durations"][:n]
|
| 103 |
+
except Exception:
|
| 104 |
+
batch = _fallback(len(history))
|
| 105 |
+
|
| 106 |
+
new_history = history + list(batch["texts"])
|
| 107 |
+
return batch, new_history
|
app.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# ============================================================================
|
| 5 |
+
# CRITICAL: Import spaces FIRST, before any CUDA-related packages.
|
| 6 |
+
# This prevents "CUDA has been initialized before importing spaces" error.
|
| 7 |
+
# ============================================================================
|
| 8 |
+
try:
|
| 9 |
+
import spaces # noqa: F401 - imported early for ZeroGPU compatibility
|
| 10 |
+
except ImportError:
|
| 11 |
+
pass # Not running on HF Spaces with ZeroGPU
|
| 12 |
+
|
| 13 |
+
import base64
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import threading
|
| 18 |
+
import time
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
import viser
|
| 25 |
+
from kimodo.assets import DEMO_ASSETS_ROOT
|
| 26 |
+
from kimodo.model.load_model import load_model
|
| 27 |
+
from kimodo.model.registry import resolve_model_name
|
| 28 |
+
from kimodo.runtime.device import select_runtime_device
|
| 29 |
+
from kimodo.skeleton import SkeletonBase, SOMASkeleton30
|
| 30 |
+
from kimodo.tools import load_json
|
| 31 |
+
from kimodo.viz import viser_utils
|
| 32 |
+
from kimodo.viz.viser_utils import (
|
| 33 |
+
Character,
|
| 34 |
+
CharacterMotion,
|
| 35 |
+
EEJointsKeyframeSet,
|
| 36 |
+
FullbodyKeyframeSet,
|
| 37 |
+
RootKeyframe2DSet,
|
| 38 |
+
)
|
| 39 |
+
from viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage
|
| 40 |
+
|
| 41 |
+
from . import generation, ui
|
| 42 |
+
from .config import (
|
| 43 |
+
DARK_THEME,
|
| 44 |
+
DEFAULT_CUR_DURATION,
|
| 45 |
+
DEFAULT_MODEL,
|
| 46 |
+
DEFAULT_PLAYBACK_SPEED,
|
| 47 |
+
DEFAULT_PROMPT,
|
| 48 |
+
DEMO_UI_QUICK_START_MODAL_MD,
|
| 49 |
+
EXAMPLES_ROOT_DIR,
|
| 50 |
+
HF_MODE,
|
| 51 |
+
LIGHT_THEME,
|
| 52 |
+
MAX_ACTIVE_USERS,
|
| 53 |
+
MAX_DURATION,
|
| 54 |
+
MAX_SESSION_MINUTES,
|
| 55 |
+
MIN_DURATION,
|
| 56 |
+
MODEL_EXAMPLES_DIRS,
|
| 57 |
+
MODEL_NAMES,
|
| 58 |
+
SERVER_NAME,
|
| 59 |
+
SERVER_PORT,
|
| 60 |
+
)
|
| 61 |
+
from .embedding_cache import CachedTextEncoder
|
| 62 |
+
from .queue_manager import QueueManager, UserQueue
|
| 63 |
+
from .state import ClientSession, ModelBundle
|
| 64 |
+
|
| 65 |
+
# Hosted runtimes (HF/Cloud Run) often send non-WS probes to the WS endpoint.
|
| 66 |
+
# Suppress noisy stack traces for these expected invalid handshakes.
|
| 67 |
+
logging.getLogger("websockets.server").setLevel(logging.CRITICAL)
|
| 68 |
+
logging.getLogger("websockets.asyncio.server").setLevel(logging.CRITICAL)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Demo:
|
| 72 |
+
def __init__(self, default_model_name: str = DEFAULT_MODEL):
|
| 73 |
+
# In hosted HF runtimes (including ZeroGPU), touching CUDA too early can
|
| 74 |
+
# crash startup before queue-managed inference starts.
|
| 75 |
+
requested_device = os.getenv("KIMODO_DEVICE")
|
| 76 |
+
running_in_space = bool(os.getenv("SPACE_ID")) or os.getenv("SYSTEM", "").strip().lower() == "spaces"
|
| 77 |
+
if requested_device is None and (HF_MODE or running_in_space):
|
| 78 |
+
requested_device = "cpu"
|
| 79 |
+
self.device = select_runtime_device(requested=requested_device)
|
| 80 |
+
print(f"Using device: {self.device}")
|
| 81 |
+
self.models: dict[str, ModelBundle] = {}
|
| 82 |
+
resolved = resolve_model_name(default_model_name, "Kimodo")
|
| 83 |
+
if resolved not in MODEL_NAMES:
|
| 84 |
+
raise ValueError(f"Unknown model '{default_model_name}'. Expected one of: {MODEL_NAMES}")
|
| 85 |
+
self.default_model_name = resolved
|
| 86 |
+
self.defer_model_load = os.getenv("KIMODO_DEFER_MODEL_LOAD", "true").strip().lower() in {
|
| 87 |
+
"1",
|
| 88 |
+
"true",
|
| 89 |
+
"yes",
|
| 90 |
+
"on",
|
| 91 |
+
}
|
| 92 |
+
self.ensure_examples_layout()
|
| 93 |
+
if self.defer_model_load:
|
| 94 |
+
print("Deferring model load until first active client session.")
|
| 95 |
+
else:
|
| 96 |
+
self.load_model(self.default_model_name)
|
| 97 |
+
|
| 98 |
+
# Serialize GPU-bound generation across all clients
|
| 99 |
+
self._generation_lock = threading.Lock()
|
| 100 |
+
self._cuda_healthy = True
|
| 101 |
+
|
| 102 |
+
# Per-client sessions
|
| 103 |
+
self.client_sessions: dict[int, ClientSession] = {}
|
| 104 |
+
self.start_direction_markers: dict[int, viser_utils.WaypointMesh] = {}
|
| 105 |
+
self.grid_handles: dict[int, viser.GridHandle] = {}
|
| 106 |
+
|
| 107 |
+
self.server = viser.ViserServer(
|
| 108 |
+
host=SERVER_NAME,
|
| 109 |
+
port=SERVER_PORT,
|
| 110 |
+
label="Kimodo",
|
| 111 |
+
enable_camera_keyboard_controls=False, # don't move the camera with the arrow keys
|
| 112 |
+
)
|
| 113 |
+
self.server.scene.world_axes.visible = False # used for debugging
|
| 114 |
+
self.server.scene.set_up_direction("+y")
|
| 115 |
+
|
| 116 |
+
# Register callbacks for session handling
|
| 117 |
+
self.server.on_client_connect(self.on_client_connect)
|
| 118 |
+
self.server.on_client_disconnect(self.on_client_disconnect)
|
| 119 |
+
|
| 120 |
+
# HF mode: queue and session limit
|
| 121 |
+
if HF_MODE:
|
| 122 |
+
self.user_queue = UserQueue(MAX_ACTIVE_USERS, MAX_SESSION_MINUTES)
|
| 123 |
+
self.queue_manager = QueueManager(
|
| 124 |
+
queue=self.user_queue,
|
| 125 |
+
server=self.server,
|
| 126 |
+
setup_demo_for_client=self._setup_demo_for_client,
|
| 127 |
+
cleanup_session=self._cleanup_session_for_client,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
self.user_queue = None
|
| 131 |
+
self.queue_manager = None
|
| 132 |
+
|
| 133 |
+
# create grid and floor
|
| 134 |
+
self.floor_len = 20.0 # meters
|
| 135 |
+
|
| 136 |
+
def ensure_examples_layout(self) -> None:
|
| 137 |
+
os.makedirs(EXAMPLES_ROOT_DIR, exist_ok=True)
|
| 138 |
+
for model_dir in MODEL_EXAMPLES_DIRS.values():
|
| 139 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
for entry in os.listdir(EXAMPLES_ROOT_DIR):
|
| 142 |
+
if entry in MODEL_EXAMPLES_DIRS:
|
| 143 |
+
continue
|
| 144 |
+
src = os.path.join(EXAMPLES_ROOT_DIR, entry)
|
| 145 |
+
if not os.path.isdir(src):
|
| 146 |
+
continue
|
| 147 |
+
dst = os.path.join(
|
| 148 |
+
MODEL_EXAMPLES_DIRS.get(DEFAULT_MODEL, next(iter(MODEL_EXAMPLES_DIRS.values()))),
|
| 149 |
+
entry,
|
| 150 |
+
)
|
| 151 |
+
if not os.path.exists(dst):
|
| 152 |
+
shutil.move(src, dst)
|
| 153 |
+
|
| 154 |
+
def get_examples_base_dir(self, model_name: str, absolute: bool = True) -> str:
|
| 155 |
+
return MODEL_EXAMPLES_DIRS[model_name]
|
| 156 |
+
|
| 157 |
+
def load_model(self, model_name: str) -> ModelBundle:
|
| 158 |
+
if model_name in self.models:
|
| 159 |
+
return self.models[model_name]
|
| 160 |
+
|
| 161 |
+
print(f"Loading model {model_name}...")
|
| 162 |
+
try:
|
| 163 |
+
model = load_model(modelname=model_name, device=self.device)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(
|
| 166 |
+
"Error loading model during Kimodo startup. "
|
| 167 |
+
"This often means the text encoder server is not running, the Hugging Face token is missing, "
|
| 168 |
+
"or the gated text encoder model cannot be accessed."
|
| 169 |
+
)
|
| 170 |
+
print(f"Original error: {type(e).__name__}: {e}")
|
| 171 |
+
raise e
|
| 172 |
+
|
| 173 |
+
if hasattr(model, "text_encoder"):
|
| 174 |
+
model.text_encoder = CachedTextEncoder(model.text_encoder, model_name=model_name)
|
| 175 |
+
|
| 176 |
+
skeleton = model.motion_rep.skeleton
|
| 177 |
+
if isinstance(skeleton, SOMASkeleton30):
|
| 178 |
+
skeleton = skeleton.somaskel77.to(model.device)
|
| 179 |
+
bundle = ModelBundle(
|
| 180 |
+
model=model,
|
| 181 |
+
motion_rep=model.motion_rep,
|
| 182 |
+
skeleton=skeleton,
|
| 183 |
+
model_fps=model.motion_rep.fps,
|
| 184 |
+
)
|
| 185 |
+
self.models[model_name] = bundle
|
| 186 |
+
print(f"Model {model_name} loaded successfully")
|
| 187 |
+
self.prewarm_embedding_cache(model_name, bundle.model)
|
| 188 |
+
return bundle
|
| 189 |
+
|
| 190 |
+
def prewarm_embedding_cache(self, model_name: str, model: object) -> None:
|
| 191 |
+
encoder = getattr(model, "text_encoder", None)
|
| 192 |
+
if not isinstance(encoder, CachedTextEncoder):
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
prompt_set = set()
|
| 196 |
+
prompt_set.add(DEFAULT_PROMPT)
|
| 197 |
+
|
| 198 |
+
examples_dir = MODEL_EXAMPLES_DIRS.get(model_name)
|
| 199 |
+
if examples_dir and os.path.isdir(examples_dir):
|
| 200 |
+
for entry in os.listdir(examples_dir):
|
| 201 |
+
example_dir = os.path.join(examples_dir, entry)
|
| 202 |
+
if not os.path.isdir(example_dir):
|
| 203 |
+
continue
|
| 204 |
+
meta_path = os.path.join(example_dir, "meta.json")
|
| 205 |
+
if not os.path.exists(meta_path):
|
| 206 |
+
continue
|
| 207 |
+
try:
|
| 208 |
+
meta = load_json(meta_path)
|
| 209 |
+
except Exception:
|
| 210 |
+
continue
|
| 211 |
+
for prompt in meta.get("prompts_text", []):
|
| 212 |
+
if isinstance(prompt, str):
|
| 213 |
+
prompt_set.add(prompt)
|
| 214 |
+
|
| 215 |
+
if prompt_set:
|
| 216 |
+
try:
|
| 217 |
+
encoder.prewarm(list(prompt_set))
|
| 218 |
+
except Exception as error:
|
| 219 |
+
# Startup should not fail if text encoder is still warming up.
|
| 220 |
+
error_str = str(error)
|
| 221 |
+
if "Encoder initialization failed" in error_str:
|
| 222 |
+
print(
|
| 223 |
+
f"⚠️ WARNING: Text encoder failed to initialize: {error}\n"
|
| 224 |
+
f" This usually means the HuggingFace gated model cannot be accessed.\n"
|
| 225 |
+
f" To fix: Set HF_TOKEN environment variable with access to Meta-Llama-3-8B.\n"
|
| 226 |
+
f" Alternatively: Generation will still work but text embeddings may fail."
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
print(f"Warning: embedding prewarm skipped: {error}")
|
| 230 |
+
|
| 231 |
+
def build_constraint_tracks(
|
| 232 |
+
self, client: viser.ClientHandle, skeleton: SkeletonBase
|
| 233 |
+
) -> dict[str, viser_utils.ConstraintSet]:
|
| 234 |
+
return {
|
| 235 |
+
"Full-Body": FullbodyKeyframeSet(
|
| 236 |
+
name="Full-Body",
|
| 237 |
+
server=client,
|
| 238 |
+
skeleton=skeleton,
|
| 239 |
+
),
|
| 240 |
+
"End-Effectors": EEJointsKeyframeSet(
|
| 241 |
+
name="End-Effectors",
|
| 242 |
+
server=client,
|
| 243 |
+
skeleton=skeleton,
|
| 244 |
+
),
|
| 245 |
+
"2D Root": RootKeyframe2DSet(
|
| 246 |
+
name="2D Root",
|
| 247 |
+
server=client,
|
| 248 |
+
skeleton=skeleton,
|
| 249 |
+
),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
def set_timeline_defaults(self, timeline, model_fps: float) -> None:
|
| 253 |
+
timeline.set_defaults(
|
| 254 |
+
default_text=DEFAULT_PROMPT,
|
| 255 |
+
default_duration=int(DEFAULT_CUR_DURATION * model_fps - 1),
|
| 256 |
+
min_duration=int(MIN_DURATION * model_fps - 1), # 2 seconds minimum,
|
| 257 |
+
max_duration=int(
|
| 258 |
+
MAX_DURATION * model_fps - 1 # - NB_TRANSITION_FRAMES
|
| 259 |
+
), # 10 seconds maximum, minus the transition frames, if needed
|
| 260 |
+
default_num_frames_zoom=int(1.10 * 10 * model_fps), # a bit more than the max
|
| 261 |
+
max_frames_zoom=1000,
|
| 262 |
+
fps=model_fps,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def _apply_constraint_overlay_visibility(self, session: ClientSession) -> None:
|
| 266 |
+
"""Apply show-all vs show-only-current-frame to constraint overlays."""
|
| 267 |
+
only_frame = session.frame_idx if session.show_only_current_constraint else None
|
| 268 |
+
for constraint in session.constraints.values():
|
| 269 |
+
constraint.set_overlay_visibility(only_frame)
|
| 270 |
+
|
| 271 |
+
def set_constraint_tracks_visible(self, session: ClientSession, visible: bool) -> None:
|
| 272 |
+
timeline = session.client.timeline
|
| 273 |
+
timeline_data = session.timeline_data
|
| 274 |
+
if timeline_data.get("constraint_tracks_visible", True) == visible:
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
with timeline_data["keyframe_update_lock"]:
|
| 278 |
+
if visible:
|
| 279 |
+
for track_id, track_info in timeline_data["tracks"].items():
|
| 280 |
+
timeline.add_track(
|
| 281 |
+
track_info["name"],
|
| 282 |
+
track_type=track_info.get("track_type", "keyframe"),
|
| 283 |
+
color=track_info.get("color"),
|
| 284 |
+
height_scale=track_info.get("height_scale", 1.0),
|
| 285 |
+
uuid=track_id,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
for keyframe_id, keyframe_data in timeline_data["keyframes"].items():
|
| 289 |
+
timeline.add_keyframe(
|
| 290 |
+
track_id=keyframe_data["track_id"],
|
| 291 |
+
frame=keyframe_data["frame"],
|
| 292 |
+
value=keyframe_data.get("value"),
|
| 293 |
+
opacity=keyframe_data.get("opacity", 1.0),
|
| 294 |
+
locked=keyframe_data.get("locked", False),
|
| 295 |
+
uuid=keyframe_id,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
for interval_id, interval_data in timeline_data["intervals"].items():
|
| 299 |
+
timeline.add_interval(
|
| 300 |
+
track_id=interval_data["track_id"],
|
| 301 |
+
start_frame=interval_data["start_frame_idx"],
|
| 302 |
+
end_frame=interval_data["end_frame_idx"],
|
| 303 |
+
value=interval_data.get("value"),
|
| 304 |
+
opacity=interval_data.get("opacity", 1.0),
|
| 305 |
+
locked=interval_data.get("locked", False),
|
| 306 |
+
uuid=interval_id,
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
for track_id in list(timeline_data["tracks"].keys()):
|
| 310 |
+
timeline.remove_track(track_id)
|
| 311 |
+
|
| 312 |
+
timeline_data["constraint_tracks_visible"] = visible
|
| 313 |
+
|
| 314 |
+
def _cleanup_session_for_client(self, client_id: int) -> None:
|
| 315 |
+
"""Remove session and scene state for a client (e.g. on session expiry)."""
|
| 316 |
+
if client_id in self.client_sessions:
|
| 317 |
+
del self.client_sessions[client_id]
|
| 318 |
+
self.start_direction_markers.pop(client_id, None)
|
| 319 |
+
self.grid_handles.pop(client_id, None)
|
| 320 |
+
|
| 321 |
+
def _setup_demo_for_client(self, client: viser.ClientHandle) -> None:
|
| 322 |
+
"""Initialize scene, GUI, and session state for a client (no modals)."""
|
| 323 |
+
self.setup_scene(client)
|
| 324 |
+
|
| 325 |
+
model_bundle = self.load_model(self.default_model_name)
|
| 326 |
+
|
| 327 |
+
# Initialize each empty constraint track
|
| 328 |
+
constraint_tracks = self.build_constraint_tracks(client, model_bundle.skeleton)
|
| 329 |
+
|
| 330 |
+
# Create GUI elements for this client
|
| 331 |
+
(
|
| 332 |
+
gui_elements,
|
| 333 |
+
timeline_tracks,
|
| 334 |
+
example_dict,
|
| 335 |
+
gui_examples_dropdown,
|
| 336 |
+
gui_save_example_path_text,
|
| 337 |
+
gui_model_selector,
|
| 338 |
+
) = ui.create_gui(
|
| 339 |
+
demo=self,
|
| 340 |
+
client=client,
|
| 341 |
+
model_name=self.default_model_name,
|
| 342 |
+
model_fps=model_bundle.model_fps,
|
| 343 |
+
)
|
| 344 |
+
timeline_data = {
|
| 345 |
+
"tracks": timeline_tracks,
|
| 346 |
+
"tracks_ids": {val["name"]: key for key, val in timeline_tracks.items()},
|
| 347 |
+
"keyframes": {},
|
| 348 |
+
"intervals": {},
|
| 349 |
+
"keyframe_update_lock": threading.Lock(),
|
| 350 |
+
"keyframe_move_timers": {},
|
| 351 |
+
"pending_keyframe_moves": {}, # keyframe_id -> new_frame
|
| 352 |
+
"constraint_tracks_visible": True,
|
| 353 |
+
"dense_path_after_release_timer": None,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
# Initialize session state
|
| 357 |
+
cur_duration = DEFAULT_CUR_DURATION
|
| 358 |
+
max_frame_idx = int(cur_duration * model_bundle.model_fps - 1)
|
| 359 |
+
|
| 360 |
+
session = ClientSession(
|
| 361 |
+
client=client,
|
| 362 |
+
gui_elements=gui_elements,
|
| 363 |
+
motions={},
|
| 364 |
+
constraints=constraint_tracks,
|
| 365 |
+
timeline_data=timeline_data,
|
| 366 |
+
frame_idx=0,
|
| 367 |
+
playing=False,
|
| 368 |
+
playback_speed=DEFAULT_PLAYBACK_SPEED,
|
| 369 |
+
cur_duration=cur_duration,
|
| 370 |
+
max_frame_idx=max_frame_idx,
|
| 371 |
+
updating_motions=False,
|
| 372 |
+
edit_mode=False,
|
| 373 |
+
model_name=self.default_model_name,
|
| 374 |
+
model_fps=model_bundle.model_fps,
|
| 375 |
+
skeleton=model_bundle.skeleton,
|
| 376 |
+
motion_rep=model_bundle.motion_rep,
|
| 377 |
+
examples_base_dir=self.get_examples_base_dir(self.default_model_name, absolute=True),
|
| 378 |
+
example_dict=example_dict,
|
| 379 |
+
gui_examples_dropdown=gui_examples_dropdown,
|
| 380 |
+
gui_save_example_path_text=gui_save_example_path_text,
|
| 381 |
+
gui_model_selector=gui_model_selector,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
self.client_sessions[client.client_id] = session
|
| 385 |
+
|
| 386 |
+
# Initialize default character for this client
|
| 387 |
+
self.add_character_motion(client, session.skeleton)
|
| 388 |
+
|
| 389 |
+
def on_client_connect(self, client: viser.ClientHandle) -> None:
|
| 390 |
+
"""Initialize GUI and state for each new client."""
|
| 391 |
+
print(f"Client {client.client_id} connected")
|
| 392 |
+
|
| 393 |
+
if HF_MODE and self.queue_manager is not None:
|
| 394 |
+
self.queue_manager.on_client_connect(client)
|
| 395 |
+
else:
|
| 396 |
+
# Show quick start popup when a browser client connects (non-HF mode).
|
| 397 |
+
with client.gui.add_modal(
|
| 398 |
+
"Welcome — Quick Start",
|
| 399 |
+
size="xl",
|
| 400 |
+
show_close_button=True,
|
| 401 |
+
save_choice="kimodo.demo.quick_start_ack",
|
| 402 |
+
) as modal:
|
| 403 |
+
client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
|
| 404 |
+
client.gui.add_button("Got it (don't remind me again)").on_click(lambda _event: modal.close())
|
| 405 |
+
self._setup_demo_for_client(client)
|
| 406 |
+
|
| 407 |
+
def setup_scene(self, client: viser.ClientHandle) -> None:
|
| 408 |
+
self.configure_theme(client)
|
| 409 |
+
client.camera.position = np.array(
|
| 410 |
+
[2.7417358737841426, 1.8790455698853281, 7.675741569777456],
|
| 411 |
+
dtype=np.float64,
|
| 412 |
+
)
|
| 413 |
+
client.camera.look_at = np.array([0.0, 0.0, 0.0], dtype=np.float64)
|
| 414 |
+
client.camera.up_direction = np.array(
|
| 415 |
+
[-1.1102230246251568e-16, 1.0, 1.3596310734468913e-32],
|
| 416 |
+
dtype=np.float64,
|
| 417 |
+
)
|
| 418 |
+
client.camera.fov = np.deg2rad(45.0)
|
| 419 |
+
grid_handle = client.scene.add_grid(
|
| 420 |
+
"/grid",
|
| 421 |
+
width=self.floor_len,
|
| 422 |
+
height=self.floor_len,
|
| 423 |
+
wxyz=viser.transforms.SO3.from_x_radians(-np.pi / 2.0).wxyz,
|
| 424 |
+
position=(0.0, 0.0001, 0.0),
|
| 425 |
+
fade_distance=3 * self.floor_len,
|
| 426 |
+
section_color=LIGHT_THEME["grid"],
|
| 427 |
+
infinite_grid=True,
|
| 428 |
+
)
|
| 429 |
+
self.grid_handles[client.client_id] = grid_handle
|
| 430 |
+
# marker for origin
|
| 431 |
+
origin_waypoint = viser_utils.WaypointMesh(
|
| 432 |
+
"/origin_waypoint",
|
| 433 |
+
client,
|
| 434 |
+
position=np.array([0.0, 0.0, 0.0]),
|
| 435 |
+
heading=np.array([0.0, 1.0]),
|
| 436 |
+
color=(0, 0, 255),
|
| 437 |
+
)
|
| 438 |
+
self.start_direction_markers[client.client_id] = origin_waypoint
|
| 439 |
+
|
| 440 |
+
def on_client_disconnect(self, client: viser.ClientHandle) -> None:
|
| 441 |
+
"""Clean up when client disconnects."""
|
| 442 |
+
print(f"Client {client.client_id} disconnected")
|
| 443 |
+
client_id = client.client_id
|
| 444 |
+
|
| 445 |
+
if HF_MODE and self.queue_manager is not None:
|
| 446 |
+
self.queue_manager.on_client_disconnect(client_id)
|
| 447 |
+
|
| 448 |
+
self._cleanup_session_for_client(client_id)
|
| 449 |
+
|
| 450 |
+
def set_start_direction_visible(self, client_id: int, visible: bool) -> None:
|
| 451 |
+
marker = self.start_direction_markers.get(client_id)
|
| 452 |
+
if marker is None:
|
| 453 |
+
return
|
| 454 |
+
marker.set_visible(visible)
|
| 455 |
+
|
| 456 |
+
def client_active(self, client_id: int) -> bool:
|
| 457 |
+
return client_id in self.client_sessions
|
| 458 |
+
|
| 459 |
+
def add_character_motion(
|
| 460 |
+
self,
|
| 461 |
+
client: viser.ClientHandle,
|
| 462 |
+
skeleton: SkeletonBase,
|
| 463 |
+
joints_pos: Optional[torch.Tensor] = None,
|
| 464 |
+
joints_rot: Optional[torch.Tensor] = None,
|
| 465 |
+
foot_contacts: Optional[torch.Tensor] = None,
|
| 466 |
+
) -> None:
|
| 467 |
+
client_id = client.client_id
|
| 468 |
+
if not self.client_active(client_id):
|
| 469 |
+
return
|
| 470 |
+
session = self.client_sessions[client_id]
|
| 471 |
+
|
| 472 |
+
ci = len(session.motions)
|
| 473 |
+
character_name = f"character{ci}"
|
| 474 |
+
# build character skeleton and skinning mesh
|
| 475 |
+
if "g1" in session.model_name:
|
| 476 |
+
mesh_mode = "g1_stl"
|
| 477 |
+
elif "smplx" in session.model_name:
|
| 478 |
+
mesh_mode = "smplx_skin"
|
| 479 |
+
elif "soma" in session.model_name:
|
| 480 |
+
if session.gui_elements.gui_use_soma_layer_checkbox.value:
|
| 481 |
+
mesh_mode = "soma_layer_skin"
|
| 482 |
+
else:
|
| 483 |
+
mesh_mode = "soma_skin"
|
| 484 |
+
else:
|
| 485 |
+
raise ValueError("The model name is not recognized for skinning.")
|
| 486 |
+
|
| 487 |
+
new_character = Character(
|
| 488 |
+
character_name,
|
| 489 |
+
client,
|
| 490 |
+
skeleton,
|
| 491 |
+
create_skeleton_mesh=True,
|
| 492 |
+
create_skinned_mesh=True,
|
| 493 |
+
visible_skeleton=False, # don't show immediately
|
| 494 |
+
visible_skinned_mesh=False, # don't show immediately
|
| 495 |
+
skinned_mesh_opacity=session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value,
|
| 496 |
+
show_foot_contacts=session.gui_elements.gui_viz_foot_contacts_checkbox.value,
|
| 497 |
+
dark_mode=session.gui_elements.gui_dark_mode_checkbox.value,
|
| 498 |
+
mesh_mode=mesh_mode,
|
| 499 |
+
gui_use_soma_layer_checkbox=session.gui_elements.gui_use_soma_layer_checkbox,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# if no motion given, initialize to character default (rest) pose for one frame
|
| 503 |
+
init_joints_pos, init_joints_rot = new_character.get_pose()
|
| 504 |
+
if joints_pos is None:
|
| 505 |
+
joints_pos = init_joints_pos[None].repeat(session.max_frame_idx + 1, 1, 1)
|
| 506 |
+
if joints_rot is None:
|
| 507 |
+
joints_rot = init_joints_rot[None].repeat(session.max_frame_idx + 1, 1, 1, 1)
|
| 508 |
+
|
| 509 |
+
new_motion = CharacterMotion(new_character, joints_pos, joints_rot, foot_contacts)
|
| 510 |
+
# save the motion in our dict
|
| 511 |
+
session.motions[character_name] = new_motion
|
| 512 |
+
|
| 513 |
+
# put the character at the right frame
|
| 514 |
+
new_motion.set_frame(session.frame_idx)
|
| 515 |
+
|
| 516 |
+
# put them visible with a small delay
|
| 517 |
+
# so that the set_frame function has time to finish
|
| 518 |
+
def _set_visibility():
|
| 519 |
+
new_motion.character.set_skinned_mesh_visibility(session.gui_elements.gui_viz_skinned_mesh_checkbox.value)
|
| 520 |
+
new_motion.character.set_skeleton_visibility(session.gui_elements.gui_viz_skeleton_checkbox.value)
|
| 521 |
+
|
| 522 |
+
timer = threading.Timer(
|
| 523 |
+
0.2, # 0.2s delay
|
| 524 |
+
_set_visibility,
|
| 525 |
+
)
|
| 526 |
+
timer.start()
|
| 527 |
+
|
| 528 |
+
def clear_motions(self, client_id: int) -> None:
|
| 529 |
+
if not self.client_active(client_id):
|
| 530 |
+
return
|
| 531 |
+
session = self.client_sessions[client_id]
|
| 532 |
+
for motion in list(session.motions.values()):
|
| 533 |
+
motion.clear()
|
| 534 |
+
session.motions.clear()
|
| 535 |
+
|
| 536 |
+
def compute_model_constraints_lst(
|
| 537 |
+
self,
|
| 538 |
+
session: ClientSession,
|
| 539 |
+
model_bundle: ModelBundle,
|
| 540 |
+
num_frames: int,
|
| 541 |
+
):
|
| 542 |
+
return generation.compute_model_constraints_lst(session, model_bundle, num_frames, self.device)
|
| 543 |
+
|
| 544 |
+
def check_cuda_health(self) -> bool:
|
| 545 |
+
"""Check if CUDA is still functional.
|
| 546 |
+
|
| 547 |
+
Trigger auto-restart if corrupted.
|
| 548 |
+
"""
|
| 549 |
+
if self.device == "cpu":
|
| 550 |
+
return True
|
| 551 |
+
try:
|
| 552 |
+
torch.tensor([1.0], device=self.device) + torch.tensor([1.0], device=self.device)
|
| 553 |
+
return True
|
| 554 |
+
except RuntimeError as e:
|
| 555 |
+
if "device-side assert" in str(e) or "CUDA error" in str(e):
|
| 556 |
+
if self._cuda_healthy:
|
| 557 |
+
self._cuda_healthy = False
|
| 558 |
+
print("FATAL: CUDA context is corrupted (device-side assert). " "The process must be restarted.")
|
| 559 |
+
self._trigger_restart()
|
| 560 |
+
return False
|
| 561 |
+
raise
|
| 562 |
+
|
| 563 |
+
def _trigger_restart(self) -> None:
|
| 564 |
+
"""Exit the process so the HF Space (or systemd/Docker) can restart it."""
|
| 565 |
+
import sys
|
| 566 |
+
|
| 567 |
+
print("Initiating automatic restart due to unrecoverable CUDA error...")
|
| 568 |
+
sys.stdout.flush()
|
| 569 |
+
sys.stderr.flush()
|
| 570 |
+
os._exit(1)
|
| 571 |
+
|
| 572 |
+
def generate(
|
| 573 |
+
self,
|
| 574 |
+
client: viser.ClientHandle,
|
| 575 |
+
prompts: list[str],
|
| 576 |
+
num_frames: list[int],
|
| 577 |
+
num_samples: int,
|
| 578 |
+
seed: int,
|
| 579 |
+
diffusion_steps: int,
|
| 580 |
+
cfg_weight: Optional[list[float]] = None,
|
| 581 |
+
cfg_type: Optional[str] = None,
|
| 582 |
+
postprocess_parameters: Optional[dict] = None,
|
| 583 |
+
transitions_parameters: Optional[dict] = None,
|
| 584 |
+
real_robot_rotations: bool = False,
|
| 585 |
+
) -> None:
|
| 586 |
+
if not self._cuda_healthy:
|
| 587 |
+
raise RuntimeError("CUDA is in a corrupted state. The space is restarting...")
|
| 588 |
+
|
| 589 |
+
locked = self._generation_lock.acquire(blocking=False)
|
| 590 |
+
if not locked:
|
| 591 |
+
waiting_notif = client.add_notification(
|
| 592 |
+
title="Waiting for GPU...",
|
| 593 |
+
body="Another generation is in progress. Yours will start automatically.",
|
| 594 |
+
loading=True,
|
| 595 |
+
with_close_button=False,
|
| 596 |
+
)
|
| 597 |
+
self._generation_lock.acquire()
|
| 598 |
+
waiting_notif.remove()
|
| 599 |
+
|
| 600 |
+
try:
|
| 601 |
+
session = self.client_sessions[client.client_id]
|
| 602 |
+
model_bundle = self.load_model(session.model_name)
|
| 603 |
+
generation.generate(
|
| 604 |
+
client=client,
|
| 605 |
+
session=session,
|
| 606 |
+
model_bundle=model_bundle,
|
| 607 |
+
prompts=prompts,
|
| 608 |
+
num_frames=num_frames,
|
| 609 |
+
num_samples=num_samples,
|
| 610 |
+
seed=seed,
|
| 611 |
+
diffusion_steps=diffusion_steps,
|
| 612 |
+
cfg_weight=cfg_weight,
|
| 613 |
+
cfg_type=cfg_type,
|
| 614 |
+
postprocess_parameters=postprocess_parameters,
|
| 615 |
+
transitions_parameters=transitions_parameters,
|
| 616 |
+
real_robot_rotations=real_robot_rotations,
|
| 617 |
+
device=self.device,
|
| 618 |
+
clear_motions=self.clear_motions,
|
| 619 |
+
add_character_motion=self.add_character_motion,
|
| 620 |
+
)
|
| 621 |
+
finally:
|
| 622 |
+
self._generation_lock.release()
|
| 623 |
+
|
| 624 |
+
def set_frame(self, client_id: int, frame_idx: int, update_timeline: bool = True):
|
| 625 |
+
if not self.client_active(client_id):
|
| 626 |
+
return
|
| 627 |
+
|
| 628 |
+
session = self.client_sessions[client_id]
|
| 629 |
+
|
| 630 |
+
session.frame_idx = frame_idx
|
| 631 |
+
if update_timeline:
|
| 632 |
+
session.client.timeline.set_current_frame(frame_idx)
|
| 633 |
+
for motion in list(session.motions.values()):
|
| 634 |
+
motion.set_frame(frame_idx)
|
| 635 |
+
self._apply_constraint_overlay_visibility(session)
|
| 636 |
+
|
| 637 |
+
def run(self) -> None:
|
| 638 |
+
last_loop_time = time.perf_counter()
|
| 639 |
+
last_cuda_check_time = 0.0
|
| 640 |
+
while True:
|
| 641 |
+
loop_start_time = time.perf_counter()
|
| 642 |
+
delta_time = loop_start_time - last_loop_time
|
| 643 |
+
last_loop_time = loop_start_time
|
| 644 |
+
|
| 645 |
+
if self.models:
|
| 646 |
+
# the max playback speed is 2x the model fps (from gui_playback_speed_buttons)
|
| 647 |
+
playback_fps = max(bundle.model_fps for bundle in self.models.values()) * 2.0
|
| 648 |
+
else:
|
| 649 |
+
playback_fps = 60.0
|
| 650 |
+
|
| 651 |
+
# update each client session independently
|
| 652 |
+
# copy to a list first to avoid changing size if client disconnects
|
| 653 |
+
for client_id, session in list(self.client_sessions.items()):
|
| 654 |
+
if not session.playing:
|
| 655 |
+
continue
|
| 656 |
+
if session.model_fps <= 0:
|
| 657 |
+
continue
|
| 658 |
+
|
| 659 |
+
# Time-based stepping keeps playback smooth even if loop cadence jitters.
|
| 660 |
+
session.playback_time_accumulator += max(0.0, delta_time) * max(0.0, session.playback_speed)
|
| 661 |
+
frame_period = 1.0 / session.model_fps
|
| 662 |
+
if session.playback_time_accumulator < frame_period:
|
| 663 |
+
continue
|
| 664 |
+
|
| 665 |
+
frames_to_advance = int(session.playback_time_accumulator / frame_period)
|
| 666 |
+
session.playback_time_accumulator -= frames_to_advance * frame_period
|
| 667 |
+
frame_count = max(1, session.max_frame_idx + 1)
|
| 668 |
+
new_frame_idx = (session.frame_idx + frames_to_advance) % frame_count
|
| 669 |
+
|
| 670 |
+
# make sure the client is still active before updating the frame
|
| 671 |
+
if self.client_active(client_id):
|
| 672 |
+
self.set_frame(client_id, new_frame_idx)
|
| 673 |
+
|
| 674 |
+
if loop_start_time - last_cuda_check_time >= 5.0:
|
| 675 |
+
self.check_cuda_health()
|
| 676 |
+
last_cuda_check_time = loop_start_time
|
| 677 |
+
|
| 678 |
+
time_remaining = max(0.0, 1.0 / playback_fps - (time.perf_counter() - loop_start_time))
|
| 679 |
+
time.sleep(time_remaining)
|
| 680 |
+
|
| 681 |
+
def configure_theme(
|
| 682 |
+
self,
|
| 683 |
+
client: viser.ClientHandle,
|
| 684 |
+
dark_mode: bool = False,
|
| 685 |
+
titlebar_dark_mode_checkbox_uuid: str | None = None,
|
| 686 |
+
):
|
| 687 |
+
# Sync grid color with theme (light vs dark)
|
| 688 |
+
theme = DARK_THEME if dark_mode else LIGHT_THEME
|
| 689 |
+
grid_handle = self.grid_handles.get(client.client_id)
|
| 690 |
+
if grid_handle is not None:
|
| 691 |
+
grid_handle.section_color = theme["grid"]
|
| 692 |
+
|
| 693 |
+
#
|
| 694 |
+
# setup theme
|
| 695 |
+
#
|
| 696 |
+
buttons = (
|
| 697 |
+
TitlebarButton(
|
| 698 |
+
text="Documentation",
|
| 699 |
+
icon="Description",
|
| 700 |
+
href="https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html",
|
| 701 |
+
),
|
| 702 |
+
TitlebarButton(
|
| 703 |
+
text="Project Page",
|
| 704 |
+
icon=None,
|
| 705 |
+
href="https://research.nvidia.com/labs/sil/projects/kimodo/",
|
| 706 |
+
),
|
| 707 |
+
TitlebarButton(
|
| 708 |
+
text="Github",
|
| 709 |
+
icon="GitHub",
|
| 710 |
+
href="https://github.com/nv-tlabs/kimodo",
|
| 711 |
+
),
|
| 712 |
+
)
|
| 713 |
+
assets_dir = DEMO_ASSETS_ROOT
|
| 714 |
+
logo_light_path = assets_dir / "nvidia_logo.png"
|
| 715 |
+
logo_dark_path = assets_dir / "nvidia_logo_dark.png"
|
| 716 |
+
if logo_light_path.exists():
|
| 717 |
+
light_b64 = base64.standard_b64encode(logo_light_path.read_bytes()).decode("ascii")
|
| 718 |
+
dark_b64 = (
|
| 719 |
+
base64.standard_b64encode(logo_dark_path.read_bytes()).decode("ascii")
|
| 720 |
+
if logo_dark_path.exists()
|
| 721 |
+
else None
|
| 722 |
+
)
|
| 723 |
+
image = TitlebarImage(
|
| 724 |
+
image_url_light=f"data:image/png;base64,{light_b64}",
|
| 725 |
+
image_url_dark=(f"data:image/png;base64,{dark_b64}" if dark_b64 else None),
|
| 726 |
+
image_alt="NVIDIA",
|
| 727 |
+
href="https://www.nvidia.com/",
|
| 728 |
+
)
|
| 729 |
+
else:
|
| 730 |
+
image = None
|
| 731 |
+
titlebar_theme = TitlebarConfig(buttons=buttons, image=image, title_text="Movimento")
|
| 732 |
+
client.gui.set_panel_label("Movimento")
|
| 733 |
+
client.gui.configure_theme(
|
| 734 |
+
titlebar_content=titlebar_theme,
|
| 735 |
+
control_layout="floating", # "floating", # ['floating', 'collapsible', 'fixed']
|
| 736 |
+
control_width="large", # ['small', 'medium', 'large']
|
| 737 |
+
dark_mode=dark_mode,
|
| 738 |
+
show_logo=False, # hide viser logo on bottom left corner
|
| 739 |
+
show_share_button=False,
|
| 740 |
+
titlebar_dark_mode_checkbox_uuid=titlebar_dark_mode_checkbox_uuid,
|
| 741 |
+
brand_color=(152, 189, 255), # (60, 131, 0), # (R, G, B) tuple
|
| 742 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from kimodo.assets import DEMO_EXAMPLES_ROOT
|
| 7 |
+
from kimodo.model.registry import (
|
| 8 |
+
AVAILABLE_MODELS,
|
| 9 |
+
DEFAULT_MODEL,
|
| 10 |
+
FRIENDLY_NAMES,
|
| 11 |
+
get_datasets,
|
| 12 |
+
get_model_info,
|
| 13 |
+
get_models_for_dataset_skeleton,
|
| 14 |
+
get_short_key_from_display_name,
|
| 15 |
+
get_skeleton_display_name,
|
| 16 |
+
get_skeleton_display_names_for_dataset,
|
| 17 |
+
get_skeleton_key_from_display_name,
|
| 18 |
+
get_skeletons_for_dataset,
|
| 19 |
+
get_versions_for_dataset_skeleton,
|
| 20 |
+
resolve_to_short_key,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
SERVER_NAME = os.environ.get("SERVER_NAME", "0.0.0.0")
|
| 24 |
+
SERVER_PORT = int(os.environ.get("SERVER_PORT") or os.environ.get("PORT", "7860"))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _env_bool(name: str, default: bool = False) -> bool:
|
| 28 |
+
raw = os.environ.get(name)
|
| 29 |
+
if raw is None:
|
| 30 |
+
return default
|
| 31 |
+
return str(raw).strip().lower() in ("1", "true", "yes", "on")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
HF_MODE = _env_bool("HF_MODE", False)
|
| 35 |
+
|
| 36 |
+
# HF mode: user queue and session limit (override via env in Spaces)
|
| 37 |
+
MAX_ACTIVE_USERS = int(os.environ.get("MAX_ACTIVE_USERS", "5"))
|
| 38 |
+
MAX_SESSION_MINUTES = float(os.environ.get("MAX_SESSION_MINUTES", "5.0"))
|
| 39 |
+
|
| 40 |
+
DEFAULT_PLAYBACK_SPEED = 1.0
|
| 41 |
+
# default start duration is 6.0 sec, but model can handle up to 10 sec
|
| 42 |
+
DEFAULT_CUR_DURATION = 6.0
|
| 43 |
+
DEFAULT_PROMPT = "A person walks forward."
|
| 44 |
+
MIN_DURATION = 2.0
|
| 45 |
+
MAX_DURATION = 10.0
|
| 46 |
+
|
| 47 |
+
SHOW_TRANSITION_PARAMS = False
|
| 48 |
+
INIT_POSTPROCESSING = True
|
| 49 |
+
NB_TRANSITION_FRAMES = 5
|
| 50 |
+
|
| 51 |
+
LIGHT_THEME = dict(
|
| 52 |
+
floor=(220, 220, 220),
|
| 53 |
+
grid=(180, 180, 180),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Dark theme: slightly lighter grid and floor for better visibility and less flat black
|
| 57 |
+
DARK_THEME = dict(
|
| 58 |
+
floor=(48, 48, 52),
|
| 59 |
+
grid=(105, 105, 110),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
EXAMPLES_ROOT_DIR = str(DEMO_EXAMPLES_ROOT)
|
| 63 |
+
|
| 64 |
+
# Model list and paths from kimodo registry (all models: Kimodo + TMR)
|
| 65 |
+
MODEL_NAMES = tuple(AVAILABLE_MODELS)
|
| 66 |
+
MODEL_EXAMPLES_DIRS = {name: os.path.join(EXAMPLES_ROOT_DIR, name) for name in MODEL_NAMES}
|
| 67 |
+
# Display labels for backward compatibility (short_key -> display name)
|
| 68 |
+
MODEL_LABELS = {name: FRIENDLY_NAMES.get(name, f"Model ({name})") for name in MODEL_NAMES}
|
| 69 |
+
MODEL_LABEL_TO_NAME = {label: name for name, label in MODEL_LABELS.items()}
|
| 70 |
+
|
| 71 |
+
# -----------------------------------------------------------------------------
|
| 72 |
+
# Demo UI copy
|
| 73 |
+
# -----------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
DEMO_UI_QUICK_START_CORE_MD = """
|
| 76 |
+
### Camera
|
| 77 |
+
- **Left-drag**: rotate
|
| 78 |
+
- **Right-drag**: pan
|
| 79 |
+
- **Scroll**: zoom
|
| 80 |
+
|
| 81 |
+
### Playback
|
| 82 |
+
- **Space** to play/pause
|
| 83 |
+
- **←/→** to step frames, or click the frame number.
|
| 84 |
+
- **Scroll up/down** in the timeline: move left/right
|
| 85 |
+
- **Shift + scroll** in the timeline: zoom in/out
|
| 86 |
+
|
| 87 |
+
### Prompts
|
| 88 |
+
- **Double-click** a text prompt to edit it.
|
| 89 |
+
- **Click and drag** the right edge of a prompt box to extend/shorten it.
|
| 90 |
+
- **Click empty space** to add a prompt.
|
| 91 |
+
- **Right-click** a prompt to delete it.
|
| 92 |
+
|
| 93 |
+
### Generate
|
| 94 |
+
- Go to the **Generate** tab to modify options
|
| 95 |
+
- It is also possible to **load** examples
|
| 96 |
+
- Click **Generate** to generate a motion
|
| 97 |
+
|
| 98 |
+
### Constraints
|
| 99 |
+
- This is **optional**: should be use after a first generation
|
| 100 |
+
- **Click** in the timeline tracks (Full-Body / 2D root etc) to add a constraint.
|
| 101 |
+
- **Right-click** on a constraint to delete it.
|
| 102 |
+
- To **edit** a constraint:
|
| 103 |
+
- Move playback to the target frame
|
| 104 |
+
- Click **Enter Editing Mode** in the Constraints tab.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
DEMO_UI_QUICK_START_MODAL_MD = (
|
| 108 |
+
DEMO_UI_QUICK_START_CORE_MD
|
| 109 |
+
+ """
|
| 110 |
+
|
| 111 |
+
See the **Instructions** tab for the full user manual.
|
| 112 |
+
"""
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
DEMO_UI_INSTRUCTIONS_TAB_MD = (
|
| 116 |
+
"""
|
| 117 |
+
## How to Use This Demo
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
+ DEMO_UI_QUICK_START_CORE_MD
|
| 121 |
+
+ """
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
### Generating Motion (step-by-step)
|
| 126 |
+
|
| 127 |
+
1. **Edit the text prompts** in the timeline (e.g., "A person walks forward.")
|
| 128 |
+
2. **Modify the duration** by moving the right edge of each prompts (2–10 seconds)
|
| 129 |
+
3. **Add constraints** (optional) to control the motion:
|
| 130 |
+
- Click **Enter Editing Mode** to adjust the character pose
|
| 131 |
+
- Use the timeline to place keyframes or intervals in constraint tracks (see below)
|
| 132 |
+
4. **Click Generate** to create the motion
|
| 133 |
+
5. If generating multiple samples, **click on a mesh** to select which one to keep
|
| 134 |
+
|
| 135 |
+
### Timeline Editing
|
| 136 |
+
|
| 137 |
+
**Adding Constraints:**
|
| 138 |
+
1. Click anywhere on the timeline to add a keyframe at that frame. The keyframe is created based on the current character motion.
|
| 139 |
+
2. Ctrl/Cmd+click+drag to add an interval constraint, or expand a keyframe into an interval
|
| 140 |
+
3. Enter editing mode with the **Enter Editing Mode** button to adjust character pose before/after adding constraints.
|
| 141 |
+
|
| 142 |
+
**Constraint Types:**
|
| 143 |
+
- **Full-Body**: constrains the entire character pose
|
| 144 |
+
- **2D Root**: constrains the character's path on the ground plane
|
| 145 |
+
- Enable **Densify** to create a continuous path
|
| 146 |
+
- **End-Effectors**: constrains hands and feet positions
|
| 147 |
+
- Use separate tracks for Left/Right Hand/Foot
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
**Moving & Deleting:**
|
| 151 |
+
- **Drag keyframes/intervals** to move them to different frames
|
| 152 |
+
- **Right-click** a keyframe or interval to delete it
|
| 153 |
+
- Use **Clear All Constraints** to remove everything
|
| 154 |
+
|
| 155 |
+
**Tips:**
|
| 156 |
+
- The posing skeleton becomes visible in editing mode for precise positioning
|
| 157 |
+
- Use **Snap to constraint** to align the current frame to a constraint
|
| 158 |
+
|
| 159 |
+
### Saving & Loading
|
| 160 |
+
|
| 161 |
+
You can save the current constraints or current motion to load in later from the Load/Save menu.
|
| 162 |
+
Saving an **Example** will save the full constraints, motion, and generation metadata.
|
| 163 |
+
|
| 164 |
+
### Visualization Options
|
| 165 |
+
|
| 166 |
+
Switch to the **Visualize** tab to:
|
| 167 |
+
- Toggle mesh and skeleton visibility
|
| 168 |
+
- Adjust mesh opacity
|
| 169 |
+
- Show/hide foot contact indicators
|
| 170 |
+
- Switch between light and dark modes
|
| 171 |
+
"""
|
| 172 |
+
)
|
embedding_cache.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
import contextvars
|
| 6 |
+
import hashlib
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import threading
|
| 10 |
+
import time
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Iterable, Optional
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from kimodo.sanitize import sanitize_texts
|
| 19 |
+
|
| 20 |
+
_ACTIVE_SESSION = contextvars.ContextVar("kimodo_demo_active_session", default=None)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class CacheStats:
|
| 25 |
+
hits: int = 0
|
| 26 |
+
misses: int = 0
|
| 27 |
+
disk_hits: int = 0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class EmbeddingCache:
|
| 31 |
+
"""Disk-backed text embedding cache with a small in-memory LRU."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
model_name: str,
|
| 37 |
+
encoder_id: str,
|
| 38 |
+
base_dir: Optional[str] = None,
|
| 39 |
+
max_mem_entries: int = 128,
|
| 40 |
+
) -> None:
|
| 41 |
+
cache_root = base_dir or os.environ.get(
|
| 42 |
+
"kimodo_EMBED_CACHE_DIR",
|
| 43 |
+
os.path.join("~", ".cache", "kimodo_demo", "embeddings"),
|
| 44 |
+
)
|
| 45 |
+
self.base_dir = os.path.expanduser(cache_root)
|
| 46 |
+
self.model_name = model_name
|
| 47 |
+
self.encoder_id = encoder_id
|
| 48 |
+
self.max_mem_entries = max_mem_entries
|
| 49 |
+
self.stats = CacheStats()
|
| 50 |
+
|
| 51 |
+
self._lock = threading.Lock()
|
| 52 |
+
self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict()
|
| 53 |
+
self._index = {}
|
| 54 |
+
self._index_loaded = False
|
| 55 |
+
|
| 56 |
+
def _model_dir(self) -> str:
|
| 57 |
+
return os.path.join(self.base_dir, self.model_name)
|
| 58 |
+
|
| 59 |
+
def _index_path(self) -> str:
|
| 60 |
+
return os.path.join(self._model_dir(), "index.json")
|
| 61 |
+
|
| 62 |
+
def _prewarm_marker_path(self, key: str) -> str:
|
| 63 |
+
return os.path.join(self._model_dir(), f"prewarm_{key}.json")
|
| 64 |
+
|
| 65 |
+
def has_prewarm_marker(self, key: str) -> bool:
|
| 66 |
+
return os.path.exists(self._prewarm_marker_path(key))
|
| 67 |
+
|
| 68 |
+
def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None:
|
| 69 |
+
os.makedirs(self._model_dir(), exist_ok=True)
|
| 70 |
+
payload = {"prompt_count": prompt_count, "updated_at": time.time()}
|
| 71 |
+
tmp_path = f"{self._prewarm_marker_path(key)}.tmp"
|
| 72 |
+
with open(tmp_path, "w", encoding="utf-8") as f:
|
| 73 |
+
json.dump(payload, f)
|
| 74 |
+
os.replace(tmp_path, self._prewarm_marker_path(key))
|
| 75 |
+
|
| 76 |
+
def _load_index(self) -> None:
|
| 77 |
+
if self._index_loaded:
|
| 78 |
+
return
|
| 79 |
+
index_path = self._index_path()
|
| 80 |
+
if os.path.exists(index_path):
|
| 81 |
+
try:
|
| 82 |
+
with open(index_path, "r", encoding="utf-8") as f:
|
| 83 |
+
self._index = json.load(f)
|
| 84 |
+
except json.JSONDecodeError:
|
| 85 |
+
self._index = {}
|
| 86 |
+
self._index_loaded = True
|
| 87 |
+
|
| 88 |
+
def _save_index(self) -> None:
|
| 89 |
+
os.makedirs(self._model_dir(), exist_ok=True)
|
| 90 |
+
tmp_path = f"{self._index_path()}.tmp"
|
| 91 |
+
with open(tmp_path, "w", encoding="utf-8") as f:
|
| 92 |
+
json.dump(self._index, f)
|
| 93 |
+
os.replace(tmp_path, self._index_path())
|
| 94 |
+
|
| 95 |
+
def _make_key(self, text: str) -> str:
|
| 96 |
+
key_src = f"{self.model_name}|{self.encoder_id}|{text}"
|
| 97 |
+
return hashlib.sha256(key_src.encode("utf-8")).hexdigest()
|
| 98 |
+
|
| 99 |
+
def _entry_path(self, key: str) -> str:
|
| 100 |
+
return os.path.join(self._model_dir(), f"{key}.npy")
|
| 101 |
+
|
| 102 |
+
def _mem_get(self, key: str) -> Optional[np.ndarray]:
|
| 103 |
+
if key in self._mem_cache:
|
| 104 |
+
self._mem_cache.move_to_end(key)
|
| 105 |
+
return self._mem_cache[key]
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def _mem_put(self, key: str, value: np.ndarray) -> None:
|
| 109 |
+
self._mem_cache[key] = value
|
| 110 |
+
self._mem_cache.move_to_end(key)
|
| 111 |
+
while len(self._mem_cache) > self.max_mem_entries:
|
| 112 |
+
self._mem_cache.popitem(last=False)
|
| 113 |
+
|
| 114 |
+
def _disk_load(self, key: str) -> Optional[np.ndarray]:
|
| 115 |
+
path = self._entry_path(key)
|
| 116 |
+
if not os.path.exists(path):
|
| 117 |
+
return None
|
| 118 |
+
try:
|
| 119 |
+
return np.load(path)
|
| 120 |
+
except Exception:
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
def _disk_save(self, key: str, value: np.ndarray) -> None:
|
| 124 |
+
os.makedirs(self._model_dir(), exist_ok=True)
|
| 125 |
+
np.save(self._entry_path(key), value)
|
| 126 |
+
self._index[key] = {
|
| 127 |
+
"length": int(value.shape[0]),
|
| 128 |
+
"dtype": str(value.dtype),
|
| 129 |
+
"updated_at": time.time(),
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def _maybe_use_session_cache(self, texts: list[str]):
|
| 133 |
+
session = _ACTIVE_SESSION.get()
|
| 134 |
+
if session is None:
|
| 135 |
+
return None
|
| 136 |
+
if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None:
|
| 137 |
+
return session.last_prompt_embeddings, session.last_prompt_lengths
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None:
|
| 141 |
+
session = _ACTIVE_SESSION.get()
|
| 142 |
+
if session is None:
|
| 143 |
+
return
|
| 144 |
+
session.last_prompt_texts = texts
|
| 145 |
+
session.last_prompt_embeddings = tensor
|
| 146 |
+
session.last_prompt_lengths = lengths
|
| 147 |
+
|
| 148 |
+
def get_or_encode(self, texts: Iterable[str], encoder):
|
| 149 |
+
if isinstance(texts, str):
|
| 150 |
+
texts = [texts]
|
| 151 |
+
texts = sanitize_texts(list(texts))
|
| 152 |
+
if len(texts) == 0:
|
| 153 |
+
empty = torch.empty()
|
| 154 |
+
return empty, []
|
| 155 |
+
|
| 156 |
+
session_cache = self._maybe_use_session_cache(texts)
|
| 157 |
+
if session_cache is not None:
|
| 158 |
+
return session_cache
|
| 159 |
+
|
| 160 |
+
arrays: list[Optional[np.ndarray]] = [None] * len(texts)
|
| 161 |
+
lengths: list[int] = [0] * len(texts)
|
| 162 |
+
misses: list[tuple[int, str, str]] = []
|
| 163 |
+
|
| 164 |
+
with self._lock:
|
| 165 |
+
self._load_index()
|
| 166 |
+
for idx, text in enumerate(texts):
|
| 167 |
+
key = self._make_key(text)
|
| 168 |
+
cached = self._mem_get(key)
|
| 169 |
+
if cached is not None:
|
| 170 |
+
arrays[idx] = cached
|
| 171 |
+
lengths[idx] = cached.shape[0]
|
| 172 |
+
self.stats.hits += 1
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
cached = self._disk_load(key)
|
| 176 |
+
if cached is not None:
|
| 177 |
+
arrays[idx] = cached
|
| 178 |
+
lengths[idx] = cached.shape[0]
|
| 179 |
+
self._mem_put(key, cached)
|
| 180 |
+
self.stats.disk_hits += 1
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
misses.append((idx, text, key))
|
| 184 |
+
self.stats.misses += 1
|
| 185 |
+
|
| 186 |
+
if misses:
|
| 187 |
+
miss_texts = [text for _, text, _ in misses]
|
| 188 |
+
miss_tensor, miss_lengths = encoder(miss_texts)
|
| 189 |
+
miss_tensor = miss_tensor.detach().cpu()
|
| 190 |
+
miss_tensor_np = miss_tensor.numpy()
|
| 191 |
+
|
| 192 |
+
with self._lock:
|
| 193 |
+
self._load_index()
|
| 194 |
+
for miss_idx, length in enumerate(miss_lengths):
|
| 195 |
+
idx, _text, key = misses[miss_idx]
|
| 196 |
+
arr = miss_tensor_np[miss_idx, :length].copy()
|
| 197 |
+
arrays[idx] = arr
|
| 198 |
+
lengths[idx] = int(length)
|
| 199 |
+
self._mem_put(key, arr)
|
| 200 |
+
self._disk_save(key, arr)
|
| 201 |
+
self._save_index()
|
| 202 |
+
|
| 203 |
+
max_len = max(lengths) if lengths else 0
|
| 204 |
+
feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0
|
| 205 |
+
dtype = arrays[0].dtype if arrays[0] is not None else np.float32
|
| 206 |
+
padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype)
|
| 207 |
+
for idx, arr in enumerate(arrays):
|
| 208 |
+
if arr is None:
|
| 209 |
+
continue
|
| 210 |
+
padded[idx, : arr.shape[0]] = arr
|
| 211 |
+
|
| 212 |
+
result = torch.from_numpy(padded)
|
| 213 |
+
self._update_session_cache(texts, result, lengths)
|
| 214 |
+
return result, lengths
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class CachedTextEncoder:
|
| 218 |
+
"""Wrapper around a text encoder to add disk-backed caching."""
|
| 219 |
+
|
| 220 |
+
def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None):
|
| 221 |
+
self.encoder = encoder
|
| 222 |
+
self.model_name = model_name
|
| 223 |
+
encoder_id = f"{type(encoder).__name__}"
|
| 224 |
+
self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir)
|
| 225 |
+
|
| 226 |
+
def __call__(self, texts):
|
| 227 |
+
return self.cache.get_or_encode(texts, self.encoder)
|
| 228 |
+
|
| 229 |
+
def prewarm(self, texts) -> None:
|
| 230 |
+
if isinstance(texts, str):
|
| 231 |
+
texts = [texts]
|
| 232 |
+
texts = sanitize_texts(list(texts))
|
| 233 |
+
prewarm_key = hashlib.sha256("|".join(texts).encode("utf-8")).hexdigest()
|
| 234 |
+
if self.cache.has_prewarm_marker(prewarm_key):
|
| 235 |
+
return
|
| 236 |
+
self.cache.get_or_encode(texts, self.encoder)
|
| 237 |
+
self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts))
|
| 238 |
+
|
| 239 |
+
def to(self, device=None, dtype=None):
|
| 240 |
+
if hasattr(self.encoder, "to"):
|
| 241 |
+
self.encoder.to(device=device, dtype=dtype)
|
| 242 |
+
return self
|
| 243 |
+
|
| 244 |
+
@contextlib.contextmanager
|
| 245 |
+
def session_context(self, session):
|
| 246 |
+
token = _ACTIVE_SESSION.set(session)
|
| 247 |
+
try:
|
| 248 |
+
yield
|
| 249 |
+
finally:
|
| 250 |
+
_ACTIVE_SESSION.reset(token)
|
| 251 |
+
|
| 252 |
+
def __getattr__(self, name):
|
| 253 |
+
return getattr(self.encoder, name)
|
generation.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
import viser
|
| 11 |
+
from kimodo.constraints import (
|
| 12 |
+
TYPE_TO_CLASS,
|
| 13 |
+
FullBodyConstraintSet,
|
| 14 |
+
Root2DConstraintSet,
|
| 15 |
+
)
|
| 16 |
+
from kimodo.exports.mujoco import apply_g1_real_robot_projection
|
| 17 |
+
from kimodo.skeleton import G1Skeleton34, SOMASkeleton30
|
| 18 |
+
from kimodo.tools import seed_everything
|
| 19 |
+
|
| 20 |
+
from .embedding_cache import CachedTextEncoder
|
| 21 |
+
from .state import ClientSession, ModelBundle
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_model_constraints_lst(
|
| 25 |
+
session: ClientSession,
|
| 26 |
+
model_bundle: ModelBundle,
|
| 27 |
+
num_frames: int,
|
| 28 |
+
device: str,
|
| 29 |
+
):
|
| 30 |
+
"""Compute the lst of constraints for the model based on the constraints in viser."""
|
| 31 |
+
assert len(session.motions) == 1, "Only one motion allowed for constrained generation"
|
| 32 |
+
if not session.constraints:
|
| 33 |
+
return []
|
| 34 |
+
|
| 35 |
+
model_skeleton = model_bundle.model.skeleton
|
| 36 |
+
# For SOMA, UI uses somaskel77; extract 30-joint subset for the model
|
| 37 |
+
use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints
|
| 38 |
+
skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None
|
| 39 |
+
|
| 40 |
+
dense_smooth_root_pos_2d = None
|
| 41 |
+
if session.constraints["2D Root"].dense_path:
|
| 42 |
+
# get the full 2d root
|
| 43 |
+
dense_smooth_root_pos_2d = session.constraints["2D Root"].get_constraint_info(device=device)["root_pos"][
|
| 44 |
+
:, [0, 2]
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
model_constraints = []
|
| 48 |
+
for track_name, constraint in session.constraints.items():
|
| 49 |
+
constraint_info = constraint.get_constraint_info(device=device)
|
| 50 |
+
frame_idx = constraint_info["frame_idx"]
|
| 51 |
+
# drop any constraints outside the generation range
|
| 52 |
+
valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames]
|
| 53 |
+
valid_idx = [i for i, _ in valid_info]
|
| 54 |
+
valid_frame_idx = [fi for _, fi in valid_info]
|
| 55 |
+
|
| 56 |
+
if len(valid_frame_idx) == 0:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
frame_indices = torch.tensor(valid_frame_idx)
|
| 60 |
+
if track_name == "2D Root":
|
| 61 |
+
smooth_root_pos_2d = constraint_info["root_pos"][valid_idx][:, [0, 2]].to(device)
|
| 62 |
+
# same as "smooth_root_2d"
|
| 63 |
+
model_constraints.append(
|
| 64 |
+
Root2DConstraintSet(
|
| 65 |
+
model_skeleton,
|
| 66 |
+
frame_indices,
|
| 67 |
+
smooth_root_pos_2d,
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
elif track_name == "Full-Body":
|
| 71 |
+
constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
|
| 72 |
+
constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
|
| 73 |
+
if skel_slice is not None:
|
| 74 |
+
constraint_joints_pos = constraint_joints_pos[:, skel_slice]
|
| 75 |
+
constraint_joints_rot = constraint_joints_rot[:, skel_slice]
|
| 76 |
+
|
| 77 |
+
smooth_root_pos_2d = None
|
| 78 |
+
if dense_smooth_root_pos_2d is not None:
|
| 79 |
+
smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices]
|
| 80 |
+
|
| 81 |
+
model_constraints.append(
|
| 82 |
+
FullBodyConstraintSet(
|
| 83 |
+
model_skeleton,
|
| 84 |
+
frame_indices,
|
| 85 |
+
constraint_joints_pos,
|
| 86 |
+
constraint_joints_rot,
|
| 87 |
+
smooth_root_2d=smooth_root_pos_2d,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
elif track_name == "End-Effectors":
|
| 91 |
+
constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
|
| 92 |
+
constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
|
| 93 |
+
if skel_slice is not None:
|
| 94 |
+
constraint_joints_pos = constraint_joints_pos[:, skel_slice]
|
| 95 |
+
constraint_joints_rot = constraint_joints_rot[:, skel_slice]
|
| 96 |
+
|
| 97 |
+
end_effector_type_set_lst = [
|
| 98 |
+
end_effector_type_set
|
| 99 |
+
for i, end_effector_type_set in enumerate(constraint_info["end_effector_type"])
|
| 100 |
+
if i in valid_idx
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
# regroup the end effector data by type
|
| 104 |
+
cls_idx = defaultdict(list)
|
| 105 |
+
for idx, end_effector_type_set in enumerate(end_effector_type_set_lst):
|
| 106 |
+
for end_effector_type in end_effector_type_set:
|
| 107 |
+
cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx)
|
| 108 |
+
|
| 109 |
+
for cls, lst_idx in cls_idx.items():
|
| 110 |
+
frame_indices_cls = frame_indices[lst_idx]
|
| 111 |
+
smooth_root_pos_2d = None
|
| 112 |
+
if dense_smooth_root_pos_2d is not None:
|
| 113 |
+
smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls]
|
| 114 |
+
|
| 115 |
+
constraint_joints_pos_el = constraint_joints_pos[lst_idx]
|
| 116 |
+
constraint_joints_rot_el = constraint_joints_rot[lst_idx]
|
| 117 |
+
|
| 118 |
+
model_constraints.append(
|
| 119 |
+
cls(
|
| 120 |
+
model_skeleton,
|
| 121 |
+
frame_indices_cls,
|
| 122 |
+
constraint_joints_pos_el,
|
| 123 |
+
constraint_joints_rot_el,
|
| 124 |
+
smooth_root_2d=smooth_root_pos_2d,
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Unsupported constraint type: {constraint.display_name}")
|
| 129 |
+
return model_constraints
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def generate(
|
| 133 |
+
*,
|
| 134 |
+
client: viser.ClientHandle,
|
| 135 |
+
session: ClientSession,
|
| 136 |
+
model_bundle: ModelBundle,
|
| 137 |
+
prompts: list[str],
|
| 138 |
+
num_frames: list[int],
|
| 139 |
+
num_samples: int,
|
| 140 |
+
seed: int,
|
| 141 |
+
diffusion_steps: int,
|
| 142 |
+
cfg_weight: Optional[list[float]] = None,
|
| 143 |
+
cfg_type: Optional[str] = None,
|
| 144 |
+
postprocess_parameters: Optional[dict] = None,
|
| 145 |
+
transitions_parameters: Optional[dict] = None,
|
| 146 |
+
real_robot_rotations: bool = False,
|
| 147 |
+
device: str,
|
| 148 |
+
clear_motions,
|
| 149 |
+
add_character_motion,
|
| 150 |
+
) -> None:
|
| 151 |
+
client_id = client.client_id
|
| 152 |
+
print(
|
| 153 |
+
f"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
seed_everything(seed)
|
| 157 |
+
|
| 158 |
+
model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device)
|
| 159 |
+
cfg_weight = cfg_weight or [2.0, 2.0]
|
| 160 |
+
postprocess_parameters = postprocess_parameters or {}
|
| 161 |
+
transitions_parameters = transitions_parameters or {}
|
| 162 |
+
|
| 163 |
+
encoder = getattr(model_bundle.model, "text_encoder", None)
|
| 164 |
+
if isinstance(encoder, CachedTextEncoder):
|
| 165 |
+
with encoder.session_context(session):
|
| 166 |
+
pred_joints_output = model_bundle.model(
|
| 167 |
+
prompts,
|
| 168 |
+
num_frames,
|
| 169 |
+
diffusion_steps,
|
| 170 |
+
multi_prompt=True,
|
| 171 |
+
constraint_lst=model_constraints,
|
| 172 |
+
cfg_weight=cfg_weight,
|
| 173 |
+
num_samples=num_samples,
|
| 174 |
+
cfg_type=cfg_type,
|
| 175 |
+
**(postprocess_parameters | transitions_parameters),
|
| 176 |
+
) # [B, T, motion_rep_dim]
|
| 177 |
+
else:
|
| 178 |
+
pred_joints_output = model_bundle.model(
|
| 179 |
+
prompts,
|
| 180 |
+
num_frames,
|
| 181 |
+
diffusion_steps,
|
| 182 |
+
multi_prompt=True,
|
| 183 |
+
constraint_lst=model_constraints,
|
| 184 |
+
cfg_weight=cfg_weight,
|
| 185 |
+
num_samples=num_samples,
|
| 186 |
+
cfg_type=cfg_type,
|
| 187 |
+
**(postprocess_parameters | transitions_parameters),
|
| 188 |
+
) # [B, T, motion_rep_dim]
|
| 189 |
+
|
| 190 |
+
joints_pos = pred_joints_output["posed_joints"] # [B, T, J, 3]
|
| 191 |
+
joints_rot = pred_joints_output["global_rot_mats"]
|
| 192 |
+
foot_contacts = pred_joints_output.get("foot_contacts")
|
| 193 |
+
|
| 194 |
+
# Optionally project G1 to real robot DoF (1-DoF per joint, clamped) for display.
|
| 195 |
+
if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34):
|
| 196 |
+
joints_pos, joints_rot = apply_g1_real_robot_projection(
|
| 197 |
+
session.skeleton,
|
| 198 |
+
pred_joints_output["posed_joints"],
|
| 199 |
+
pred_joints_output["global_rot_mats"],
|
| 200 |
+
clamp_to_limits=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Display on characters (callbacks keep this module UI-agnostic).
|
| 204 |
+
clear_motions(client_id)
|
| 205 |
+
# Keep one sample centered at the origin so constraints align.
|
| 206 |
+
spread_factor = 1.0 # meters
|
| 207 |
+
center_idx = num_samples // 2
|
| 208 |
+
x_trans = (np.arange(num_samples) - center_idx) * spread_factor
|
| 209 |
+
for i in range(num_samples):
|
| 210 |
+
cur_joints_pos = joints_pos[i]
|
| 211 |
+
cur_joints_pos[..., 0] += x_trans[i]
|
| 212 |
+
add_character_motion(
|
| 213 |
+
client,
|
| 214 |
+
session.skeleton,
|
| 215 |
+
cur_joints_pos,
|
| 216 |
+
joints_rot[i],
|
| 217 |
+
foot_contacts[i],
|
| 218 |
+
)
|
queue_manager.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""HF mode user queue and session time limit."""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import viser
|
| 12 |
+
|
| 13 |
+
from .config import DEMO_UI_QUICK_START_MODAL_MD, MAX_SESSION_MINUTES
|
| 14 |
+
|
| 15 |
+
# Link for "Duplicate this Space" on Hugging Face (used in queue and expiry modals).
|
| 16 |
+
DUPLICATE_SPACE_URL = "https://huggingface.co/spaces/nvidia/Kimodo?duplicate=true"
|
| 17 |
+
GITHUB_REPO_URL = "https://github.com/nv-tlabs/kimodo"
|
| 18 |
+
|
| 19 |
+
# How often to refresh queue modal content (position, total, estimated wait).
|
| 20 |
+
QUEUE_MODAL_REFRESH_INTERVAL_SEC = 15
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UserQueue:
|
| 24 |
+
"""Thread-safe queue: active users (with activation timestamp) and waiting queue."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, max_active: int, max_minutes: float) -> None:
|
| 27 |
+
self._max_active = max_active
|
| 28 |
+
self._max_minutes = max_minutes
|
| 29 |
+
self._max_seconds = max_minutes * 60.0
|
| 30 |
+
self._active: dict[int, float] = {} # client_id -> activation timestamp
|
| 31 |
+
self._queued: list[int] = []
|
| 32 |
+
self._lock = threading.Lock()
|
| 33 |
+
|
| 34 |
+
def try_activate(self, client_id: int) -> bool:
|
| 35 |
+
"""If a slot is free, add client as active and return True.
|
| 36 |
+
|
| 37 |
+
Else return False.
|
| 38 |
+
"""
|
| 39 |
+
with self._lock:
|
| 40 |
+
if len(self._active) < self._max_active:
|
| 41 |
+
self._active[client_id] = time.time()
|
| 42 |
+
return True
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
def enqueue(self, client_id: int) -> None:
|
| 46 |
+
with self._lock:
|
| 47 |
+
if client_id not in self._queued:
|
| 48 |
+
self._queued.append(client_id)
|
| 49 |
+
|
| 50 |
+
def remove(self, client_id: int) -> bool:
|
| 51 |
+
"""Remove from active or queue.
|
| 52 |
+
|
| 53 |
+
Returns True if was active.
|
| 54 |
+
"""
|
| 55 |
+
with self._lock:
|
| 56 |
+
was_active = client_id in self._active
|
| 57 |
+
self._active.pop(client_id, None)
|
| 58 |
+
if client_id in self._queued:
|
| 59 |
+
self._queued.remove(client_id)
|
| 60 |
+
return was_active
|
| 61 |
+
|
| 62 |
+
def promote_next(self) -> int | None:
|
| 63 |
+
"""If queue non-empty, pop first, activate them, return their client_id.
|
| 64 |
+
|
| 65 |
+
Else None.
|
| 66 |
+
"""
|
| 67 |
+
with self._lock:
|
| 68 |
+
if not self._queued:
|
| 69 |
+
return None
|
| 70 |
+
client_id = self._queued.pop(0)
|
| 71 |
+
self._active[client_id] = time.time()
|
| 72 |
+
return client_id
|
| 73 |
+
|
| 74 |
+
def get_queue_position(self, client_id: int) -> tuple[int, int] | None:
|
| 75 |
+
"""(1-based position, total_in_queue) or None if not queued."""
|
| 76 |
+
with self._lock:
|
| 77 |
+
if client_id not in self._queued:
|
| 78 |
+
return None
|
| 79 |
+
pos = self._queued.index(client_id)
|
| 80 |
+
return (pos + 1, len(self._queued))
|
| 81 |
+
|
| 82 |
+
def get_estimated_wait_seconds(self, client_id: int) -> float:
|
| 83 |
+
"""Estimated seconds until this queued client gets a slot."""
|
| 84 |
+
with self._lock:
|
| 85 |
+
if client_id not in self._queued:
|
| 86 |
+
return 0.0
|
| 87 |
+
pos = self._queued.index(client_id) + 1 # 1-based
|
| 88 |
+
# Expiry times of active users (when they free a slot)
|
| 89 |
+
now = time.time()
|
| 90 |
+
expiries = sorted(now + self._max_seconds - (now - t) for t in self._active.values())
|
| 91 |
+
if not expiries:
|
| 92 |
+
return 0.0
|
| 93 |
+
# Nth slot to free (1-indexed) wraps over expiries
|
| 94 |
+
idx = (pos - 1) % len(expiries)
|
| 95 |
+
cycles = (pos - 1) // len(expiries)
|
| 96 |
+
slot_free_time = expiries[idx] + cycles * self._max_seconds
|
| 97 |
+
return max(0.0, slot_free_time - now)
|
| 98 |
+
|
| 99 |
+
def is_active(self, client_id: int) -> bool:
|
| 100 |
+
with self._lock:
|
| 101 |
+
return client_id in self._active
|
| 102 |
+
|
| 103 |
+
def was_active(self, client_id: int) -> bool:
|
| 104 |
+
"""True if client is currently active (for use when already holding lock)."""
|
| 105 |
+
return client_id in self._active
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _format_wait(seconds: float) -> str:
|
| 109 |
+
if seconds < 60:
|
| 110 |
+
return "less than a minute"
|
| 111 |
+
mins = int(math.ceil(seconds / 60))
|
| 112 |
+
return f"~{mins} minute{'s' if mins != 1 else ''}"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _queue_modal_markdown(position: int, total: int, estimated_wait_sec: float) -> str:
|
| 116 |
+
wait_str = _format_wait(estimated_wait_sec)
|
| 117 |
+
mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
|
| 118 |
+
return f"""## Kimodo Demo — Please Wait
|
| 119 |
+
|
| 120 |
+
This demo runs with limited capacity.
|
| 121 |
+
Each user gets **{mins} minute{"s" if mins != 1 else ""}** of interactive time.
|
| 122 |
+
|
| 123 |
+
**Your position in queue:** {position} / {total}
|
| 124 |
+
|
| 125 |
+
**Estimated wait:** {wait_str}
|
| 126 |
+
|
| 127 |
+
Please keep this tab open — the demo will start automatically when it's your turn.
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
*Want unlimited access? [Duplicate this Space]({DUPLICATE_SPACE_URL}) or clone the [GitHub repo]({GITHUB_REPO_URL}) to run locally!*
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _welcome_modal_markdown() -> str:
|
| 135 |
+
mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
|
| 136 |
+
return f"""## Welcome to Kimodo Demo
|
| 137 |
+
|
| 138 |
+
You have been granted a **{mins}-minute** demo session.
|
| 139 |
+
Your session timer has started.
|
| 140 |
+
|
| 141 |
+
Click the button below to begin!
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _expiry_modal_markdown() -> str:
|
| 146 |
+
mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES
|
| 147 |
+
return f"""## Session Expired
|
| 148 |
+
|
| 149 |
+
Your {mins}-minute demo session has ended.
|
| 150 |
+
Thank you for trying Kimodo!
|
| 151 |
+
|
| 152 |
+
Refresh this page to rejoin the queue, or [duplicate this Space]({DUPLICATE_SPACE_URL}) for unlimited access.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class QueueManager:
|
| 157 |
+
"""Orchestrates HF mode: queue modals, welcome modal, session timer, promotion."""
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
queue: UserQueue,
|
| 162 |
+
server: viser.ViserServer,
|
| 163 |
+
setup_demo_for_client: Callable[[viser.ClientHandle], None],
|
| 164 |
+
cleanup_session: Callable[[int], None],
|
| 165 |
+
) -> None:
|
| 166 |
+
self._queue = queue
|
| 167 |
+
self._server = server
|
| 168 |
+
self._setup_demo_for_client = setup_demo_for_client
|
| 169 |
+
self._cleanup_session = cleanup_session
|
| 170 |
+
self._max_seconds = queue._max_seconds
|
| 171 |
+
|
| 172 |
+
self._queue_modal_handles: dict[int, tuple[Any, Any]] = {}
|
| 173 |
+
self._welcome_modal_handles: dict[int, Any] = {}
|
| 174 |
+
self._expiry_timers: dict[int, threading.Timer] = {}
|
| 175 |
+
self._lock = threading.Lock()
|
| 176 |
+
self._refresh_stop = threading.Event()
|
| 177 |
+
self._refresh_thread = threading.Thread(
|
| 178 |
+
target=self._queue_modal_refresh_loop,
|
| 179 |
+
name="queue-modal-refresh",
|
| 180 |
+
daemon=True,
|
| 181 |
+
)
|
| 182 |
+
self._refresh_thread.start()
|
| 183 |
+
|
| 184 |
+
def _queue_modal_refresh_loop(self) -> None:
|
| 185 |
+
"""Periodically refresh queue modals so position, total, and estimated wait stay current."""
|
| 186 |
+
while not self._refresh_stop.wait(timeout=QUEUE_MODAL_REFRESH_INTERVAL_SEC):
|
| 187 |
+
self._update_all_queue_modals()
|
| 188 |
+
|
| 189 |
+
def on_client_connect(self, client: viser.ClientHandle) -> None:
|
| 190 |
+
"""Handle new connection: activate if slot free, else enqueue and show queue modal."""
|
| 191 |
+
client_id = client.client_id
|
| 192 |
+
if self._queue.try_activate(client_id):
|
| 193 |
+
try:
|
| 194 |
+
self._setup_demo_for_client(client)
|
| 195 |
+
except RuntimeError as e:
|
| 196 |
+
if "CUDA error" in str(e):
|
| 197 |
+
print(f"CUDA error while setting up client {client_id}: {e}")
|
| 198 |
+
return
|
| 199 |
+
raise
|
| 200 |
+
self._start_session_timer(client_id)
|
| 201 |
+
self._show_welcome_modal(client)
|
| 202 |
+
else:
|
| 203 |
+
self._queue.enqueue(client_id)
|
| 204 |
+
self._show_queue_modal(client)
|
| 205 |
+
self._update_all_queue_modals()
|
| 206 |
+
|
| 207 |
+
def on_client_disconnect(self, client_id: int) -> None:
|
| 208 |
+
"""Remove from queue/active, cancel timer, promote next if was active.
|
| 209 |
+
|
| 210 |
+
Session/scene cleanup is done by the demo's on_client_disconnect.
|
| 211 |
+
"""
|
| 212 |
+
with self._lock:
|
| 213 |
+
self._expiry_timers.pop(client_id, None)
|
| 214 |
+
self._queue_modal_handles.pop(client_id, None)
|
| 215 |
+
self._welcome_modal_handles.pop(client_id, None)
|
| 216 |
+
was_active = self._queue.remove(client_id)
|
| 217 |
+
if was_active:
|
| 218 |
+
self._promote_next_user()
|
| 219 |
+
else:
|
| 220 |
+
self._update_all_queue_modals()
|
| 221 |
+
|
| 222 |
+
def _show_queue_modal(self, client: viser.ClientHandle) -> None:
|
| 223 |
+
client_id = client.client_id
|
| 224 |
+
pos, total = self._queue.get_queue_position(client_id) or (0, 0)
|
| 225 |
+
wait_sec = self._queue.get_estimated_wait_seconds(client_id)
|
| 226 |
+
md_content = _queue_modal_markdown(pos, total, wait_sec)
|
| 227 |
+
|
| 228 |
+
modal = client.gui.add_modal(
|
| 229 |
+
"Kimodo Demo — Please Wait",
|
| 230 |
+
size="xl",
|
| 231 |
+
show_close_button=False,
|
| 232 |
+
)
|
| 233 |
+
with modal:
|
| 234 |
+
md_handle = client.gui.add_markdown(md_content)
|
| 235 |
+
with self._lock:
|
| 236 |
+
self._queue_modal_handles[client_id] = (modal, md_handle)
|
| 237 |
+
|
| 238 |
+
def _show_quick_start_modal(self, client: viser.ClientHandle) -> None:
|
| 239 |
+
"""Show the quick start instructions modal (same as non-HF mode)."""
|
| 240 |
+
with client.gui.add_modal(
|
| 241 |
+
"Welcome — Quick Start",
|
| 242 |
+
size="xl",
|
| 243 |
+
show_close_button=True,
|
| 244 |
+
save_choice="kimodo.demo.quick_start_ack",
|
| 245 |
+
) as quick_start_modal:
|
| 246 |
+
client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)
|
| 247 |
+
client.gui.add_button("Got it (don't remind me again)").on_click(lambda _: quick_start_modal.close())
|
| 248 |
+
|
| 249 |
+
def _show_welcome_modal(self, client: viser.ClientHandle) -> None:
|
| 250 |
+
client_id = client.client_id
|
| 251 |
+
|
| 252 |
+
def _on_start_demo(_: Any) -> None:
|
| 253 |
+
modal.close()
|
| 254 |
+
self._show_quick_start_modal(client)
|
| 255 |
+
|
| 256 |
+
modal = client.gui.add_modal(
|
| 257 |
+
"Welcome to Kimodo Demo",
|
| 258 |
+
size="xl",
|
| 259 |
+
show_close_button=True,
|
| 260 |
+
)
|
| 261 |
+
with modal:
|
| 262 |
+
client.gui.add_markdown(_welcome_modal_markdown())
|
| 263 |
+
client.gui.add_button("Start Demo").on_click(_on_start_demo)
|
| 264 |
+
with self._lock:
|
| 265 |
+
self._welcome_modal_handles[client_id] = modal
|
| 266 |
+
|
| 267 |
+
def _update_all_queue_modals(self) -> None:
|
| 268 |
+
with self._lock:
|
| 269 |
+
handles = list(self._queue_modal_handles.items())
|
| 270 |
+
for client_id, (modal, md_handle) in handles:
|
| 271 |
+
pos_total = self._queue.get_queue_position(client_id)
|
| 272 |
+
if pos_total is None:
|
| 273 |
+
continue
|
| 274 |
+
pos, total = pos_total
|
| 275 |
+
wait_sec = self._queue.get_estimated_wait_seconds(client_id)
|
| 276 |
+
try:
|
| 277 |
+
md_handle.content = _queue_modal_markdown(pos, total, wait_sec)
|
| 278 |
+
except Exception:
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
def _promote_next_user(self) -> None:
|
| 282 |
+
promoted_id = self._queue.promote_next()
|
| 283 |
+
if promoted_id is None:
|
| 284 |
+
return
|
| 285 |
+
clients = self._server.get_clients()
|
| 286 |
+
client = clients.get(promoted_id)
|
| 287 |
+
if client is None:
|
| 288 |
+
return
|
| 289 |
+
with self._lock:
|
| 290 |
+
old = self._queue_modal_handles.pop(promoted_id, None)
|
| 291 |
+
if old is not None:
|
| 292 |
+
try:
|
| 293 |
+
old[0].close()
|
| 294 |
+
except Exception:
|
| 295 |
+
pass
|
| 296 |
+
try:
|
| 297 |
+
self._setup_demo_for_client(client)
|
| 298 |
+
except RuntimeError as e:
|
| 299 |
+
if "CUDA error" in str(e):
|
| 300 |
+
print(f"CUDA error while setting up client {promoted_id}: {e}")
|
| 301 |
+
return
|
| 302 |
+
raise
|
| 303 |
+
self._start_session_timer(promoted_id)
|
| 304 |
+
self._show_welcome_modal(client)
|
| 305 |
+
self._update_all_queue_modals()
|
| 306 |
+
|
| 307 |
+
def _start_session_timer(self, client_id: int) -> None:
|
| 308 |
+
def on_expiry() -> None:
|
| 309 |
+
self._on_session_expired(client_id)
|
| 310 |
+
|
| 311 |
+
t = threading.Timer(self._max_seconds, on_expiry)
|
| 312 |
+
t.daemon = True
|
| 313 |
+
with self._lock:
|
| 314 |
+
self._expiry_timers[client_id] = t
|
| 315 |
+
t.start()
|
| 316 |
+
|
| 317 |
+
def _on_session_expired(self, client_id: int) -> None:
|
| 318 |
+
with self._lock:
|
| 319 |
+
self._expiry_timers.pop(client_id, None)
|
| 320 |
+
if not self._queue.is_active(client_id):
|
| 321 |
+
return
|
| 322 |
+
self._queue.remove(client_id)
|
| 323 |
+
clients = self._server.get_clients()
|
| 324 |
+
client = clients.get(client_id)
|
| 325 |
+
if client is not None:
|
| 326 |
+
try:
|
| 327 |
+
with client.gui.add_modal(
|
| 328 |
+
"Session Expired",
|
| 329 |
+
size="lg",
|
| 330 |
+
show_close_button=False,
|
| 331 |
+
) as modal_ctx:
|
| 332 |
+
client.gui.add_markdown(_expiry_modal_markdown())
|
| 333 |
+
except Exception:
|
| 334 |
+
pass
|
| 335 |
+
self._cleanup_session(client_id)
|
| 336 |
+
self._promote_next_user()
|
state.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import kimodo.viz.viser_utils as viser_utils
|
| 10 |
+
import viser
|
| 11 |
+
from kimodo.skeleton import SkeletonBase
|
| 12 |
+
from kimodo.viz.viser_utils import GuiElements
|
| 13 |
+
|
| 14 |
+
from .config import (
|
| 15 |
+
DEFAULT_CUR_DURATION,
|
| 16 |
+
DEFAULT_MODEL,
|
| 17 |
+
DEFAULT_PLAYBACK_SPEED,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class ModelBundle:
|
| 23 |
+
model: object
|
| 24 |
+
motion_rep: object
|
| 25 |
+
skeleton: SkeletonBase
|
| 26 |
+
model_fps: float
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ClientSession:
|
| 31 |
+
"""Per-client session data."""
|
| 32 |
+
|
| 33 |
+
client: viser.ClientHandle
|
| 34 |
+
gui_elements: GuiElements
|
| 35 |
+
motions: dict # character_name -> CharacterMotion
|
| 36 |
+
constraints: dict[str, viser_utils.ConstraintSet] = field(default_factory=dict)
|
| 37 |
+
timeline_data: object = None
|
| 38 |
+
frame_idx: int = 0
|
| 39 |
+
playing: bool = False
|
| 40 |
+
playback_speed: float = DEFAULT_PLAYBACK_SPEED
|
| 41 |
+
playback_time_accumulator: float = 0.0
|
| 42 |
+
last_space_toggle_time: float = 0.0
|
| 43 |
+
cur_duration: float = DEFAULT_CUR_DURATION
|
| 44 |
+
max_frame_idx: int = 100 # will be updated based on model_fps
|
| 45 |
+
updating_motions: bool = False
|
| 46 |
+
edit_mode: bool = False
|
| 47 |
+
model_name: str = DEFAULT_MODEL
|
| 48 |
+
model_fps: float = 0.0
|
| 49 |
+
skeleton: SkeletonBase | None = None
|
| 50 |
+
motion_rep: object | None = None
|
| 51 |
+
examples_base_dir: str = ""
|
| 52 |
+
example_dict: dict[str, str] = field(default_factory=dict)
|
| 53 |
+
gui_examples_dropdown: Optional[viser.GuiInputHandle] = None
|
| 54 |
+
gui_save_example_path_text: Optional[viser.GuiInputHandle] = None
|
| 55 |
+
gui_model_selector: Optional[viser.GuiInputHandle] = None
|
| 56 |
+
last_prompt_texts: Optional[list[str]] = None
|
| 57 |
+
last_prompt_embeddings: Optional[torch.Tensor] = None
|
| 58 |
+
last_prompt_lengths: Optional[list[int]] = None
|
| 59 |
+
edit_mode_snapshot: Optional[dict[int, dict[str, object]]] = None
|
| 60 |
+
undo_drag_snapshot: Optional[dict[str, object]] = None
|
| 61 |
+
show_only_current_constraint: bool = False # False = Show All, True = Show only Current
|
ui.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|