Spaces:
Running on Zero
Running on Zero
File size: 2,458 Bytes
6d5047c 4be5ba2 6d5047c fd6eef4 6d5047c 1e9ec46 6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""LLM2Vec encoder wrapper for Kimodo text conditioning."""
import os
import numpy as np
import torch
from .llm2vec import LLM2Vec
class LLM2VecEncoder:
"""LLM2Vec text embeddings."""
def __init__(
self,
base_model_name_or_path: str,
peft_model_name_or_path: str,
dtype: str,
llm_dim: int,
) -> None:
torch_dtype = getattr(torch, dtype)
self.llm_dim = llm_dim
cache_dir = os.environ.get("HUGGINGFACE_CACHE_DIR")
hf_token = (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
if "TEXT_ENCODERS_DIR" in os.environ:
base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path)
peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path)
self.model = LLM2Vec.from_pretrained(
base_model_name_or_path=base_model_name_or_path,
peft_model_name_or_path=peft_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
token=hf_token,
)
self.model.eval()
for p in self.model.parameters():
p.requires_grad = False
def to(self, device: torch.device):
self.model = self.model.to(device)
return self
def eval(self):
self.model.eval()
return self
def get_device(self):
return self.model.model.device
def __call__(self, text: list[str] | str):
is_string = False
if isinstance(text, str):
text = [text]
is_string = True
with torch.no_grad():
encoded_text = self.model.encode(text, batch_size=len(text), show_progress_bar=False)
assert len(encoded_text.shape)
assert self.llm_dim == encoded_text.shape[-1]
encoded_text = encoded_text[:, None]
lengths = np.ones(len(encoded_text), dtype=int).tolist()
if is_string:
encoded_text = encoded_text[0]
lengths = lengths[0]
encoded_text = torch.as_tensor(encoded_text, device=self.get_device())
return encoded_text, lengths
|