# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """LLM2Vec encoder wrapper for Kimodo text conditioning.""" import os import numpy as np import torch from .llm2vec import LLM2Vec class LLM2VecEncoder: """LLM2Vec text embeddings.""" def __init__( self, base_model_name_or_path: str, peft_model_name_or_path: str, dtype: str, llm_dim: int, ) -> None: torch_dtype = getattr(torch, dtype) self.llm_dim = llm_dim cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR") hf_token = ( os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) if "TEXT_ENCODERS_DIR" in os.environ: base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path) peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path) self.model = LLM2Vec.from_pretrained( base_model_name_or_path=base_model_name_or_path, peft_model_name_or_path=peft_model_name_or_path, torch_dtype=torch_dtype, cache_dir=cache_dir, token=hf_token, ) self.model.eval() for p in self.model.parameters(): p.requires_grad = False def to(self, device: torch.device): self.model = self.model.to(device) return self def eval(self): self.model.eval() return self def get_device(self): return self.model.model.device def __call__(self, text: list[str] | str): is_string = False if isinstance(text, str): text = [text] is_string = True with torch.no_grad(): encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False) assert len(encoded_text.shape) assert self.llm_dim == encoded_text.shape[-1] encoded_text = encoded_text[:, None] lengths = np.ones(len(encoded_text), dtype=int).tolist() if is_string: encoded_text = encoded_text[0] lengths = lengths[0] encoded_text = torch.as_tensor(encoded_text, device=self.get_device()) return encoded_text, lengths