Spaces:
Running on Zero
Running on Zero
File size: 6,042 Bytes
6d5047c 0d13d79 6d5047c 6e8f47a 6d5047c 0d13d79 6e8f47a 6d5047c 540ad5a b2d6609 540ad5a 0d13d79 de482a9 0d13d79 540ad5a 0c5e838 540ad5a 6d5047c 6e8f47a 6d5047c 540ad5a 6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Remote text encoder API client (Gradio) for motion generation."""
import logging
import os
import numpy as np
import torch
from gradio_client import Client
# Suppress the [httpx] logs (GET requests)
logging.getLogger("httpx").setLevel(logging.WARNING)
# Suppress internal gradio_client logs
logging.getLogger("gradio_client").setLevel(logging.WARNING)
class TextEncoderAPI:
"""Text encoder API client for motion generation."""
def __init__(self, url: str):
self.url = url
self.client = None
self.device = "cpu"
self.dtype = torch.float
def _get_client(self) -> Client:
"""Lazily create the Gradio client, retrying until the server is ready."""
if self.client is not None:
return self.client
import time
client_timeout_sec = int(os.environ.get("TEXT_ENCODER_CLIENT_TIMEOUT_SEC", "180"))
deadline = time.monotonic() + client_timeout_sec
last_exc: Exception | None = None
delay = 2.0
while time.monotonic() < deadline:
try:
self.client = Client(self.url, verbose=False)
return self.client
except Exception as exc:
last_exc = exc
print(f"[text_encoder_api] Client init failed ({exc}), retrying in {delay:.0f}s …")
time.sleep(delay)
delay = min(delay * 1.5, 20.0)
raise RuntimeError(
f"Text encoder at {self.url!r} did not become ready within {client_timeout_sec}s. "
f"Last error: {last_exc}"
)
def _create_np_random_name(self):
import uuid
return str(uuid.uuid4()) + ".npy"
def to(self, device=None, dtype=None):
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
return self
def _extract_result_path(self, result):
"""Extract npy path from heterogeneous gradio_client responses with error detection."""
candidates = []
if isinstance(result, (list, tuple)):
candidates = list(result)
elif result is not None:
candidates = [result]
for item in candidates:
# Check for error messages first (e.g., "## Encoder initialization failed")
if isinstance(item, str):
if item and item.startswith("##"):
# This is an error message from the Gradio server
error_msg = item.replace("##", "").strip()
if "initialization failed" in error_msg.lower():
raise RuntimeError(
f"Text encoder initialization failed. This usually indicates:\n"
f" - Missing or invalid HF_TOKEN for gated models (Llama-3)\n"
f" - Poor network connectivity during model download\n"
f" Original error: {error_msg}"
)
raise RuntimeError(f"Text encoder API error: {error_msg}")
if "failed" in item.lower() or "error" in item.lower():
raise RuntimeError(f"Text encoder API error: {item}")
if item and item.endswith(".npy"):
return item
if item:
# Log unexpected string for debugging
print(f"[text_encoder_api] unexpected string response: {item[:100]}")
if isinstance(item, dict):
for key in ("value", "path", "name"):
value = item.get(key)
if isinstance(value, str) and value:
# Check for errors in dict values too
if "initialization failed" in value.lower():
raise RuntimeError(
f"Text encoder initialization failed. This usually indicates:\n"
f" - Missing or invalid HF_TOKEN for gated models (Llama-3)\n"
f" - Poor network connectivity during model download"
)
if value.startswith("##") or "failed" in value.lower() or "error" in value.lower():
raise RuntimeError(f"Text encoder API error: {value}")
if value.endswith(".npy"):
return value
raise RuntimeError(f"Text encoder API returned unexpected payload: {result!r}")
def __call__(self, texts):
"""Encode text prompts into tensors.
Args:
texts (str | list[str]): text prompts to encode
Returns:
tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths
"""
if isinstance(texts, str):
texts = [texts]
tensors = []
lengths = []
for text in texts:
filename = self._create_np_random_name()
# Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s).
result = self._get_client().submit(
text=text,
filename=filename,
api_name="/DemoWrapper",
).result(timeout=300)
path = self._extract_result_path(result)
tensor = np.load(path)
length = tensor.shape[0]
tensors.append(tensor)
lengths.append(length)
padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype)
for idx, (tensor, length) in enumerate(zip(tensors, lengths)):
padded_tensor[idx, :length] = tensor
padded_tensor = torch.from_numpy(padded_tensor)
padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype)
return padded_tensor, lengths
|