rydlrKE commited on
Commit
0d13d79
·
verified ·
1 Parent(s): 3f43f78

fix: lazy TextEncoderAPI client with retry + HTTP readiness gate

Browse files
Files changed (1) hide show
  1. kimodo/model/text_encoder_api.py +58 -29
kimodo/model/text_encoder_api.py CHANGED
@@ -4,6 +4,8 @@
4
 
5
  import logging
6
 
 
 
7
  import numpy as np
8
  import torch
9
  from gradio_client import Client
@@ -19,17 +21,34 @@ class TextEncoderAPI:
19
  """Text encoder API client for motion generation."""
20
 
21
  def __init__(self, url: str):
22
- # Keep startup resilient: do not connect during app/model initialization.
23
- # In strict API mode, we only attempt network calls when embeddings are requested.
24
  self.url = url
25
  self.client = None
26
  self.device = "cpu"
27
  self.dtype = torch.float
28
 
29
- def _get_client(self):
30
- if self.client is None:
31
- self.client = Client(self.url, verbose=False)
32
- return self.client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def _create_np_random_name(self):
35
  import uuid
@@ -51,33 +70,43 @@ class TextEncoderAPI:
51
  elif result is not None:
52
  candidates = [result]
53
 
54
- # First pass: check for valid .npy paths
55
  for item in candidates:
56
- if isinstance(item, str) and item and item.endswith(".npy"):
57
- return item
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if isinstance(item, dict):
59
  for key in ("value", "path", "name"):
60
  value = item.get(key)
61
- if isinstance(value, str) and value.endswith(".npy"):
62
- return value
63
-
64
- # Second pass: collect all error indicators
65
- error_parts = []
66
- for item in candidates:
67
- if isinstance(item, str) and item:
68
- if item.startswith("##") or "failed" in item.lower() or "error" in item.lower():
69
- error_parts.append(item.strip())
70
- if isinstance(item, dict):
71
- value = item.get("value")
72
- if isinstance(value, str) and (
73
- value.startswith("##") or "failed" in value.lower() or "error" in value.lower()
74
- ):
75
- error_parts.append(value.strip())
76
-
77
- if error_parts:
78
- # Combine all error messages
79
- full_error = "\n".join(error_parts)
80
- raise RuntimeError(f"Text encoder initialization failed:\n{full_error}")
81
 
82
  raise RuntimeError(f"Text encoder API returned unexpected payload: {result!r}")
83
 
 
4
 
5
  import logging
6
 
7
+ import os
8
+
9
  import numpy as np
10
  import torch
11
  from gradio_client import Client
 
21
  """Text encoder API client for motion generation."""
22
 
23
  def __init__(self, url: str):
 
 
24
  self.url = url
25
  self.client = None
26
  self.device = "cpu"
27
  self.dtype = torch.float
28
 
29
+ def _get_client(self) -> Client:
30
+ """Lazily create the Gradio client, retrying until the server is ready."""
31
+ if self.client is not None:
32
+ return self.client
33
+ import time
34
+
35
+ client_timeout_sec = int(os.environ.get("TEXT_ENCODER_CLIENT_TIMEOUT_SEC", "180"))
36
+ deadline = time.monotonic() + client_timeout_sec
37
+ last_exc: Exception | None = None
38
+ delay = 2.0
39
+ while time.monotonic() < deadline:
40
+ try:
41
+ self.client = Client(self.url, verbose=False)
42
+ return self.client
43
+ except Exception as exc:
44
+ last_exc = exc
45
+ print(f"[text_encoder_api] Client init failed ({exc}), retrying in {delay:.0f}s …")
46
+ time.sleep(delay)
47
+ delay = min(delay * 1.5, 20.0)
48
+ raise RuntimeError(
49
+ f"Text encoder at {self.url!r} did not become ready within {client_timeout_sec}s. "
50
+ f"Last error: {last_exc}"
51
+ )
52
 
53
  def _create_np_random_name(self):
54
  import uuid
 
70
  elif result is not None:
71
  candidates = [result]
72
 
 
73
  for item in candidates:
74
+ # Check for error messages first (e.g., "## Encoder initialization failed")
75
+ if isinstance(item, str):
76
+ if item and item.startswith("##"):
77
+ # This is an error message from the Gradio server
78
+ error_msg = item.replace("##", "").strip()
79
+ if "initialization failed" in error_msg.lower():
80
+ raise RuntimeError(
81
+ f"Text encoder initialization failed. This usually indicates:\n"
82
+ f" - Missing or invalid HF_TOKEN for gated models (Llama-3)\n"
83
+ f" - Poor network connectivity during model download\n"
84
+ f" Original error: {error_msg}"
85
+ )
86
+ raise RuntimeError(f"Text encoder API error: {error_msg}")
87
+ if "failed" in item.lower() or "error" in item.lower():
88
+ raise RuntimeError(f"Text encoder API error: {item}")
89
+ if item and item.endswith(".npy"):
90
+ return item
91
+ if item:
92
+ # Log unexpected string for debugging
93
+ print(f"[text_encoder_api] unexpected string response: {item[:100]}")
94
+
95
  if isinstance(item, dict):
96
  for key in ("value", "path", "name"):
97
  value = item.get(key)
98
+ if isinstance(value, str) and value:
99
+ # Check for errors in dict values too
100
+ if "initialization failed" in value.lower():
101
+ raise RuntimeError(
102
+ f"Text encoder initialization failed. This usually indicates:\n"
103
+ f" - Missing or invalid HF_TOKEN for gated models (Llama-3)\n"
104
+ f" - Poor network connectivity during model download"
105
+ )
106
+ if value.startswith("##") or "failed" in value.lower() or "error" in value.lower():
107
+ raise RuntimeError(f"Text encoder API error: {value}")
108
+ if value.endswith(".npy"):
109
+ return value
 
 
 
 
 
 
 
 
110
 
111
  raise RuntimeError(f"Text encoder API returned unexpected payload: {result!r}")
112