techfreakworm commited on
Commit
0d2d5a8
·
unverified ·
1 Parent(s): fc8c46f

fix(lyrics): re-create mlx_lm generation_stream in worker thread

Browse files

The previous fix wrapped generate() in mx.stream(mx.gpu) but only
installed the *default* gpu stream for the current thread. mlx_lm
also uses a module-level generation_stream = mx.new_stream(...)
created at import time on the main thread; wired_limit().__exit__
then calls mx.synchronize(generation_stream) which crashes from the
Gradio worker thread with 'There is no Stream(gpu, 0) in current thread.'

Re-assign mlx_lm.generate.generation_stream to a fresh stream created
in the current thread before every generate() call. Safe because
Gradio runs at default_concurrency_limit=1.

Files changed (1) hide show
  1. lyrics_lm.py +16 -6
lyrics_lm.py CHANGED
@@ -133,6 +133,7 @@ class _MLXLM:
133
 
134
  def generate(self, system: str, user: str, **kw: Any) -> str:
135
  import mlx.core as mx # type: ignore[import-not-found]
 
136
  from mlx_lm import generate # type: ignore[import-not-found]
137
 
138
  # Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
@@ -143,13 +144,22 @@ class _MLXLM:
143
  f"<|im_start|>assistant\n"
144
  )
145
  # Gradio runs handlers in anyio worker threads. MLX maintains a
146
- # *per-thread* default stream and bails with "There is no
147
- # Stream(gpu, 0) in current thread" when a worker thread that
148
- # didn't create the GPU stream tries to use it. Wrapping the
149
- # generate() call in ``mx.stream(mx.gpu)`` installs a GPU stream
150
- # for the current thread for the duration of the context, which
151
- # is what mlx-lm's wired_limit() helper expects.
 
 
 
 
 
 
 
 
152
  with mx.stream(mx.gpu):
 
153
  return generate(
154
  self.model,
155
  self.tokenizer,
 
133
 
134
  def generate(self, system: str, user: str, **kw: Any) -> str:
135
  import mlx.core as mx # type: ignore[import-not-found]
136
+ import mlx_lm.generate as mlx_gen_mod # type: ignore[import-not-found]
137
  from mlx_lm import generate # type: ignore[import-not-found]
138
 
139
  # Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
 
144
  f"<|im_start|>assistant\n"
145
  )
146
  # Gradio runs handlers in anyio worker threads. MLX maintains a
147
+ # *per-thread* default stream and a module-level ``generation_stream``
148
+ # that was created at mlx_lm import time on the MAIN thread. Both
149
+ # need to be valid in the *current* (worker) thread or
150
+ # ``wired_limit().__exit__`` crashes with "There is no Stream(gpu, 0)
151
+ # in current thread" when it calls ``mx.synchronize(generation_stream)``.
152
+ #
153
+ # Two-part fix:
154
+ # 1. ``mx.stream(mx.gpu)`` wrap installs the default GPU stream
155
+ # for the current thread for the duration of the call.
156
+ # 2. Re-assign ``mlx_lm.generate.generation_stream`` to a stream
157
+ # created in the *current* thread so ``mx.synchronize`` doesn't
158
+ # reach across thread boundaries. The reassignment is safe
159
+ # because Gradio's queue runs at default_concurrency_limit=1 —
160
+ # no two lyrics drafts run concurrently.
161
  with mx.stream(mx.gpu):
162
+ mlx_gen_mod.generation_stream = mx.new_stream(mx.default_device())
163
  return generate(
164
  self.model,
165
  self.tokenizer,