Yang2001 commited on
Commit
95d9430
·
1 Parent(s): 8cb4c62

Fix race condition: disable generate during preprocess, use preprocessed image, thread-safe progress

Browse files
Files changed (2) hide show
  1. app.py +4 -5
  2. index.html +9 -2
app.py CHANGED
@@ -236,11 +236,10 @@ from fastapi import Request
236
 
237
  # Per-session progress queues
238
  _progress_queues: Dict[str, queue.Queue] = {}
239
- _active_session: str = "" # Which session is currently running GPU work
240
 
241
  def _reset_progress(session_id: str):
242
- global _active_session
243
- _active_session = session_id
244
  if session_id not in _progress_queues:
245
  _progress_queues[session_id] = queue.Queue()
246
  # Drain old items
@@ -253,7 +252,7 @@ def _reset_progress(session_id: str):
253
 
254
  def _update_progress(stage: str, step: int, total: int):
255
  data = {"stage": stage, "step": step, "total": total, "done": False}
256
- session_id = _active_session
257
  if session_id and session_id in _progress_queues:
258
  try:
259
  _progress_queues[session_id].put_nowait(data)
@@ -261,7 +260,7 @@ def _update_progress(stage: str, step: int, total: int):
261
  pass
262
 
263
  def _finish_progress():
264
- session_id = _active_session
265
  if session_id and session_id in _progress_queues:
266
  try:
267
  _progress_queues[session_id].put_nowait({"done": True})
 
236
 
237
  # Per-session progress queues
238
  _progress_queues: Dict[str, queue.Queue] = {}
239
+ _thread_local = threading.local()
240
 
241
  def _reset_progress(session_id: str):
242
+ _thread_local.active_session = session_id
 
243
  if session_id not in _progress_queues:
244
  _progress_queues[session_id] = queue.Queue()
245
  # Drain old items
 
252
 
253
  def _update_progress(stage: str, step: int, total: int):
254
  data = {"stage": stage, "step": step, "total": total, "done": False}
255
+ session_id = getattr(_thread_local, 'active_session', '')
256
  if session_id and session_id in _progress_queues:
257
  try:
258
  _progress_queues[session_id].put_nowait(data)
 
260
  pass
261
 
262
  def _finish_progress():
263
+ session_id = getattr(_thread_local, 'active_session', '')
264
  if session_id and session_id in _progress_queues:
265
  try:
266
  _progress_queues[session_id].put_nowait({"done": True})
index.html CHANGED
@@ -768,6 +768,8 @@
768
 
769
  let client;
770
  let currentFile = null;
 
 
771
  let generationResult = null;
772
  let currentMode = "shaded_forest";
773
  let currentFrame = 0;
@@ -889,6 +891,9 @@
889
 
890
  async function handleImageUpload(file) {
891
  currentFile = file;
 
 
 
892
  const reader = new FileReader();
893
  reader.onload = (e) => {
894
  const img = document.getElementById('source-preview');
@@ -896,7 +901,6 @@
896
  img.src = e.target.result;
897
  img.style.display = 'block';
898
  hint.style.display = 'none';
899
- document.getElementById('generate-btn').disabled = false;
900
  setStep(1);
901
  };
902
  reader.readAsDataURL(file);
@@ -906,6 +910,7 @@
906
  const result = await client.predict("/preprocess", { image: handle_file(file) });
907
  const processedUrl = result.data[0].url;
908
  if (processedUrl) {
 
909
  document.getElementById('source-preview').src = processedUrl;
910
  document.getElementById('ref-thumb-2').src = processedUrl;
911
  document.getElementById('ref-thumb-2').style.display = 'block';
@@ -915,6 +920,8 @@
915
  } catch (err) {
916
  console.error("Preprocess failed:", err);
917
  }
 
 
918
  }
919
 
920
  function setStep(num) {
@@ -941,7 +948,7 @@
941
  startProgressListener();
942
  try {
943
  const params = {
944
- image: handle_file(currentFile),
945
  seed: parseInt(document.getElementById('seed').value),
946
  resolution: parseInt(document.getElementById('resolution').value),
947
  ss_guidance_strength: parseFloat(document.getElementById('ss_gs').value),
 
768
 
769
  let client;
770
  let currentFile = null;
771
+ let preprocessedFile = null;
772
+ let isPreprocessing = false;
773
  let generationResult = null;
774
  let currentMode = "shaded_forest";
775
  let currentFrame = 0;
 
891
 
892
  async function handleImageUpload(file) {
893
  currentFile = file;
894
+ preprocessedFile = null;
895
+ isPreprocessing = true;
896
+ document.getElementById('generate-btn').disabled = true;
897
  const reader = new FileReader();
898
  reader.onload = (e) => {
899
  const img = document.getElementById('source-preview');
 
901
  img.src = e.target.result;
902
  img.style.display = 'block';
903
  hint.style.display = 'none';
 
904
  setStep(1);
905
  };
906
  reader.readAsDataURL(file);
 
910
  const result = await client.predict("/preprocess", { image: handle_file(file) });
911
  const processedUrl = result.data[0].url;
912
  if (processedUrl) {
913
+ preprocessedFile = processedUrl;
914
  document.getElementById('source-preview').src = processedUrl;
915
  document.getElementById('ref-thumb-2').src = processedUrl;
916
  document.getElementById('ref-thumb-2').style.display = 'block';
 
920
  } catch (err) {
921
  console.error("Preprocess failed:", err);
922
  }
923
+ isPreprocessing = false;
924
+ document.getElementById('generate-btn').disabled = false;
925
  }
926
 
927
  function setStep(num) {
 
948
  startProgressListener();
949
  try {
950
  const params = {
951
+ image: preprocessedFile ? handle_file(preprocessedFile) : handle_file(currentFile),
952
  seed: parseInt(document.getElementById('seed').value),
953
  resolution: parseInt(document.getElementById('resolution').value),
954
  ss_guidance_strength: parseFloat(document.getElementById('ss_gs').value),