Spaces:
Running on Zero
Running on Zero
fix(lyrics): re-create mlx_lm generation_stream in worker thread
Browse filesThe previous fix wrapped generate() in mx.stream(mx.gpu) but only
installed the *default* gpu stream for the current thread. mlx_lm
also uses a module-level generation_stream = mx.new_stream(...)
created at import time on the main thread; wired_limit().__exit__
then calls mx.synchronize(generation_stream) which crashes from the
Gradio worker thread with 'There is no Stream(gpu, 0) in current thread.'
Re-assign mlx_lm.generate.generation_stream to a fresh stream created
in the current thread before every generate() call. Safe because
Gradio runs at default_concurrency_limit=1.
- lyrics_lm.py +16 -6
lyrics_lm.py
CHANGED
|
@@ -133,6 +133,7 @@ class _MLXLM:
|
|
| 133 |
|
| 134 |
def generate(self, system: str, user: str, **kw: Any) -> str:
|
| 135 |
import mlx.core as mx # type: ignore[import-not-found]
|
|
|
|
| 136 |
from mlx_lm import generate # type: ignore[import-not-found]
|
| 137 |
|
| 138 |
# Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
|
|
@@ -143,13 +144,22 @@ class _MLXLM:
|
|
| 143 |
f"<|im_start|>assistant\n"
|
| 144 |
)
|
| 145 |
# Gradio runs handlers in anyio worker threads. MLX maintains a
|
| 146 |
-
# *per-thread* default stream and
|
| 147 |
-
#
|
| 148 |
-
#
|
| 149 |
-
#
|
| 150 |
-
#
|
| 151 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
with mx.stream(mx.gpu):
|
|
|
|
| 153 |
return generate(
|
| 154 |
self.model,
|
| 155 |
self.tokenizer,
|
|
|
|
| 133 |
|
| 134 |
def generate(self, system: str, user: str, **kw: Any) -> str:
|
| 135 |
import mlx.core as mx # type: ignore[import-not-found]
|
| 136 |
+
import mlx_lm.generate as mlx_gen_mod # type: ignore[import-not-found]
|
| 137 |
from mlx_lm import generate # type: ignore[import-not-found]
|
| 138 |
|
| 139 |
# Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
|
|
|
|
| 144 |
f"<|im_start|>assistant\n"
|
| 145 |
)
|
| 146 |
# Gradio runs handlers in anyio worker threads. MLX maintains a
|
| 147 |
+
# *per-thread* default stream and a module-level ``generation_stream``
|
| 148 |
+
# that was created at mlx_lm import time on the MAIN thread. Both
|
| 149 |
+
# need to be valid in the *current* (worker) thread or
|
| 150 |
+
# ``wired_limit().__exit__`` crashes with "There is no Stream(gpu, 0)
|
| 151 |
+
# in current thread" when it calls ``mx.synchronize(generation_stream)``.
|
| 152 |
+
#
|
| 153 |
+
# Two-part fix:
|
| 154 |
+
# 1. ``mx.stream(mx.gpu)`` wrap installs the default GPU stream
|
| 155 |
+
# for the current thread for the duration of the call.
|
| 156 |
+
# 2. Re-assign ``mlx_lm.generate.generation_stream`` to a stream
|
| 157 |
+
# created in the *current* thread so ``mx.synchronize`` doesn't
|
| 158 |
+
# reach across thread boundaries. The reassignment is safe
|
| 159 |
+
# because Gradio's queue runs at default_concurrency_limit=1 —
|
| 160 |
+
# no two lyrics drafts run concurrently.
|
| 161 |
with mx.stream(mx.gpu):
|
| 162 |
+
mlx_gen_mod.generation_stream = mx.new_stream(mx.default_device())
|
| 163 |
return generate(
|
| 164 |
self.model,
|
| 165 |
self.tokenizer,
|