Buckets:
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Remote text encoder API client (Gradio) for motion generation.""" | |
| import logging | |
| import numpy as np | |
| import torch | |
| from gradio_client import Client | |
| # Suppress the [httpx] logs (GET requests) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| # Suppress internal gradio_client logs | |
| logging.getLogger("gradio_client").setLevel(logging.WARNING) | |
| class TextEncoderAPI: | |
| """Text encoder API client for motion generation.""" | |
| def __init__(self, url: str): | |
| self.client = Client(url, verbose=False) | |
| self.device = "cpu" | |
| self.dtype = torch.float | |
| def _create_np_random_name(self): | |
| import uuid | |
| return str(uuid.uuid4()) + ".npy" | |
| def to(self, device=None, dtype=None): | |
| if device is not None: | |
| self.device = device | |
| if dtype is not None: | |
| self.dtype = dtype | |
| return self | |
| def __call__(self, texts): | |
| """Encode text prompts into tensors. | |
| Args: | |
| texts (str | list[str]): text prompts to encode | |
| Returns: | |
| tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths | |
| """ | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| tensors = [] | |
| lengths = [] | |
| for text in texts: | |
| filename = self._create_np_random_name() | |
| # Use a long result timeout to tolerate text-encoder cold-start (LLM2Vec model load ~60-120s). | |
| result = self.client.submit( | |
| text=text, | |
| filename=filename, | |
| api_name="/DemoWrapper", | |
| ).result(timeout=300) | |
| path = result[0]["value"] | |
| tensor = np.load(path) | |
| length = tensor.shape[0] | |
| tensors.append(tensor) | |
| lengths.append(length) | |
| padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype) | |
| for idx, (tensor, length) in enumerate(zip(tensors, lengths)): | |
| padded_tensor[idx, :length] = tensor | |
| padded_tensor = torch.from_numpy(padded_tensor) | |
| padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype) | |
| return padded_tensor, lengths | |
Xet Storage Details
- Size:
- 2.33 kB
- Xet hash:
- 8b5b17931c20d4ea8aef8cd6ca55d74df0b3905cf78cdde5d726aa984ec04a7c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.