movimento / kimodo /model /llm2vec /llm2vec_wrapper.py
rydlrKE's picture
ZeroGPU-safe startup device + llm2vec tensor warning cleanup
1e9ec46 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""LLM2Vec encoder wrapper for Kimodo text conditioning."""
import os
import numpy as np
import torch
from .llm2vec import LLM2Vec
class LLM2VecEncoder:
"""LLM2Vec text embeddings."""
def __init__(
self,
base_model_name_or_path: str,
peft_model_name_or_path: str,
dtype: str,
llm_dim: int,
) -> None:
torch_dtype = getattr(torch, dtype)
self.llm_dim = llm_dim
cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR")
hf_token = (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
if "TEXT_ENCODERS_DIR" in os.environ:
base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path)
peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path)
self.model = LLM2Vec.from_pretrained(
base_model_name_or_path=base_model_name_or_path,
peft_model_name_or_path=peft_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
token=hf_token,
)
self.model.eval()
for p in self.model.parameters():
p.requires_grad = False
def to(self, device: torch.device):
self.model = self.model.to(device)
return self
def eval(self):
self.model.eval()
return self
def get_device(self):
return self.model.model.device
def __call__(self, text: list[str] | str):
is_string = False
if isinstance(text, str):
text = [text]
is_string = True
with torch.no_grad():
encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False)
assert len(encoded_text.shape)
assert self.llm_dim == encoded_text.shape[-1]
encoded_text = encoded_text[:, None]
lengths = np.ones(len(encoded_text), dtype=int).tolist()
if is_string:
encoded_text = encoded_text[0]
lengths = lengths[0]
encoded_text = torch.as_tensor(encoded_text, device=self.get_device())
return encoded_text, lengths