File size: 17,398 Bytes
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Registry of model names and Hugging Face repo IDs for Kimodo and TMR.

Canonical source of truth is the list of repo IDs. Short keys (e.g. soma-rp) and metadata (dataset,
skeleton, version, display name) are derived by parsing.
"""

import re
from dataclasses import dataclass
from typing import Optional

# Canonical list: repo IDs in the same syntax as Hugging Face (org/Model-Name-v1).
# Parser expects: org/Family-SKELETON-DATASET-version (e.g. Kimodo-SOMA-RP-v1).
KIMODO_REPO_IDS = [
    "nvidia/Kimodo-SOMA-RP-v1",
    "nvidia/Kimodo-SMPLX-RP-v1",
    "nvidia/Kimodo-G1-RP-v1",
    "nvidia/Kimodo-SOMA-SEED-v1",
    "nvidia/Kimodo-G1-SEED-v1",
]
TMR_REPO_IDS = [
    "nvidia/TMR-SOMA-RP-v1",
]

# Repo ID without org, for display (e.g. Kimodo-SOMA-RP-v1).
_REPO_NAME_PATTERN = re.compile(r"^(Kimodo|TMR)-([A-Za-z0-9]+)-(RP|SEED)-v(\d+)$")


@dataclass
class ModelInfo:
    """Structured metadata for one model, derived from its repo ID."""

    repo_id: str
    short_key: str
    family: str
    skeleton: str
    dataset: str
    version: str
    display_name: str

    @property
    def dataset_ui_label(self) -> str:
        return "Rigplay" if self.dataset == "RP" else "SEED"


def _parse_repo_id(repo_id: str) -> Optional[ModelInfo]:
    """Parse a repo ID into ModelInfo.

    Returns None if format is unrecognized.
    """
    # repo_id is "org/Model-Name-v1"
    if "/" in repo_id:
        _, name = repo_id.split("/", 1)
    else:
        name = repo_id
    m = _REPO_NAME_PATTERN.match(name)
    if not m:
        return None
    family, skeleton, dataset, ver = m.groups()
    # Normalize skeleton for display (as is for now)
    skeleton_display = skeleton
    # Include family so Kimodo-SOMA-RP and TMR-SOMA-RP have distinct keys.
    short_key = f"{family.lower()}-{skeleton.lower()}-{dataset.lower()}"
    return ModelInfo(
        repo_id=repo_id,
        short_key=short_key,
        family=family,
        skeleton=skeleton_display,
        dataset=dataset,
        version=f"v{ver}",
        display_name=name,
    )


def _build_registry() -> tuple[list[ModelInfo], dict[str, str], list[str]]:
    """Build model infos, short_key -> repo_id map, and list of short keys.

    When multiple versions exist for the same (family, skeleton, dataset), the base short_key (e.g.
    kimodo-soma-rp) maps to the latest version's repo_id so that HF resolution finds the newest
    model.
    """

    def _version_key(info: ModelInfo) -> int:
        v = info.version
        if v.startswith("v") and v[1:].isdigit():
            return int(v[1:])
        return 0

    all_repos = KIMODO_REPO_IDS + TMR_REPO_IDS
    infos: list[ModelInfo] = []
    for repo_id in all_repos:
        info = _parse_repo_id(repo_id)
        if info is None:
            raise ValueError(f"Registry repo ID does not match expected pattern: {repo_id}")
        infos.append(info)

    # Map each base short_key to the latest version's repo_id (by version number)
    model_names: dict[str, str] = {}
    seen_short_keys: set[str] = set()
    for info in infos:
        if info.short_key in seen_short_keys:
            continue
        seen_short_keys.add(info.short_key)
        candidates = [
            i for i in infos if i.family == info.family and i.skeleton == info.skeleton and i.dataset == info.dataset
        ]
        if candidates:
            latest = max(candidates, key=_version_key)
            model_names[info.short_key] = latest.repo_id

    return infos, model_names, list(model_names.keys())


MODEL_INFOS, MODEL_NAMES, _SHORT_KEYS = _build_registry()
AVAILABLE_MODELS = _SHORT_KEYS

# Short-key lists for Kimodo vs TMR (load_model uses TMR_MODELS to branch).
KIMODO_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "Kimodo"]
TMR_MODELS = [info.short_key for info in MODEL_INFOS if info.family == "TMR"]

# Backward compatibility: FRIENDLY_NAMES for any code that still expects it.
FRIENDLY_NAMES = {info.short_key: info.display_name for info in MODEL_INFOS}

DEFAULT_MODEL = "kimodo-soma-rp"
DEFAULT_TEXT_ENCODER_URL = "http://127.0.0.1:9550/"

# Friendly names for skeleton dropdown (key -> label).
SKELETON_DISPLAY_NAMES = {
    "SOMA": "SOMA Human Body",
    "SMPLX": "SMPLX Human Body",
    "G1": "Unitree G1 Humanoid Robot",
}

# Order for skeleton dropdown: SOMA, SMPLX, G1.
SKELETON_ORDER = ("SOMA", "SMPLX", "G1")


def get_skeleton_display_name(skeleton_key: str) -> str:
    """Return the UI label for a skeleton key (e.g. SOMA -> SOMA Human Body)."""
    return SKELETON_DISPLAY_NAMES.get(skeleton_key, skeleton_key)


def get_skeleton_key_from_display_name(display_name: str) -> Optional[str]:
    """Return the skeleton key for a UI label, or None."""
    for key, label in SKELETON_DISPLAY_NAMES.items():
        if label == display_name:
            return key
    return None


def get_skeleton_display_names_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:
    """Return skeleton UI labels for the given dataset.

    If family is set (e.g. "Kimodo"), only skeletons with a model of that family are included.
    """
    keys = get_skeletons_for_dataset(dataset_ui_label, family=family)
    return [get_skeleton_display_name(k) for k in keys]


def get_short_key(repo_id: str) -> Optional[str]:
    """Return the short key for a repo ID, or None if not in registry."""
    for info in MODEL_INFOS:
        if info.repo_id == repo_id:
            return info.short_key
    return None


def get_model_info(short_key: str) -> Optional[ModelInfo]:
    """Return ModelInfo for a short key, or None if not found.

    When multiple versions share the same short_key, returns the one used for loading (the latest
    version), so CHECKPOINT_DIR and HF use the same version.
    """
    repo_id = MODEL_NAMES.get(short_key)
    if repo_id is None:
        return None
    for info in MODEL_INFOS:
        if info.repo_id == repo_id:
            return info
    return None


def get_short_key_from_display_name(display_name: str) -> Optional[str]:
    """Return short_key for a display name (e.g. Kimodo-SOMA-RP-v1), or None."""
    for info in MODEL_INFOS:
        if info.display_name == display_name:
            return info.short_key
    return None


def get_models_for_demo() -> list[ModelInfo]:
    """Return all model infos in registry order (for demo model list)."""
    return list(MODEL_INFOS)


def get_datasets(family: Optional[str] = None) -> list[str]:
    """Return unique dataset UI labels (Rigplay, SEED) present in registry.

    If family is set (e.g. "Kimodo"), only datasets that have a model of that family are included.
    """
    infos = MODEL_INFOS
    if family is not None:
        infos = [i for i in infos if i.family == family]
    labels = set()
    for info in infos:
        labels.add(info.dataset_ui_label)
    return sorted(labels)


def get_skeletons_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:
    """Return skeleton names that have a model for the given dataset.

    Order: SOMA, SMPLX, G1 (only those present for the dataset).
    If family is set (e.g. "Kimodo"), only skeletons with a model of that
    family are included.
    """
    dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
    infos = MODEL_INFOS
    if family is not None:
        infos = [i for i in infos if i.family == family]
    skeletons = set()
    for info in infos:
        if info.dataset == dataset:
            skeletons.add(info.skeleton)
    return [s for s in SKELETON_ORDER if s in skeletons]


def get_versions_for_dataset_skeleton(dataset_ui_label: str, skeleton: str) -> list[str]:
    """Return version strings (e.g. v1) for the given dataset/skeleton.

    Sorted by version number so the last element is the highest (e.g. v1, v2).
    """
    dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
    versions = []
    for info in MODEL_INFOS:
        if info.dataset == dataset and info.skeleton == skeleton:
            versions.append(info.version)

    # Sort by numeric part so v2 comes after v1.
    def version_key(v: str) -> int:
        if v.startswith("v") and v[1:].isdigit():
            return int(v[1:])
        return 0

    return sorted(set(versions), key=version_key)


def get_models_for_dataset_skeleton(
    dataset_ui_label: str, skeleton: str, family: Optional[str] = None
) -> list[ModelInfo]:
    """Return model infos for the given dataset/skeleton, sorted by version (max first).

    Used to build the Version dropdown (options = full display names, one per model). If family is
    set (e.g. "Kimodo"), only models of that family are returned.
    """
    dataset = "RP" if dataset_ui_label == "Rigplay" else "SEED"
    infos = [info for info in MODEL_INFOS if info.dataset == dataset and info.skeleton == skeleton]
    if family is not None:
        infos = [i for i in infos if i.family == family]

    def version_key(info: ModelInfo) -> int:
        v = info.version
        if v.startswith("v") and v[1:].isdigit():
            return int(v[1:])
        return 0

    return sorted(infos, key=version_key, reverse=True)


def resolve_to_short_key(dataset_ui_label: str, skeleton: str, version: str) -> Optional[str]:
    """Return the short key for (dataset, skeleton, version), or None."""
    for info in MODEL_INFOS:
        if info.dataset_ui_label == dataset_ui_label and info.skeleton == skeleton and info.version == version:
            return info.short_key
    return None


# -----------------------------------------------------------------------------
# Flexible model name resolution (partial names, case-insensitive, defaults)
# -----------------------------------------------------------------------------

_FAMILY_ALIASES = {"kimodo": "Kimodo", "tmr": "TMR"}
_DATASET_ALIASES = {"rp": "RP", "rigplay": "RP", "seed": "SEED"}
_SKELETON_ALIASES = {
    "soma": "SOMA",
    "smplx": "SMPLX",
    "g1": "G1",
}


def _normalize_family(s: str) -> Optional[str]:
    """Return canonical family (Kimodo/TMR) or None if unknown."""
    return _FAMILY_ALIASES.get(s.strip().lower())


def _normalize_dataset(s: str) -> Optional[str]:
    """Return canonical dataset (RP/SEED) or None if unknown."""
    return _DATASET_ALIASES.get(s.strip().lower())


def _normalize_skeleton(s: str) -> Optional[str]:
    """Return canonical skeleton (SOMA/SMPLX/G1) or None if unknown."""
    return _SKELETON_ALIASES.get(s.strip().lower())


def _get_latest_for_family_skeleton_dataset(family: str, skeleton: str, dataset: str) -> Optional[ModelInfo]:
    """Return the model info with the highest version for (family, skeleton, dataset)."""
    candidates = [
        info for info in MODEL_INFOS if info.family == family and info.skeleton == skeleton and info.dataset == dataset
    ]
    if not candidates:
        return None

    def version_key(info: ModelInfo) -> int:
        v = info.version
        if v.startswith("v") and v[1:].isdigit():
            return int(v[1:])
        return 0

    return max(candidates, key=version_key)


def kimodo_short_key_for_skeleton_dataset(skeleton: str, dataset: str) -> Optional[str]:
    """Return the latest Kimodo model short_key for ``skeleton`` and ``dataset`` (RP/SEED), or
    None."""
    info = _get_latest_for_family_skeleton_dataset("Kimodo", skeleton, dataset)
    return info.short_key if info is not None else None


def registry_skeleton_for_joint_count(nb_joints: int) -> str:
    """Map motion joint count to registry skeleton key (SOMA / SMPLX / G1)."""
    if nb_joints == 34:
        return "G1"
    if nb_joints == 22:
        return "SMPLX"
    if nb_joints in (77, 30):
        return "SOMA"
    raise ValueError(f"No Kimodo model registered for motion with J={nb_joints}")


# Optional version: Family-Skeleton-Dataset-vN or Family-Skeleton-Dataset
_RESOLVE_FULL_PATTERN = re.compile(
    r"^(Kimodo|TMR|kimodo|tmr)[\-_]" r"([A-Za-z0-9]+)[\-_]" r"(RP|SEED|rp|seed)" r"(?:[\-_]v(\d+))?$",
    re.IGNORECASE,
)
# Partial: Skeleton-Dataset or Skeleton or Dataset (no family)
_RESOLVE_PARTIAL_PATTERN = re.compile(
    r"^([A-Za-z0-9]+)(?:[\-_](RP|SEED|rp|seed))?(?:[\-_]v(\d+))?$",
    re.IGNORECASE,
)


def resolve_model_name(name: Optional[str], default_family: Optional[str] = None) -> str:
    """Resolve a user-facing model name to a short_key.

    Accepts full names (e.g. Kimodo-SOMA-RP-v1), case-insensitive matching,
    and partial names with defaults: dataset=RP, skeleton=SOMA, family from
    default_family (Kimodo for demo/generation, TMR for embed script).
    Omitted version resolves to the latest for that model.

    Args:
        name: User-provided name (can be None or empty).
        default_family: "Kimodo" or "TMR" when name is empty or omits family.

    Returns:
        Short key (e.g. kimodo-soma-rp) for use with load_model / MODEL_NAMES.

    Raises:
        ValueError: If name cannot be resolved or default_family is missing when needed.
    """
    if name is not None:
        name = name.strip()
    if not name:
        if default_family is None:
            raise ValueError('Model name is empty; provide a name or set default_family ("Kimodo" or "TMR").')
        fam = _normalize_family(default_family)
        if fam is None:
            raise ValueError(f"default_family must be 'Kimodo' or 'TMR', got {default_family!r}")
        info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", "RP")
        if info is None:
            raise ValueError(f"No model found for {fam}-SOMA-RP. Available: {list(MODEL_NAMES.keys())}")
        return info.short_key

    # Exact short_key
    if name in MODEL_NAMES:
        return name

    # Case-insensitive match against short_key or display_name
    name_lower = name.lower()
    matches = []
    for info in MODEL_INFOS:
        if name_lower == info.short_key.lower():
            matches.append(info)
        disp = info.display_name.lower()
        if name_lower == disp or name_lower == ("nvidia/" + disp):
            matches.append(info)
    if len(matches) == 1:
        return matches[0].short_key
    if len(matches) > 1:
        return matches[0].short_key

    # Parsed full form: Family-Skeleton-Dataset or Family-Skeleton-Dataset-vN
    m = _RESOLVE_FULL_PATTERN.match(name)
    if m:
        fam_raw, skel_raw, ds_raw, ver_num = m.groups()
        fam = _normalize_family(fam_raw)
        skel = _normalize_skeleton(skel_raw)
        ds = _normalize_dataset(ds_raw)
        if fam is not None and skel is not None and ds is not None:
            if ver_num is not None:
                version = f"v{ver_num}"
                for info in MODEL_INFOS:
                    if info.family == fam and info.skeleton == skel and info.dataset == ds and info.version == version:
                        return info.short_key
            else:
                info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)
                if info is not None:
                    return info.short_key

    # Parsed partial: Skeleton-Dataset, Skeleton, or Dataset (use default_family)
    if default_family is not None:
        m = _RESOLVE_PARTIAL_PATTERN.match(name)
        if m:
            tok1, ds_raw, ver_num = m.groups()
            fam = _normalize_family(default_family)
            if fam is not None:
                skel = _normalize_skeleton(tok1)
                ds_candidate = _normalize_dataset(ds_raw) if ds_raw else None
                if skel is not None and ds_candidate is not None:
                    ds = ds_candidate
                elif skel is not None:
                    ds = "RP"
                else:
                    skel = "SOMA"
                    ds = _normalize_dataset(tok1) if tok1 else "RP"
                    if ds is None:
                        ds = "RP"
                if ver_num is not None:
                    version = f"v{ver_num}"
                    for info in MODEL_INFOS:
                        if (
                            info.family == fam
                            and info.skeleton == skel
                            and info.dataset == ds
                            and info.version == version
                        ):
                            return info.short_key
                else:
                    info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)
                    if info is not None:
                        return info.short_key

        # Single token: skeleton or dataset
        fam = _normalize_family(default_family)
        if fam is not None:
            skel = _normalize_skeleton(name)
            if skel is not None:
                info = _get_latest_for_family_skeleton_dataset(fam, skel, "RP")
                if info is not None:
                    return info.short_key
            ds = _normalize_dataset(name)
            if ds is not None:
                info = _get_latest_for_family_skeleton_dataset(fam, "SOMA", ds)
                if info is not None:
                    return info.short_key

    raise ValueError(
        f"Model name {name!r} could not be resolved. "
        f"Use a short key (e.g. {list(MODEL_NAMES.keys())[:3]}...), "
        "a full name (e.g. Kimodo-SOMA-RP-v1), or a partial (e.g. SOMA-RP, SOMA) "
        "with default_family set."
    )