File size: 6,042 Bytes
6d5047c
 
 
 
 
 
0d13d79
 
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8f47a
 
6d5047c
 
 
0d13d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8f47a
6d5047c
 
 
 
 
 
 
 
 
 
 
 
540ad5a
b2d6609
540ad5a
 
 
 
 
 
 
0d13d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de482a9
 
 
0d13d79
 
 
 
 
 
 
 
 
 
 
 
540ad5a
0c5e838
540ad5a
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8f47a
6d5047c
 
 
 
540ad5a
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Remote text encoder API client (Gradio) for motion generation."""

import logging

import os

import numpy as np
import torch
from gradio_client import Client

# Suppress the [httpx] logs (GET requests)
logging.getLogger("httpx").setLevel(logging.WARNING)

# Suppress internal gradio_client logs
logging.getLogger("gradio_client").setLevel(logging.WARNING)


class TextEncoderAPI:
    """Text encoder API client for motion generation."""

    def __init__(self, url: str):
        self.url = url
        self.client = None
        self.device = "cpu"
        self.dtype = torch.float

    def _get_client(self) -> Client:
        """Lazily create the Gradio client, retrying until the server is ready."""
        if self.client is not None:
            return self.client
        import time

        client_timeout_sec = int(os.environ.get("TEXT_ENCODER_CLIENT_TIMEOUT_SEC", "180"))
        deadline = time.monotonic() + client_timeout_sec
        last_exc: Exception | None = None
        delay = 2.0
        while time.monotonic() < deadline:
            try:
                self.client = Client(self.url, verbose=False)
                return self.client
            except Exception as exc:
                last_exc = exc
                print(f"[text_encoder_api] Client init failed ({exc}), retrying in {delay:.0f}s …")
                time.sleep(delay)
                delay = min(delay * 1.5, 20.0)
        raise RuntimeError(
            f"Text encoder at {self.url!r} did not become ready within {client_timeout_sec}s. "
            f"Last error: {last_exc}"
        )

    def _create_np_random_name(self):
        import uuid

        return str(uuid.uuid4()) + ".npy"

    def to(self, device=None, dtype=None):
        if device is not None:
            self.device = device
        if dtype is not None:
            self.dtype = dtype
        return self

    def _extract_result_path(self, result):
        """Extract npy path from heterogeneous gradio_client responses with error detection."""
        candidates = []
        if isinstance(result, (list, tuple)):
            candidates = list(result)
        elif result is not None:
            candidates = [result]

        for item in candidates:
            # Check for error messages first (e.g., "## Encoder initialization failed")
            if isinstance(item, str):
                if item and item.startswith("##"):
                    # This is an error message from the Gradio server
                    error_msg = item.replace("##", "").strip()
                    if "initialization failed" in error_msg.lower():
                        raise RuntimeError(
                            f"Text encoder initialization failed. This usually indicates:\n"
                            f"  - Missing or invalid HF_TOKEN for gated models (Llama-3)\n"
                            f"  - Poor network connectivity during model download\n"
                            f"  Original error: {error_msg}"
                        )
                    raise RuntimeError(f"Text encoder API error: {error_msg}")
                if "failed" in item.lower() or "error" in item.lower():
                    raise RuntimeError(f"Text encoder API error: {item}")
                if item and item.endswith(".npy"):
                    return item
                if item:
                    # Log unexpected string for debugging
                    print(f"[text_encoder_api] unexpected string response: {item[:100]}")
                    
            if isinstance(item, dict):
                for key in ("value", "path", "name"):
                    value = item.get(key)
                    if isinstance(value, str) and value:
                        # Check for errors in dict values too
                        if "initialization failed" in value.lower():
                            raise RuntimeError(
                                f"Text encoder initialization failed. This usually indicates:\n"
                                f"  - Missing or invalid HF_TOKEN for gated models (Llama-3)\n"
                                f"  - Poor network connectivity during model download"
                            )
                        if value.startswith("##") or "failed" in value.lower() or "error" in value.lower():
                            raise RuntimeError(f"Text encoder API error: {value}")
                        if value.endswith(".npy"):
                            return value

        raise RuntimeError(f"Text encoder API returned unexpected payload: {result!r}")

    def __call__(self, texts):
        """Encode text prompts into tensors.

        Args:
            texts (str | list[str]): text prompts to encode

        Returns:
            tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths
        """
        if isinstance(texts, str):
            texts = [texts]

        tensors = []
        lengths = []
        for text in texts:
            filename = self._create_np_random_name()

            # Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s).
            result = self._get_client().submit(
                text=text,
                filename=filename,
                api_name="/DemoWrapper",
            ).result(timeout=300)
            path = self._extract_result_path(result)
            tensor = np.load(path)
            length = tensor.shape[0]

            tensors.append(tensor)
            lengths.append(length)

        padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype)
        for idx, (tensor, length) in enumerate(zip(tensors, lengths)):
            padded_tensor[idx, :length] = tensor

        padded_tensor = torch.from_numpy(padded_tensor)
        padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype)
        return padded_tensor, lengths