File size: 8,713 Bytes
6d5047c
 
 
 
 
 
 
 
4be5ba2
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560cef6
 
6d5047c
 
 
 
 
 
 
 
4be5ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5430059
 
6d5047c
 
4be5ba2
 
 
6d5047c
 
 
 
 
5430059
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5430059
6d5047c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import os

import gradio as gr
import numpy as np
from huggingface_hub import HfApi

from kimodo.model import resolve_target

from .gradio_theme import get_gradio_theme

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
DEFAULT_TEXT = "A person walks and falls to the ground."
DEFAULT_SERVER_NAME = "0.0.0.0"
DEFAULT_SERVER_PORT = 9550
DEFAULT_TMP_FOLDER = "/tmp/text_encoder/"
DEFAULT_TEXT_ENCODER = "llm2vec"
TEXT_ENCODER_PRESETS = {
    "llm2vec": {
        "target": "kimodo.model.LLM2VecEncoder",
        "kwargs": {
            "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
            "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
            "dtype": "bfloat16",
            "llm_dim": 4096,
        },
        "display_name": "LLM2Vec",
    }
}


def _get_hf_token() -> str | None:
    return (
        os.environ.get("HF_TOKEN")
        or os.environ.get("HUGGING_FACE_HUB_TOKEN")
        or os.environ.get("HF_HUB_TOKEN")
        or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
    )


def _validate_text_encoder_startup(text_encoder_name: str) -> None:
    """Fail fast before launching Gradio if the text encoder cannot be resolved."""
    if text_encoder_name not in TEXT_ENCODER_PRESETS:
        available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
        raise ValueError(f"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}")

    preset = TEXT_ENCODER_PRESETS[text_encoder_name]
    token = _get_hf_token()
    text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR")

    if text_encoders_dir:
        base_model_path = os.path.join(text_encoders_dir, preset["kwargs"]["base_model_name_or_path"])
        peft_model_path = os.path.join(text_encoders_dir, preset["kwargs"]["peft_model_name_or_path"])
        missing = [path for path in (base_model_path, peft_model_path) if not os.path.exists(path)]
        if missing:
            raise RuntimeError(
                "TEXT_ENCODERS_DIR is set, but the following local model paths are missing: "
                + ", ".join(missing)
            )
        return

    if not token:
        raise RuntimeError(
            "HF token is missing. Set one of HF_TOKEN, HUGGING_FACE_HUB_TOKEN, HF_HUB_TOKEN, or "
            "HUGGINGFACEHUB_API_TOKEN before starting the text encoder server."
        )

    api = HfApi()
    for repo_id, label in (
        (preset["kwargs"]["base_model_name_or_path"], "base model"),
        (preset["kwargs"]["peft_model_name_or_path"], "PEFT adapter"),
    ):
        try:
            api.model_info(repo_id=repo_id, token=token)
        except Exception as error:
            raise RuntimeError(f"Failed to access {label} '{repo_id}' with the configured HF token: {error}") from error


class DemoWrapper:
    def __init__(self, text_encoder_name, tmp_folder):
        self.text_encoder_name = text_encoder_name
        self.text_encoder = None
        self.init_error = None
        self.tmp_folder = tmp_folder

    def _get_text_encoder(self):
        if self.text_encoder is not None:
            return self.text_encoder
        if self.init_error is not None:
            raise RuntimeError(self.init_error)
        try:
            self.text_encoder = _build_text_encoder(self.text_encoder_name)
            return self.text_encoder
        except Exception as error:
            self.init_error = error
            raise

    def __call__(self, text, filename, progress=gr.Progress()):
        try:
            text_encoder = self._get_text_encoder()
        except Exception as error:
            output_title = gr.Markdown(visible=True, value="## Encoder initialization failed")
            output_text = gr.Markdown(
                visible=True,
                value=(
                    "Text encoder could not initialize. "
                    "If you use gated Hugging Face models, configure a valid HF token in the runtime env.\n\n"
                    f"Error: `{type(error).__name__}: {error}`"
                ),
            )
            download = gr.DownloadButton(visible=False)
            return download, output_title, output_text

        # Compute text embedding
        tensor, length = text_encoder(text)
        embedding = tensor[:length]
        embedding = embedding.cpu().numpy()

        # Save text embedding
        path = os.path.join(self.tmp_folder, filename)
        np.save(path, embedding)

        output_title = gr.Markdown(visible=True)
        output_text = gr.Markdown(visible=True, value=f"Text: {text}")
        download = gr.DownloadButton(visible=True, value=path)
        return download, output_title, output_text


