rydlrKE commited on
Commit
de482a9
·
1 Parent(s): 560cef6

Fix text encoder payload parsing and add remote encoder switch

Browse files
Files changed (2) hide show
  1. app.py +21 -5
  2. kimodo/model/text_encoder_api.py +11 -0
app.py CHANGED
@@ -31,10 +31,21 @@ os.environ.setdefault("TEXT_ENCODER", "llm2vec")
31
  os.environ.setdefault("LLM2VEC_BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
32
  os.environ.setdefault(
33
  "LLM2VEC_PEFT_MODEL",
34
- "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
35
  )
 
 
 
 
 
36
  TEXT_ENCODER_PORT = int(os.environ.get("TEXT_ENCODER_PORT", "9550"))
37
- os.environ.setdefault("TEXT_ENCODER_URL", f"http://127.0.0.1:{TEXT_ENCODER_PORT}/")
 
 
 
 
 
 
38
  # Prefer CPU on ZeroGPU to avoid low-level CUDA init crashes during model load.
39
  os.environ.setdefault("KIMODO_DEVICE", "cpu")
40
 
@@ -84,15 +95,20 @@ def main() -> None:
84
  # Invoke GPU function to satisfy HF Spaces startup requirement.
85
  _gpu_healthcheck()
86
 
87
- # Keep existing embedding pipeline (TextEncoderAPI -> local llm2vec server).
88
- text_encoder_proc = _start_text_encoder_server()
 
 
 
 
89
 
90
  import kimodo
91
  from kimodo.demo.app import Demo
92
 
93
  print(f"[movimento][boot] kimodo_module={getattr(kimodo, '__file__', 'unknown')}")
94
  print(f"[movimento][boot] mode=native_direct port={PORT}")
95
- print(f"[movimento][boot] text_encoder_pid={text_encoder_proc.pid}")
 
96
  Demo()
97
 
98
  # Keep the process alive while Viser serves on SERVER_PORT.
 
31
  os.environ.setdefault("LLM2VEC_BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
32
  os.environ.setdefault(
33
  "LLM2VEC_PEFT_MODEL",
34
+ "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
35
  )
36
+ hf_token = os.environ.get("HF_TOKEN")
37
+ if hf_token:
38
+ os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token)
39
+ os.environ.setdefault("HF_HUB_TOKEN", hf_token)
40
+ os.environ.setdefault("HUGGINGFACEHUB_API_TOKEN", hf_token)
41
  TEXT_ENCODER_PORT = int(os.environ.get("TEXT_ENCODER_PORT", "9550"))
42
+ TEXT_ENCODER_SOURCE = os.environ.get("TEXT_ENCODER_SOURCE", "local").strip().lower()
43
+ if TEXT_ENCODER_SOURCE not in {"local", "remote"}:
44
+ raise RuntimeError("TEXT_ENCODER_SOURCE must be 'local' or 'remote'.")
45
+ if TEXT_ENCODER_SOURCE == "local":
46
+ os.environ.setdefault("TEXT_ENCODER_URL", f"http://127.0.0.1:{TEXT_ENCODER_PORT}/")
47
+ elif "TEXT_ENCODER_URL" not in os.environ:
48
+ raise RuntimeError("TEXT_ENCODER_URL is required when TEXT_ENCODER_SOURCE=remote.")
49
  # Prefer CPU on ZeroGPU to avoid low-level CUDA init crashes during model load.
50
  os.environ.setdefault("KIMODO_DEVICE", "cpu")
51
 
 
95
  # Invoke GPU function to satisfy HF Spaces startup requirement.
96
  _gpu_healthcheck()
97
 
98
+ text_encoder_proc = None
99
+ if TEXT_ENCODER_SOURCE == "local":
100
+ # Keep existing embedding pipeline (TextEncoderAPI -> local llm2vec server).
101
+ text_encoder_proc = _start_text_encoder_server()
102
+ else:
103
+ print(f"[movimento][boot] using remote text encoder: {os.environ['TEXT_ENCODER_URL']}")
104
 
105
  import kimodo
106
  from kimodo.demo.app import Demo
107
 
108
  print(f"[movimento][boot] kimodo_module={getattr(kimodo, '__file__', 'unknown')}")
109
  print(f"[movimento][boot] mode=native_direct port={PORT}")
110
+ if text_encoder_proc is not None:
111
+ print(f"[movimento][boot] text_encoder_pid={text_encoder_proc.pid}")
112
  Demo()
113
 
114
  # Keep the process alive while Viser serves on SERVER_PORT.
kimodo/model/text_encoder_api.py CHANGED
@@ -55,6 +55,11 @@ class TextEncoderAPI:
55
  for item in candidates:
56
  if isinstance(item, str) and item and item.endswith(".npy"):
57
  return item
 
 
 
 
 
58
 
59
  # Second pass: collect all error indicators
60
  error_parts = []
@@ -62,6 +67,12 @@ class TextEncoderAPI:
62
  if isinstance(item, str) and item:
63
  if item.startswith("##") or "failed" in item.lower() or "error" in item.lower():
64
  error_parts.append(item.strip())
 
 
 
 
 
 
65
 
66
  if error_parts:
67
  # Combine all error messages
 
55
  for item in candidates:
56
  if isinstance(item, str) and item and item.endswith(".npy"):
57
  return item
58
+ if isinstance(item, dict):
59
+ for key in ("value", "path", "name"):
60
+ value = item.get(key)
61
+ if isinstance(value, str) and value.endswith(".npy"):
62
+ return value
63
 
64
  # Second pass: collect all error indicators
65
  error_parts = []
 
67
  if isinstance(item, str) and item:
68
  if item.startswith("##") or "failed" in item.lower() or "error" in item.lower():
69
  error_parts.append(item.strip())
70
+ if isinstance(item, dict):
71
+ value = item.get("value")
72
+ if isinstance(value, str) and (
73
+ value.startswith("##") or "failed" in value.lower() or "error" in value.lower()
74
+ ):
75
+ error_parts.append(value.strip())
76
 
77
  if error_parts:
78
  # Combine all error messages