techfreakworm commited on
Commit
a81cc03
·
unverified ·
1 Parent(s): 901e296

fix(lyrics): wrap mlx-lm generate in mx.stream(mx.gpu) for Gradio worker threads

Browse files

Gradio dispatches handlers via anyio.to_thread.run_sync; the worker
thread's per-thread MLX default stream is uninitialised, so
mlx-lm's wired_limit() fails with RuntimeError 'There is no Stream(gpu, 0)
in current thread.' Wrapping the generate() call in an mx.stream(mx.gpu)
context installs a GPU stream for the duration of the call.

Files changed (1) hide show
  1. lyrics_lm.py +15 -6
lyrics_lm.py CHANGED
@@ -132,6 +132,7 @@ class _MLXLM:
132
  tokenizer: Any
133
 
134
  def generate(self, system: str, user: str, **kw: Any) -> str:
 
135
  from mlx_lm import generate # type: ignore[import-not-found]
136
 
137
  # Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
@@ -141,12 +142,20 @@ class _MLXLM:
141
  f"<|im_start|>user\n{user}<|im_end|>\n"
142
  f"<|im_start|>assistant\n"
143
  )
144
- return generate(
145
- self.model,
146
- self.tokenizer,
147
- prompt=prompt,
148
- max_tokens=int(kw.get("max_new_tokens", 600)),
149
- )
 
 
 
 
 
 
 
 
150
 
151
 
152
  @dataclass
 
132
  tokenizer: Any
133
 
134
  def generate(self, system: str, user: str, **kw: Any) -> str:
135
+ import mlx.core as mx # type: ignore[import-not-found]
136
  from mlx_lm import generate # type: ignore[import-not-found]
137
 
138
  # Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
 
142
  f"<|im_start|>user\n{user}<|im_end|>\n"
143
  f"<|im_start|>assistant\n"
144
  )
145
+ # Gradio runs handlers in anyio worker threads. MLX maintains a
146
+ # *per-thread* default stream and bails with "There is no
147
+ # Stream(gpu, 0) in current thread" when a worker thread that
148
+ # didn't create the GPU stream tries to use it. Wrapping the
149
+ # generate() call in ``mx.stream(mx.gpu)`` installs a GPU stream
150
+ # for the current thread for the duration of the context, which
151
+ # is what mlx-lm's wired_limit() helper expects.
152
+ with mx.stream(mx.gpu):
153
+ return generate(
154
+ self.model,
155
+ self.tokenizer,
156
+ prompt=prompt,
157
+ max_tokens=int(kw.get("max_new_tokens", 600)),
158
+ )
159
 
160
 
161
  @dataclass