Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import argparse | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import HfApi | |
| from kimodo.model import resolve_target | |
| from .gradio_theme import get_gradio_theme | |
| os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES" | |
| DEFAULT_TEXT = "A person walks and falls to the ground." | |
| DEFAULT_SERVER_NAME = "0.0.0.0" | |
| DEFAULT_SERVER_PORT = 9550 | |
| DEFAULT_TMP_FOLDER = "/tmp/text_encoder/" | |
| DEFAULT_TEXT_ENCODER = "llm2vec" | |
| TEXT_ENCODER_PRESETS = { | |
| "llm2vec": { | |
| "target": "kimodo.model.LLM2VecEncoder", | |
| "kwargs": { | |
| "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp", | |
| "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised", | |
| "dtype": "bfloat16", | |
| "llm_dim": 4096, | |
| }, | |
| "display_name": "LLM2Vec", | |
| } | |
| } | |
| def _get_hf_token() -> str | None: | |
| return ( | |
| os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| or os.environ.get("HF_HUB_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| def _validate_text_encoder_startup(text_encoder_name: str) -> None: | |
| """Fail fast before launching Gradio if the text encoder cannot be resolved.""" | |
| if text_encoder_name not in TEXT_ENCODER_PRESETS: | |
| available = ", ".join(sorted(TEXT_ENCODER_PRESETS)) | |
| raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}") | |
| preset = TEXT_ENCODER_PRESETS[text_encoder_name] | |
| token = _get_hf_token() | |
| text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR") | |
| if text_encoders_dir: | |
| base_model_path = os.path.join(text_encoders_dir, preset["kwargs"]["base_model_name_or_path"]) | |
| peft_model_path = os.path.join(text_encoders_dir, preset["kwargs"]["peft_model_name_or_path"]) | |
| missing = [path for path in (base_model_path, peft_model_path) if not os.path.exists(path)] | |
| if missing: | |
| raise RuntimeError( | |
| "TEXT_ENCODERS_DIR is set, but the following local model paths are missing: " | |
| + ", ".join(missing) | |
| ) | |
| return | |
| if not token: | |
| raise RuntimeError( | |
| "HF token is missing. Set one of HF_TOKEN, HUGGING_FACE_HUB_TOKEN, HF_HUB_TOKEN, or " | |
| "HUGGINGFACEHUB_API_TOKEN before starting the text encoder server." | |
| ) | |
| api = HfApi() | |
| for repo_id, label in ( | |
| (preset["kwargs"]["base_model_name_or_path"], "base model"), | |
| (preset["kwargs"]["peft_model_name_or_path"], "PEFT adapter"), | |
| ): | |
| try: | |
| api.model_info(repo_id=repo_id, token=token) | |
| except Exception as error: | |
| raise RuntimeError(f"Failed to access {label} '{repo_id}' with the configured HF token: {error}") from error | |
| class DemoWrapper: | |
| def __init__(self, text_encoder_name, tmp_folder): | |
| self.text_encoder_name = text_encoder_name | |
| self.text_encoder = None | |
| self.init_error = None | |
| self.tmp_folder = tmp_folder | |
| def _get_text_encoder(self): | |
| if self.text_encoder is not None: | |
| return self.text_encoder | |
| if self.init_error is not None: | |
| raise RuntimeError(self.init_error) | |
| try: | |
| self.text_encoder = _build_text_encoder(self.text_encoder_name) | |
| return self.text_encoder | |
| except Exception as error: | |
| self.init_error = error | |
| raise | |
| def __call__(self, text, filename, progress=gr.Progress()): | |
| try: | |
| text_encoder = self._get_text_encoder() | |
| except Exception as error: | |
| output_title = gr.Markdown(visible=True, value="## Encoder initialization failed") | |
| output_text = gr.Markdown( | |
| visible=True, | |
| value=( | |
| "Text encoder could not initialize. " | |
| "If you use gated Hugging Face models, configure a valid HF token in the runtime env.\n\n" | |
| f"Error: `{type(error).__name__}: {error}`" | |
| ), | |
| ) | |
| download = gr.DownloadButton(visible=False) | |
| return download, output_title, output_text | |
| # Compute text embedding | |
| tensor, length = text_encoder(text) | |
| embedding = tensor[:length] | |
| embedding = embedding.cpu().numpy() | |
| # Save text embedding | |
| path = os.path.join(self.tmp_folder, filename) | |
| np.save(path, embedding) | |
| output_title = gr.Markdown(visible=True) | |
| output_text = gr.Markdown(visible=True, value=f"Text: {text}") | |
| download = gr.DownloadButton(visible=True, value=path) | |
| return download, output_title, output_text | |
| def _get_env(name: str, default): | |
| return os.getenv(name, default) | |
| def _build_text_encoder(name: str): | |
| if name not in TEXT_ENCODER_PRESETS: | |
| available = ", ".join(sorted(TEXT_ENCODER_PRESETS)) | |
| raise ValueError(f"Unknown TEXT_ENCODER='{name}'. Available: {available}") | |
| preset = TEXT_ENCODER_PRESETS[name] | |
| target_cls = resolve_target(preset["target"]) | |
| return target_cls(**preset["kwargs"]) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Run text encoder Gradio server.") | |
| parser.add_argument( | |
| "--text-encoder", | |
| default=_get_env("TEXT_ENCODER", DEFAULT_TEXT_ENCODER), | |
| choices=sorted(TEXT_ENCODER_PRESETS.keys()), | |
| help="Text encoder preset.", | |
| ) | |
| parser.add_argument( | |
| "--tmp-folder", | |
| default=_get_env("TEXT_ENCODER_TMP_FOLDER", DEFAULT_TMP_FOLDER), | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| server_name = _get_env("GRADIO_SERVER_NAME", DEFAULT_SERVER_NAME) | |
| server_port = int(os.environ.get("GRADIO_SERVER_PORT") or os.environ.get("PORT", str(DEFAULT_SERVER_PORT))) | |
| theme, css = get_gradio_theme() | |
| # Avoid Spaces hot-reload watcher importing `spaces` after CUDA init. | |
| os.environ.setdefault("GRADIO_HOT_RELOAD", "false") | |
| os.makedirs(args.tmp_folder, exist_ok=True) | |
| display_name = TEXT_ENCODER_PRESETS[args.text_encoder]["display_name"] | |
| if _get_env("TEXT_ENCODER_VALIDATE_STARTUP", "1") != "0": | |
| _validate_text_encoder_startup(args.text_encoder) | |
| # Suppress model loading during DemoWrapper initialization to allow graceful degradation | |
| # Model will be loaded lazily on first request | |
| demo_wrapper_fn = DemoWrapper(args.text_encoder, args.tmp_folder) | |
| with gr.Blocks(title="Text encoder") as demo: | |
| gr.Markdown(f"# Text encoder: {display_name}") | |
| gr.Markdown("## Description") | |
| gr.Markdown("Get a embeddings from a text.") | |
| gr.Markdown("## Inputs") | |
| with gr.Row(): | |
| text = gr.Textbox( | |
| placeholder="Type the motion you want to generate with a sentence", | |
| show_label=True, | |
| label="Text prompt", | |
| value=DEFAULT_TEXT, | |
| type="text", | |
| ) | |
| with gr.Row(scale=3): | |
| with gr.Column(scale=1): | |
| btn = gr.Button("Encode", variant="primary") | |
| with gr.Column(scale=1): | |
| clear = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=3): | |
| pass | |
| output_title = gr.Markdown("## Outputs", visible=False) | |
| output_text = gr.Markdown("", visible=False) | |
| with gr.Row(scale=3): | |
| with gr.Column(scale=1): | |
| download = gr.DownloadButton("Download", variant="primary", visible=False) | |
| with gr.Column(scale=4): | |
| pass | |
| filename = gr.Textbox( | |
| visible=False, | |
| value="embedding.npy", | |
| ) | |
| def clear_fn(): | |
| return [ | |
| gr.DownloadButton(visible=False), | |
| gr.Markdown(visible=False), | |
| gr.Markdown(visible=False), | |
| ] | |
| outputs = [download, output_title, output_text] | |
| gr.on( | |
| triggers=[text.submit, btn.click], | |
| fn=clear_fn, | |
| inputs=None, | |
| outputs=outputs, | |
| ).then( | |
| fn=demo_wrapper_fn, | |
| inputs=[text, filename], | |
| outputs=outputs, | |
| ) | |
| def download_file(): | |
| return gr.DownloadButton() | |
| download.click( | |
| fn=download_file, | |
| inputs=None, | |
| outputs=[download], | |
| ) | |
| clear.click(fn=clear_fn, inputs=None, outputs=outputs) | |
| demo.launch(server_name=server_name, server_port=server_port, css=css, theme=theme) | |
| if __name__ == "__main__": | |
| main() | |