def _get_env(name: str, default):
    return os.getenv(name, default)


def _build_text_encoder(name: str):
    if name not in TEXT_ENCODER_PRESETS:
        available = ", ".join(sorted(TEXT_ENCODER_PRESETS))
        raise ValueError(f"Unknown TEXT_ENCODER='{name}'. Available: {available}")
    preset = TEXT_ENCODER_PRESETS[name]
    target_cls = resolve_target(preset["target"])
    return target_cls(**preset["kwargs"])


def parse_args():
    parser = argparse.ArgumentParser(description="Run text encoder Gradio server.")
    parser.add_argument(
        "--text-encoder",
        default=_get_env("TEXT_ENCODER", DEFAULT_TEXT_ENCODER),
        choices=sorted(TEXT_ENCODER_PRESETS.keys()),
        help="Text encoder preset.",
    )
    parser.add_argument(
        "--tmp-folder",
        default=_get_env("TEXT_ENCODER_TMP_FOLDER", DEFAULT_TMP_FOLDER),
    )
    return parser.parse_args()


def main():
    args = parse_args()
    server_name = _get_env("GRADIO_SERVER_NAME", DEFAULT_SERVER_NAME)
    server_port = int(os.environ.get("GRADIO_SERVER_PORT") or os.environ.get("PORT", str(DEFAULT_SERVER_PORT)))
    theme, css = get_gradio_theme()
    # Avoid Spaces hot-reload watcher importing `spaces` after CUDA init.
    os.environ.setdefault("GRADIO_HOT_RELOAD", "false")
    os.makedirs(args.tmp_folder, exist_ok=True)
    display_name = TEXT_ENCODER_PRESETS[args.text_encoder]["display_name"]

    if _get_env("TEXT_ENCODER_VALIDATE_STARTUP", "1") != "0":
        _validate_text_encoder_startup(args.text_encoder)
    
    # Suppress model loading during DemoWrapper initialization to allow graceful degradation
    # Model will be loaded lazily on first request
    demo_wrapper_fn = DemoWrapper(args.text_encoder, args.tmp_folder)

    with gr.Blocks(title="Text encoder") as demo:
        gr.Markdown(f"# Text encoder: {display_name}")
        gr.Markdown("## Description")
        gr.Markdown("Get a embeddings from a text.")

        gr.Markdown("## Inputs")
        with gr.Row():
            text = gr.Textbox(
                placeholder="Type the motion you want to generate with a sentence",
                show_label=True,
                label="Text prompt",
                value=DEFAULT_TEXT,
                type="text",
            )
        with gr.Row(scale=3):
            with gr.Column(scale=1):
                btn = gr.Button("Encode", variant="primary")
            with gr.Column(scale=1):
                clear = gr.Button("Clear", variant="secondary")
            with gr.Column(scale=3):
                pass

        output_title = gr.Markdown("## Outputs", visible=False)
        output_text = gr.Markdown("", visible=False)
        with gr.Row(scale=3):
            with gr.Column(scale=1):
                download = gr.DownloadButton("Download", variant="primary", visible=False)
            with gr.Column(scale=4):
                pass

        filename = gr.Textbox(
            visible=False,
            value="embedding.npy",
        )

        def clear_fn():
            return [
                gr.DownloadButton(visible=False),
                gr.Markdown(visible=False),
                gr.Markdown(visible=False),
            ]

        outputs = [download, output_title, output_text]

        gr.on(
            triggers=[text.submit, btn.click],
            fn=clear_fn,
            inputs=None,
            outputs=outputs,
        ).then(
            fn=demo_wrapper_fn,
            inputs=[text, filename],
            outputs=outputs,
        )

        def download_file():
            return gr.DownloadButton()

        download.click(
            fn=download_file,
            inputs=None,
            outputs=[download],
        )
        clear.click(fn=clear_fn, inputs=None, outputs=outputs)

    demo.launch(server_name=server_name, server_port=server_port, css=css, theme=theme)


if __name__ == "__main__":
    main()