Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """LLM2Vec encoder wrapper for Kimodo text conditioning.""" | |
| import os | |
| import numpy as np | |
| import torch | |
| from .llm2vec import LLM2Vec | |
| class LLM2VecEncoder: | |
| """LLM2Vec text embeddings.""" | |
| def __init__( | |
| self, | |
| base_model_name_or_path: str, | |
| peft_model_name_or_path: str, | |
| dtype: str, | |
| llm_dim: int, | |
| ) -> None: | |
| torch_dtype = getattr(torch, dtype) | |
| self.llm_dim = llm_dim | |
| cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR") | |
| hf_token = ( | |
| os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| or os.environ.get("HF_HUB_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| if "TEXT_ENCODERS_DIR" in os.environ: | |
| base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path) | |
| peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path) | |
| self.model = LLM2Vec.from_pretrained( | |
| base_model_name_or_path=base_model_name_or_path, | |
| peft_model_name_or_path=peft_model_name_or_path, | |
| torch_dtype=torch_dtype, | |
| cache_dir=cache_dir, | |
| token=hf_token, | |
| ) | |
| self.model.eval() | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| def to(self, device: torch.device): | |
| self.model = self.model.to(device) | |
| return self | |
| def eval(self): | |
| self.model.eval() | |
| return self | |
| def get_device(self): | |
| return self.model.model.device | |
| def __call__(self, text: list[str] | str): | |
| is_string = False | |
| if isinstance(text, str): | |
| text = [text] | |
| is_string = True | |
| with torch.no_grad(): | |
| encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False) | |
| assert len(encoded_text.shape) | |
| assert self.llm_dim == encoded_text.shape[-1] | |
| encoded_text = encoded_text[:, None] | |
| lengths = np.ones(len(encoded_text), dtype=int).tolist() | |
| if is_string: | |
| encoded_text = encoded_text[0] | |
| lengths = lengths[0] | |
| encoded_text = torch.as_tensor(encoded_text, device=self.get_device()) | |
| return encoded_text, lengths | |