Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Remote text encoder API client (Gradio) for motion generation.""" | |
| import logging | |
| import os | |
| import numpy as np | |
| import torch | |
| from gradio_client import Client | |
| # Suppress the [httpx] logs (GET requests) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| # Suppress internal gradio_client logs | |
| logging.getLogger("gradio_client").setLevel(logging.WARNING) | |
| class TextEncoderAPI: | |
| """Text encoder API client for motion generation.""" | |
| def __init__(self, url: str): | |
| self.url = url | |
| self.client = None | |
| self.device = "cpu" | |
| self.dtype = torch.float | |
| def _get_client(self) -> Client: | |
| """Lazily create the Gradio client, retrying until the server is ready.""" | |
| if self.client is not None: | |
| return self.client | |
| import time | |
| client_timeout_sec = int(os.environ.get("TEXT_ENCODER_CLIENT_TIMEOUT_SEC", "180")) | |
| deadline = time.monotonic() + client_timeout_sec | |
| last_exc: Exception | None = None | |
| delay = 2.0 | |
| while time.monotonic() < deadline: | |
| try: | |
| self.client = Client(self.url, verbose=False) | |
| return self.client | |
| except Exception as exc: | |
| last_exc = exc | |
| print(f"[text_encoder_api] Client init failed ({exc}), retrying in {delay:.0f}s …") | |
| time.sleep(delay) | |
| delay = min(delay * 1.5, 20.0) | |
| raise RuntimeError( | |
| f"Text encoder at {self.url!r} did not become ready within {client_timeout_sec}s. " | |
| f"Last error: {last_exc}" | |
| ) | |
| def _create_np_random_name(self): | |
| import uuid | |
| return str(uuid.uuid4()) + ".npy" | |
| def to(self, device=None, dtype=None): | |
| if device is not None: | |
| self.device = device | |
| if dtype is not None: | |
| self.dtype = dtype | |
| return self | |
| def _extract_result_path(self, result): | |
| """Extract npy path from heterogeneous gradio_client responses with error detection.""" | |
| candidates = [] | |
| if isinstance(result, (list, tuple)): | |
| candidates = list(result) | |
| elif result is not None: | |
| candidates = [result] | |
| for item in candidates: | |
| # Check for error messages first (e.g., "## Encoder initialization failed") | |
| if isinstance(item, str): | |
| if item and item.startswith("##"): | |
| # This is an error message from the Gradio server | |
| error_msg = item.replace("##", "").strip() | |
| if "initialization failed" in error_msg.lower(): | |
| raise RuntimeError( | |
| f"Text encoder initialization failed. This usually indicates:\n" | |
| f" - Missing or invalid HF_TOKEN for gated models (Llama-3)\n" | |
| f" - Poor network connectivity during model download\n" | |
| f" Original error: {error_msg}" | |
| ) | |
| raise RuntimeError(f"Text encoder API error: {error_msg}") | |
| if "failed" in item.lower() or "error" in item.lower(): | |
| raise RuntimeError(f"Text encoder API error: {item}") | |
| if item and item.endswith(".npy"): | |
| return item | |
| if item: | |
| # Log unexpected string for debugging | |
| print(f"[text_encoder_api] unexpected string response: {item[:100]}") | |
| if isinstance(item, dict): | |
| for key in ("value", "path", "name"): | |
| value = item.get(key) | |
| if isinstance(value, str) and value: | |
| # Check for errors in dict values too | |
| if "initialization failed" in value.lower(): | |
| raise RuntimeError( | |
| f"Text encoder initialization failed. This usually indicates:\n" | |
| f" - Missing or invalid HF_TOKEN for gated models (Llama-3)\n" | |
| f" - Poor network connectivity during model download" | |
| ) | |
| if value.startswith("##") or "failed" in value.lower() or "error" in value.lower(): | |
| raise RuntimeError(f"Text encoder API error: {value}") | |
| if value.endswith(".npy"): | |
| return value | |
| raise RuntimeError(f"Text encoder API returned unexpected payload: {result!r}") | |
| def __call__(self, texts): | |
| """Encode text prompts into tensors. | |
| Args: | |
| texts (str | list[str]): text prompts to encode | |
| Returns: | |
| tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths | |
| """ | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| tensors = [] | |
| lengths = [] | |
| for text in texts: | |
| filename = self._create_np_random_name() | |
| # Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s). | |
| result = self._get_client().submit( | |
| text=text, | |
| filename=filename, | |
| api_name="/DemoWrapper", | |
| ).result(timeout=300) | |
| path = self._extract_result_path(result) | |
| tensor = np.load(path) | |
| length = tensor.shape[0] | |
| tensors.append(tensor) | |
| lengths.append(length) | |
| padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype) | |
| for idx, (tensor, length) in enumerate(zip(tensors, lengths)): | |
| padded_tensor[idx, :length] = tensor | |
| padded_tensor = torch.from_numpy(padded_tensor) | |
| padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype) | |
| return padded_tensor, lengths | |