Buckets:

rydlrKE's picture
download
raw
7.52 kB
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Load Kimodo diffusion models from local checkpoints or Hugging Face."""
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from .loading import (
AVAILABLE_MODELS,
DEFAULT_MODEL,
DEFAULT_TEXT_ENCODER_URL,
MODEL_NAMES,
TMR_MODELS,
get_env_var,
instantiate_from_dict,
)
from .registry import get_model_info, resolve_model_name
DEFAULT_TEXT_ENCODER = "llm2vec"
TEXT_ENCODER_PRESETS = {
"llm2vec": {
"target": "kimodo.model.LLM2VecEncoder",
"kwargs": {
"base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
"dtype": "bfloat16",
"llm_dim": 4096,
},
}
}
def _resolve_hf_model_path(modelname: str) -> Path:
"""Resolve model name to a local path, using Hugging Face cache or CHECKPOINT_DIR."""
try:
repo_id = MODEL_NAMES[modelname]
except KeyError:
raise ValueError(f"Model '{modelname}' not found. Available models: {MODEL_NAMES.keys()}")
local_cache = get_env_var("LOCAL_CACHE", "False").lower() == "true"
if not local_cache:
snapshot_dir = snapshot_download(repo_id=repo_id) # will check online no matter what
return Path(snapshot_dir)
try:
snapshot_dir = snapshot_download(repo_id=repo_id, local_files_only=True) # will check local cache only
return Path(snapshot_dir)
except Exception:
# if local cache is not found, download from online
try:
snapshot_dir = snapshot_download(repo_id=repo_id)
return Path(snapshot_dir)
except Exception:
raise RuntimeError(f"Could not resolve model '{modelname}' from Hugging Face (repo: {repo_id}). ") from None
def _build_api_text_encoder_conf(text_encoder_url: str) -> dict:
return {
"_target_": "kimodo.model.text_encoder_api.TextEncoderAPI",
"url": text_encoder_url,
}
def _build_local_text_encoder_conf() -> dict:
text_encoder_name = get_env_var("TEXT_ENCODER", DEFAULT_TEXT_ENCODER)
if text_encoder_name not in TEXT_ENCODER_PRESETS:
available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}")
preset = TEXT_ENCODER_PRESETS[text_encoder_name]
return {
"_target_": preset["target"],
**preset["kwargs"],
}
def _select_text_encoder_conf(text_encoder_url: str) -> dict:
# TEXT_ENCODER_MODE options:
# - "api": force TextEncoderAPI
# - "local": force local LLM2VecEncoder
# - "auto": try API first, fallback to local if unreachable
mode = get_env_var("TEXT_ENCODER_MODE", "auto").lower()
if mode == "local":
return _build_local_text_encoder_conf()
if mode == "api":
return _build_api_text_encoder_conf(text_encoder_url)
api_conf = _build_api_text_encoder_conf(text_encoder_url)
try:
text_encoder = instantiate_from_dict(api_conf)
# Probe availability early so inference doesn't fail later.
text_encoder(["healthcheck"])
return api_conf
except Exception as error:
print(
"Text encoder service is unreachable, falling back to local LLM2Vec "
f"encoder. ({type(error).__name__}: {error})"
)
return _build_local_text_encoder_conf()
def load_model(
modelname=None,
device=None,
eval_mode: bool = True,
default_family: Optional[str] = "Kimodo",
return_resolved_name: bool = False,
):
"""Load a kimodo model by name (e.g. 'g1', 'soma').
Resolution of partial/full names (e.g. Kimodo-SOMA-RP-v1, SOMA) is done
inside this function using default_family when the name is not a known
short key.
Args:
modelname: Model identifier; uses DEFAULT_MODEL if None. Can be a short key,
a full name (e.g. Kimodo-SOMA-RP-v1), or a partial name; unknown names
are resolved via resolve_model_name using default_family.
device: Target device for the model (e.g. 'cuda', 'cpu').
eval_mode: If True, set model to eval mode.
default_family: Used when modelname is not in AVAILABLE_MODELS to resolve
partial names ("Kimodo" for demo/generation, "TMR" for embed script).
Default "Kimodo".
return_resolved_name: If True, return (model, resolved_short_key). If False,
return only the model.
Returns:
Loaded model in eval mode, or (model, resolved short key) if
return_resolved_name is True.
Raises:
ValueError: If modelname is not in AVAILABLE_MODELS and cannot be resolved.
FileNotFoundError: If config.yaml is missing in the checkpoint folder.
"""
if modelname is None:
modelname = DEFAULT_MODEL
if modelname not in AVAILABLE_MODELS:
if default_family is not None:
modelname = resolve_model_name(modelname, default_family)
else:
raise ValueError(
f"""The model is not recognized.
Please choose between: {AVAILABLE_MODELS}"""
)
resolved_modelname = modelname
# In case, we specify a custom checkpoint directory
configured_checkpoint_dir = get_env_var("CHECKPOINT_DIR")
if configured_checkpoint_dir:
print(f"CHECKPOINT_DIR is set to {configured_checkpoint_dir}, checking the local cache...")
# Checkpoint folders are named by display name (e.g. Kimodo-SOMA-RP-v1)
info = get_model_info(modelname)
checkpoint_folder_name = info.display_name if info is not None else modelname
model_path = Path(configured_checkpoint_dir) / checkpoint_folder_name
if not model_path.exists() and modelname != checkpoint_folder_name:
# Fallback: try short_key for backward compatibility
model_path = Path(configured_checkpoint_dir) / modelname
if not model_path.exists():
print(f"Model folder not found at '{model_path}', downloading it from Hugging Face...")
model_path = _resolve_hf_model_path(modelname)
else:
# Otherwise, we load the model from the local cache or download it from Hugging Face.
model_path = _resolve_hf_model_path(modelname)
model_config_path = model_path / "config.yaml"
if not model_config_path.exists():
raise FileNotFoundError(f"The model checkpoint folder exists but config.yaml is missing: {model_config_path}")
model_conf = OmegaConf.load(model_config_path)
if modelname in TMR_MODELS:
# Same process at the moment for TMR and Kimodo
pass
text_encoder_url = get_env_var("TEXT_ENCODER_URL", DEFAULT_TEXT_ENCODER_URL)
runtime_conf = OmegaConf.create(
{
"checkpoint_dir": str(model_path),
"text_encoder": _select_text_encoder_conf(text_encoder_url),
}
)
model_cfg = OmegaConf.to_container(OmegaConf.merge(model_conf, runtime_conf), resolve=True)
model_cfg.pop("checkpoint_dir", None)
model = instantiate_from_dict(model_cfg, overrides={"device": device})
if eval_mode:
model = model.eval()
if return_resolved_name:
return model, resolved_modelname
return model

Xet Storage Details

Size:
7.52 kB
·
Xet hash:
ef307dce73f800c0612bffe18b8e1779d6d392b34981ffde7c172dc7afd746f4

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.