Yang2001 commited on
Commit
702d7bc
·
1 Parent(s): b385a26

Fix progress tracking: replace SSE+queue with file-based polling for HF Spaces compatibility

Browse files
Files changed (2) hide show
  1. app.py +34 -45
  2. index.html +11 -12
app.py CHANGED
@@ -226,39 +226,46 @@ def unpack_state(state_path):
226
  return shape_slat, tex_slat, int(data['res'])
227
 
228
  # ============================================================================
229
- # Progress Tracking (SSE-based, tqdm interception, multi-session)
230
  # ============================================================================
231
 
232
  import asyncio
233
- import queue
234
- from fastapi.responses import StreamingResponse
235
  from fastapi import Request
236
 
237
- # Per-session progress queues
238
- _progress_queues: Dict[str, queue.Queue] = {}
 
239
  _thread_local = threading.local()
240
 
 
 
 
 
241
  def _reset_progress(session_id: str):
242
  _thread_local.active_session = session_id
243
- # Always recreate the queue (old one may have been consumed or stale)
244
- _progress_queues[session_id] = queue.Queue()
245
 
246
  def _update_progress(stage: str, step: int, total: int):
247
- data = {"stage": stage, "step": step, "total": total, "done": False}
248
  session_id = getattr(_thread_local, 'active_session', '')
249
- if session_id and session_id in _progress_queues:
250
- try:
251
- _progress_queues[session_id].put_nowait(data)
252
- except:
253
- pass
254
 
255
  def _finish_progress():
256
  session_id = getattr(_thread_local, 'active_session', '')
257
- if session_id and session_id in _progress_queues:
258
- try:
259
- _progress_queues[session_id].put_nowait({"done": True})
260
- except:
261
- pass
 
 
 
 
 
 
 
 
262
 
263
  # Monkey-patch tqdm to intercept progress
264
  import tqdm as _tqdm_module
@@ -302,34 +309,16 @@ async def homepage():
302
  return HTMLResponse(content=f.read())
303
 
304
  @app.get("/progress")
305
- async def progress_sse(request: Request):
306
- """SSE endpoint for real-time progress updates during generation."""
307
  session_id = request.query_params.get("session_id", "")
308
- if session_id and session_id not in _progress_queues:
309
- _progress_queues[session_id] = queue.Queue()
310
-
311
- async def event_stream():
312
- timeout_count = 0
313
- while True:
314
- q = _progress_queues.get(session_id)
315
- if q:
316
- try:
317
- data = q.get_nowait()
318
- yield f"data: {json.dumps(data)}\n\n"
319
- if data.get("done"):
320
- break
321
- timeout_count = 0
322
- except queue.Empty:
323
- yield f": keepalive\n\n"
324
- timeout_count += 1
325
- else:
326
- yield f": keepalive\n\n"
327
- timeout_count += 1
328
- # Timeout after 5 minutes of no data
329
- if timeout_count > 1000:
330
- break
331
- await asyncio.sleep(0.3)
332
- return StreamingResponse(event_stream(), media_type="text/event-stream")
333
 
334
  @app.api()
335
  @spaces.GPU(duration=30)
 
226
  return shape_slat, tex_slat, int(data['res'])
227
 
228
  # ============================================================================
229
+ # Progress Tracking (file-based, cross-process safe for @spaces.GPU)
230
  # ============================================================================
231
 
232
  import asyncio
233
+ from fastapi.responses import JSONResponse
 
234
  from fastapi import Request
235
 
236
+ PROGRESS_DIR = os.path.join(TMP_DIR, '_progress')
237
+ os.makedirs(PROGRESS_DIR, exist_ok=True)
238
+
239
  _thread_local = threading.local()
240
 
241
+ def _progress_file(session_id: str) -> str:
242
+ """Return path to a session's progress JSON file."""
243
+ return os.path.join(PROGRESS_DIR, f"{session_id}.json")
244
+
245
  def _reset_progress(session_id: str):
246
  _thread_local.active_session = session_id
