Ghostgim commited on
Commit
551cb99
·
verified ·
1 Parent(s): 088475f

feat(chat): stream tokens as they're generated

Browse files

Convert chat_fn to a generator so Gradio's ChatInterface shows tokens
appearing incrementally rather than waiting 15-25 s for the full reply
to materialize. New helper generate_until_end_stream is the same loop as
generate_until_end but yields the growing token list after every sampled
token; chat_fn decodes and yields the running text snapshot per Gradio's
API contract.

No extra forward-pass cost. The user sees motion within the first ~1-2 s
instead of staring at a static loading state for the full duration. As a
side effect, this also reduces peak memory pressure: Gradio holds only
the latest snapshot in flight rather than the entire response object,
and the streaming yield gives the worker more chances to release
intermediate tensors mid-generation.

Files changed (1) hide show
  1. app.py +54 -7
app.py CHANGED
@@ -197,6 +197,44 @@ def generate_until_end(
197
  return new_ids
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # ---------------------------------------------------------------------------
201
  # Module-level state
202
  # ---------------------------------------------------------------------------
@@ -248,7 +286,13 @@ def chat_fn(message: str, history: list, temperature: float, top_k: int,
248
  else:
249
  break
250
 
251
- new_ids = generate_until_end(
 
 
 
 
 
 
252
  MODEL, prompt_ids,
253
  end_id=END_ID,
254
  max_new_tokens=int(max_tokens),
@@ -256,22 +300,25 @@ def chat_fn(message: str, history: list, temperature: float, top_k: int,
256
  top_k=int(top_k),
257
  top_p=float(top_p),
258
  repetition_penalty=float(repetition_penalty),
259
- )
260
- result = TOKENIZER.decode(new_ids).strip() or "(no response)"
 
 
 
 
 
 
261
 
262
  # Free intermediate tensors before returning. Without this, on
263
  # HF Spaces (CPU runtime, ~16GB RAM) the activation memory from
264
  # consecutive generations accumulates and the worker errors out
265
- # after 2-3 turns. The user-visible bug is "model errors after 2
266
- # generations and needs page reload"; this block fixes it.
267
  if torch.backends.mps.is_available():
268
  torch.mps.empty_cache()
269
  elif torch.cuda.is_available():
270
  torch.cuda.empty_cache()
271
  gc.collect()
272
 
273
- return result
274
-
275
 
276
  # ---------------------------------------------------------------------------
277
  # UI
 
197
  return new_ids
198
 
199
 
200
+ def generate_until_end_stream(
201
+ model,
202
+ prompt_ids: List[int],
203
+ *,
204
+ end_id: int,
205
+ max_new_tokens: int,
206
+ temperature: float,
207
+ top_k: int,
208
+ top_p: float,
209
+ repetition_penalty: float,
210
+ ):
211
+ """Streaming variant: same as ``generate_until_end`` but yields the
212
+ growing list of new token ids after every sampled token.
213
+
214
+ Used by Gradio's chat interface so the user sees text appear
215
+ incrementally rather than waiting 15-25 s for the full response.
216
+ The yields happen with no extra forward-pass cost; the generator
217
+ just surfaces what each iteration of the loop produces."""
218
+ ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0)
219
+ new_ids: List[int] = []
220
+ ctx = model.config.context_length
221
+ with torch.no_grad():
222
+ for _ in range(max_new_tokens):
223
+ cond = ids[:, -ctx:]
224
+ logits, _ = model(cond)
225
+ next_logits = logits[:, -1, :].squeeze(0).clone()
226
+ tok = sample_next(
227
+ next_logits,
228
+ temperature=temperature, top_k=top_k, top_p=top_p,
229
+ prev_ids=new_ids[-128:], repetition_penalty=repetition_penalty,
230
+ )
231
+ if tok == end_id:
232
+ break
233
+ new_ids.append(tok)
234
+ ids = torch.cat([ids, torch.tensor([[tok]])], dim=1)
235
+ yield new_ids
236
+
237
+
238
  # ---------------------------------------------------------------------------
239
  # Module-level state
240
  # ---------------------------------------------------------------------------
 
286
  else:
287
  break
288
 
289
+ # Streaming: yield the growing decoded text after each sampled token
290
+ # so Gradio shows incremental output. Same total wall-clock as the
291
+ # non-streaming path, but the user sees motion immediately and the
292
+ # demo feels alive instead of frozen for 15-25 s. Each yield is a
293
+ # full snapshot of the response so far (Gradio's ChatInterface API).
294
+ last_text = ""
295
+ for new_ids in generate_until_end_stream(
296
  MODEL, prompt_ids,
297
  end_id=END_ID,
298
  max_new_tokens=int(max_tokens),
 
300
  top_k=int(top_k),
301
  top_p=float(top_p),
302
  repetition_penalty=float(repetition_penalty),
303
+ ):
304
+ text = TOKENIZER.decode(new_ids).strip()
305
+ if text and text != last_text:
306
+ last_text = text
307
+ yield text
308
+
309
+ if not last_text:
310
+ yield "(no response)"
311
 
312
  # Free intermediate tensors before returning. Without this, on
313
  # HF Spaces (CPU runtime, ~16GB RAM) the activation memory from
314
  # consecutive generations accumulates and the worker errors out
315
+ # after 2-3 turns.
 
316
  if torch.backends.mps.is_available():
317
  torch.mps.empty_cache()
318
  elif torch.cuda.is_available():
319
  torch.cuda.empty_cache()
320
  gc.collect()
321
 
 
 
322
 
323
  # ---------------------------------------------------------------------------
324
  # UI