Spaces:
Running on Zero
Running on Zero
File size: 2,463 Bytes
6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation."""
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf
from safetensors.torch import load_file as load_safetensors
from .registry import (
AVAILABLE_MODELS,
DEFAULT_MODEL,
DEFAULT_TEXT_ENCODER_URL,
KIMODO_MODELS,
MODEL_NAMES,
TMR_MODELS,
)
def get_env_var(name: str, default: Optional[str] = None) -> Optional[str]:
"""Return environment variable value, or default if unset/empty."""
return os.environ.get(name) or default
def instantiate_from_dict(
cfg: Dict[str, Any],
overrides: Optional[Dict[str, Any]] = None,
):
"""Instantiate an object from a config dict (e.g. from OmegaConf.to_container).
The dict must contain _target_ with a fully qualified class path. Nested configs are
instantiated recursively.
"""
if overrides:
cfg = {**cfg, **overrides}
conf = OmegaConf.create(cfg)
return instantiate(conf)
def load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict:
"""Load a state dict from a checkpoint file.
If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning),
that is returned; otherwise the whole checkpoint is treated as the state dict.
Args:
ckpt_path: Path to the checkpoint file.
Returns:
state_dict suitable for model.load_state_dict().
"""
ckpt_path = str(ckpt_path)
if ckpt_path.endswith(".safetensors"):
state_dict = load_safetensors(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
elif isinstance(checkpoint, dict):
state_dict = checkpoint
else:
raise ValueError(f"Unsupported checkpoint format: {ckpt_path}")
return {key: val.detach().cpu() for key, val in state_dict.items()}
__all__ = [
"get_env_var",
"instantiate_from_dict",
"KIMODO_MODELS",
"TMR_MODELS",
"AVAILABLE_MODELS",
"MODEL_NAMES",
"DEFAULT_MODEL",
"DEFAULT_TEXT_ENCODER_URL",
"load_checkpoint_state_dict",
]
|