Spaces:
Running on Zero
Running on Zero
fix(lyrics): wrap mlx-lm generate in mx.stream(mx.gpu) for Gradio worker threads
Browse filesGradio dispatches handlers via anyio.to_thread.run_sync; the worker
thread's per-thread MLX default stream is uninitialised, so
mlx-lm's wired_limit() fails with RuntimeError 'There is no Stream(gpu, 0)
in current thread.' Wrapping the generate() call in an mx.stream(mx.gpu)
context installs a GPU stream for the duration of the call.
- lyrics_lm.py +15 -6
lyrics_lm.py
CHANGED
|
@@ -132,6 +132,7 @@ class _MLXLM:
|
|
| 132 |
tokenizer: Any
|
| 133 |
|
| 134 |
def generate(self, system: str, user: str, **kw: Any) -> str:
|
|
|
|
| 135 |
from mlx_lm import generate # type: ignore[import-not-found]
|
| 136 |
|
| 137 |
# Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
|
|
@@ -141,12 +142,20 @@ class _MLXLM:
|
|
| 141 |
f"<|im_start|>user\n{user}<|im_end|>\n"
|
| 142 |
f"<|im_start|>assistant\n"
|
| 143 |
)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
@dataclass
|
|
|
|
| 132 |
tokenizer: Any
|
| 133 |
|
| 134 |
def generate(self, system: str, user: str, **kw: Any) -> str:
|
| 135 |
+
import mlx.core as mx # type: ignore[import-not-found]
|
| 136 |
from mlx_lm import generate # type: ignore[import-not-found]
|
| 137 |
|
| 138 |
# Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
|
|
|
|
| 142 |
f"<|im_start|>user\n{user}<|im_end|>\n"
|
| 143 |
f"<|im_start|>assistant\n"
|
| 144 |
)
|
| 145 |
+
# Gradio runs handlers in anyio worker threads. MLX maintains a
|
| 146 |
+
# *per-thread* default stream and bails with "There is no
|
| 147 |
+
# Stream(gpu, 0) in current thread" when a worker thread that
|
| 148 |
+
# didn't create the GPU stream tries to use it. Wrapping the
|
| 149 |
+
# generate() call in ``mx.stream(mx.gpu)`` installs a GPU stream
|
| 150 |
+
# for the current thread for the duration of the context, which
|
| 151 |
+
# is what mlx-lm's wired_limit() helper expects.
|
| 152 |
+
with mx.stream(mx.gpu):
|
| 153 |
+
return generate(
|
| 154 |
+
self.model,
|
| 155 |
+
self.tokenizer,
|
| 156 |
+
prompt=prompt,
|
| 157 |
+
max_tokens=int(kw.get("max_new_tokens", 600)),
|
| 158 |
+
)
|
| 159 |
|
| 160 |
|
| 161 |
@dataclass
|