247
+ _write_progress_file(session_id, {"stage": "Initializing...", "step": 0, "total": 0, "done": False})
 
248
 
249
  def _update_progress(stage: str, step: int, total: int):
 
250
  session_id = getattr(_thread_local, 'active_session', '')
251
+ if session_id:
252
+ _write_progress_file(session_id, {"stage": stage, "step": step, "total": total, "done": False})
 
 
 
253
 
254
  def _finish_progress():
255
  session_id = getattr(_thread_local, 'active_session', '')
256
+ if session_id:
257
+ _write_progress_file(session_id, {"done": True})
258
+
259
+ def _write_progress_file(session_id: str, data: dict):
260
+ """Atomically write progress JSON to a file (cross-process safe)."""
261
+ path = _progress_file(session_id)
262
+ tmp_path = path + ".tmp"
263
+ try:
264
+ with open(tmp_path, 'w') as f:
265
+ json.dump(data, f)
266
+ os.replace(tmp_path, path) # atomic on POSIX
267
+ except Exception:
268
+ pass
269
 
270
  # Monkey-patch tqdm to intercept progress
271
  import tqdm as _tqdm_module
 
309
  return HTMLResponse(content=f.read())
310
 
311
  @app.get("/progress")
312
+ async def progress_poll(request: Request):
313
+ """Polling endpoint for real-time progress updates during generation."""
314
  session_id = request.query_params.get("session_id", "")
315
+ path = _progress_file(session_id)
316
+ try:
317
+ with open(path, 'r') as f:
318
+ data = json.load(f)
319
+ return JSONResponse(data)
320
+ except (FileNotFoundError, json.JSONDecodeError):
321
+ return JSONResponse({"stage": "Waiting...", "step": 0, "total": 0, "done": False})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  @app.api()
324
  @spaces.GPU(duration=30)
index.html CHANGED
@@ -1054,8 +1054,8 @@
1054
  }
1055
  }
1056
 
1057
- // SSE Progress Listener
1058
- let progressEventSource = null;
1059
  let lastStageName = "";
1060
 
1061
  function startProgressListener() {
@@ -1067,26 +1067,25 @@
1067
  document.getElementById('progress-bar-fill').style.width = '0%';
1068
  lastStageName = "";
1069
 
1070
- progressEventSource = new EventSource(`/progress?session_id=${sessionId}`);
1071
- progressEventSource.onmessage = (event) => {
1072
  try {
1073
- const data = JSON.parse(event.data);
 
 
1074
  if (data.done) {
1075
  stopProgressListener();
1076
  return;
1077
  }
1078
  updateProgressUI(data);
1079
  } catch (e) {}
1080
- };
1081
- progressEventSource.onerror = () => {
1082
- // Silently ignore SSE errors, generation continues
1083
- };
1084
  }
1085
 
1086
  function stopProgressListener() {
1087
- if (progressEventSource) {
1088
- progressEventSource.close();
1089
- progressEventSource = null;
1090
  }
1091
  }
1092
 
 
1054
  }
1055
  }
1056
 
1057
+ // Progress Polling
1058
+ let progressInterval = null;
1059
  let lastStageName = "";
1060
 
1061
  function startProgressListener() {
 
1067
  document.getElementById('progress-bar-fill').style.width = '0%';
1068
  lastStageName = "";
1069
 
1070
+ // Poll every 500ms instead of SSE
1071
+ progressInterval = setInterval(async () => {
1072
  try {
1073
+ const resp = await fetch(`/progress?session_id=${sessionId}`);
1074
+ if (!resp.ok) return;
1075
+ const data = await resp.json();
1076
  if (data.done) {
1077
  stopProgressListener();
1078
  return;
1079
  }
1080
  updateProgressUI(data);
1081
  } catch (e) {}
1082
+ }, 500);
 
 
 
1083
  }
1084
 
1085
  function stopProgressListener() {
1086
+ if (progressInterval) {
1087
+ clearInterval(progressInterval);
1088
+ progressInterval = null;
1089
  }
1090
  }
1091