File size: 2,463 Bytes
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation."""

import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf
from safetensors.torch import load_file as load_safetensors

from .registry import (
    AVAILABLE_MODELS,
    DEFAULT_MODEL,
    DEFAULT_TEXT_ENCODER_URL,
    KIMODO_MODELS,
    MODEL_NAMES,
    TMR_MODELS,
)


def get_env_var(name: str, default: Optional[str] = None) -> Optional[str]:
    """Return environment variable value, or default if unset/empty."""
    return os.environ.get(name) or default


def instantiate_from_dict(
    cfg: Dict[str, Any],
    overrides: Optional[Dict[str, Any]] = None,
):
    """Instantiate an object from a config dict (e.g. from OmegaConf.to_container).

    The dict must contain _target_ with a fully qualified class path. Nested configs are
    instantiated recursively.
    """
    if overrides:
        cfg = {**cfg, **overrides}
    conf = OmegaConf.create(cfg)
    return instantiate(conf)


def load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict:
    """Load a state dict from a checkpoint file.

    If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning),
    that is returned; otherwise the whole checkpoint is treated as the state dict.

    Args:
        ckpt_path: Path to the checkpoint file.

    Returns:
        state_dict suitable for model.load_state_dict().
    """
    ckpt_path = str(ckpt_path)

    if ckpt_path.endswith(".safetensors"):
        state_dict = load_safetensors(ckpt_path)
    else:
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        elif isinstance(checkpoint, dict):
            state_dict = checkpoint
        else:
            raise ValueError(f"Unsupported checkpoint format: {ckpt_path}")
    return {key: val.detach().cpu() for key, val in state_dict.items()}


__all__ = [
    "get_env_var",
    "instantiate_from_dict",
    "KIMODO_MODELS",
    "TMR_MODELS",
    "AVAILABLE_MODELS",
    "MODEL_NAMES",
    "DEFAULT_MODEL",
    "DEFAULT_TEXT_ENCODER_URL",
    "load_checkpoint_state_dict",
]