Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation.""" | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Union | |
| import torch | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from safetensors.torch import load_file as load_safetensors | |
| from .registry import ( | |
| AVAILABLE_MODELS, | |
| DEFAULT_MODEL, | |
| DEFAULT_TEXT_ENCODER_URL, | |
| KIMODO_MODELS, | |
| MODEL_NAMES, | |
| TMR_MODELS, | |
| ) | |
| def get_env_var(name: str, default: Optional[str] = None) -> Optional[str]: | |
| """Return environment variable value, or default if unset/empty.""" | |
| return os.environ.get(name) or default | |
| def instantiate_from_dict( | |
| cfg: Dict[str, Any], | |
| overrides: Optional[Dict[str, Any]] = None, | |
| ): | |
| """Instantiate an object from a config dict (e.g. from OmegaConf.to_container). | |
| The dict must contain _target_ with a fully qualified class path. Nested configs are | |
| instantiated recursively. | |
| """ | |
| if overrides: | |
| cfg = {**cfg, **overrides} | |
| conf = OmegaConf.create(cfg) | |
| return instantiate(conf) | |
| def load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict: | |
| """Load a state dict from a checkpoint file. | |
| If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning), | |
| that is returned; otherwise the whole checkpoint is treated as the state dict. | |
| Args: | |
| ckpt_path: Path to the checkpoint file. | |
| Returns: | |
| state_dict suitable for model.load_state_dict(). | |
| """ | |
| ckpt_path = str(ckpt_path) | |
| if ckpt_path.endswith(".safetensors"): | |
| state_dict = load_safetensors(ckpt_path) | |
| else: | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| elif isinstance(checkpoint, dict): | |
| state_dict = checkpoint | |
| else: | |
| raise ValueError(f"Unsupported checkpoint format: {ckpt_path}") | |
| return {key: val.detach().cpu() for key, val in state_dict.items()} | |
| __all__ = [ | |
| "get_env_var", | |
| "instantiate_from_dict", | |
| "KIMODO_MODELS", | |
| "TMR_MODELS", | |
| "AVAILABLE_MODELS", | |
| "MODEL_NAMES", | |
| "DEFAULT_MODEL", | |
| "DEFAULT_TEXT_ENCODER_URL", | |
| "load_checkpoint_state_dict", | |
| ] | |