Spaces:
Running on Zero
Running on Zero
File size: 8,713 Bytes
6d5047c 4be5ba2 6d5047c 560cef6 6d5047c 4be5ba2 6d5047c 5430059 6d5047c 4be5ba2 6d5047c 5430059 6d5047c 5430059 6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import gradio as gr
import numpy as np
from huggingface_hub import HfApi
from kimodo.model import resolve_target
from .gradio_theme import get_gradio_theme
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
DEFAULT_TEXT = "A person walks and falls to the ground."
DEFAULT_SERVER_NAME = "0.0.0.0"
DEFAULT_SERVER_PORT = 9550
DEFAULT_TMP_FOLDER = "/tmp/text_encoder/"
DEFAULT_TEXT_ENCODER = "llm2vec"
TEXT_ENCODER_PRESETS = {
"llm2vec": {
"target": "kimodo.model.LLM2VecEncoder",
"kwargs": {
"base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
"dtype": "bfloat16",
"llm_dim": 4096,
},
"display_name": "LLM2Vec",
}
}
def _get_hf_token() -> str | None:
return (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
def _validate_text_encoder_startup(text_encoder_name: str) -> None:
"""Fail fast before launching Gradio if the text encoder cannot be resolved."""
if text_encoder_name not in TEXT_ENCODER_PRESETS:
available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}")
preset = TEXT_ENCODER_PRESETS[text_encoder_name]
token = _get_hf_token()
text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR")
if text_encoders_dir:
base_model_path = os.path.join(text_encoders_dir, preset["kwargs"]["base_model_name_or_path"])
peft_model_path = os.path.join(text_encoders_dir, preset["kwargs"]["peft_model_name_or_path"])
missing = [path for path in (base_model_path, peft_model_path) if not os.path.exists(path)]
if missing:
raise RuntimeError(
"TEXT_ENCODERS_DIR is set, but the following local model paths are missing: "
+ ", ".join(missing)
)
return
if not token:
raise RuntimeError(
"HF token is missing. Set one of HF_TOKEN, HUGGING_FACE_HUB_TOKEN, HF_HUB_TOKEN, or "
"HUGGINGFACEHUB_API_TOKEN before starting the text encoder server."
)
api = HfApi()
for repo_id, label in (
(preset["kwargs"]["base_model_name_or_path"], "base model"),
(preset["kwargs"]["peft_model_name_or_path"], "PEFT adapter"),
):
try:
api.model_info(repo_id=repo_id, token=token)
except Exception as error:
raise RuntimeError(f"Failed to access {label} '{repo_id}' with the configured HF token: {error}") from error
class DemoWrapper:
def __init__(self, text_encoder_name, tmp_folder):
self.text_encoder_name = text_encoder_name
self.text_encoder = None
self.init_error = None
self.tmp_folder = tmp_folder
def _get_text_encoder(self):
if self.text_encoder is not None:
return self.text_encoder
if self.init_error is not None:
raise RuntimeError(self.init_error)
try:
self.text_encoder = _build_text_encoder(self.text_encoder_name)
return self.text_encoder
except Exception as error:
self.init_error = error
raise
def __call__(self, text, filename, progress=gr.Progress()):
try:
text_encoder = self._get_text_encoder()
except Exception as error:
output_title = gr.Markdown(visible=True, value="## Encoder initialization failed")
output_text = gr.Markdown(
visible=True,
value=(
"Text encoder could not initialize. "
"If you use gated Hugging Face models, configure a valid HF token in the runtime env.\n\n"
f"Error: `{type(error).__name__}: {error}`"
),
)
download = gr.DownloadButton(visible=False)
return download, output_title, output_text
# Compute text embedding
tensor, length = text_encoder(text)
embedding = tensor[:length]
embedding = embedding.cpu().numpy()
# Save text embedding
path = os.path.join(self.tmp_folder, filename)
np.save(path, embedding)
output_title = gr.Markdown(visible=True)
output_text = gr.Markdown(visible=True, value=f"Text: {text}")
download = gr.DownloadButton(visible=True, value=path)
return download, output_title, output_text
def _get_env(name: str, default):
return os.getenv(name, default)
def _build_text_encoder(name: str):
if name not in TEXT_ENCODER_PRESETS:
available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
raise ValueError(f"Unknown TEXT_ENCODER='{name}'. Available: {available}")
preset = TEXT_ENCODER_PRESETS[name]
target_cls = resolve_target(preset["target"])
return target_cls(**preset["kwargs"])
def parse_args():
parser = argparse.ArgumentParser(description="Run text encoder Gradio server.")
parser.add_argument(
"--text-encoder",
default=_get_env("TEXT_ENCODER", DEFAULT_TEXT_ENCODER),
choices=sorted(TEXT_ENCODER_PRESETS.keys()),
help="Text encoder preset.",
)
parser.add_argument(
"--tmp-folder",
default=_get_env("TEXT_ENCODER_TMP_FOLDER", DEFAULT_TMP_FOLDER),
)
return parser.parse_args()
def main():
args = parse_args()
server_name = _get_env("GRADIO_SERVER_NAME", DEFAULT_SERVER_NAME)
server_port = int(os.environ.get("GRADIO_SERVER_PORT") or os.environ.get("PORT", str(DEFAULT_SERVER_PORT)))
theme, css = get_gradio_theme()
# Avoid Spaces hot-reload watcher importing `spaces` after CUDA init.
os.environ.setdefault("GRADIO_HOT_RELOAD", "false")
os.makedirs(args.tmp_folder, exist_ok=True)
display_name = TEXT_ENCODER_PRESETS[args.text_encoder]["display_name"]
if _get_env("TEXT_ENCODER_VALIDATE_STARTUP", "1") != "0":
_validate_text_encoder_startup(args.text_encoder)
# Suppress model loading during DemoWrapper initialization to allow graceful degradation
# Model will be loaded lazily on first request
demo_wrapper_fn = DemoWrapper(args.text_encoder, args.tmp_folder)
with gr.Blocks(title="Text encoder") as demo:
gr.Markdown(f"# Text encoder: {display_name}")
gr.Markdown("## Description")
gr.Markdown("Get a embeddings from a text.")
gr.Markdown("## Inputs")
with gr.Row():
text = gr.Textbox(
placeholder="Type the motion you want to generate with a sentence",
show_label=True,
label="Text prompt",
value=DEFAULT_TEXT,
type="text",
)
with gr.Row(scale=3):
with gr.Column(scale=1):
btn = gr.Button("Encode", variant="primary")
with gr.Column(scale=1):
clear = gr.Button("Clear", variant="secondary")
with gr.Column(scale=3):
pass
output_title = gr.Markdown("## Outputs", visible=False)
output_text = gr.Markdown("", visible=False)
with gr.Row(scale=3):
with gr.Column(scale=1):
download = gr.DownloadButton("Download", variant="primary", visible=False)
with gr.Column(scale=4):
pass
filename = gr.Textbox(
visible=False,
value="embedding.npy",
)
def clear_fn():
return [
gr.DownloadButton(visible=False),
gr.Markdown(visible=False),
gr.Markdown(visible=False),
]
outputs = [download, output_title, output_text]
gr.on(
triggers=[text.submit, btn.click],
fn=clear_fn,
inputs=None,
outputs=outputs,
).then(
fn=demo_wrapper_fn,
inputs=[text, filename],
outputs=outputs,
)
def download_file():
return gr.DownloadButton()
download.click(
fn=download_file,
inputs=None,
outputs=[download],
)
clear.click(fn=clear_fn, inputs=None, outputs=outputs)
demo.launch(server_name=server_name, server_port=server_port, css=css, theme=theme)
if __name__ == "__main__":
main()
|