SpringWang08 commited on
Commit
d9a0039
·
1 Parent(s): 0ee23ae

Add backend progress polling for job predictions

Browse files
Files changed (2) hide show
  1. web/main.py +171 -18
  2. web/static/index.html +80 -3
web/main.py CHANGED
@@ -5,7 +5,9 @@ import io
5
  import json
6
  import os
7
  import re
 
8
  import time
 
9
  from pathlib import Path
10
  from typing import Any, Optional
11
 
@@ -131,6 +133,20 @@ class VQAServerState:
131
  self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "0") == "1"
132
  # Chạy lần lượt và giải phóng model sau mỗi lượt để giảm đỉnh RAM/VRAM.
133
  self.release_after_predict = os.getenv("WEB_RELEASE_AFTER_PREDICT", "1") == "1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  @property
136
  def phobert_model(self) -> str:
@@ -149,6 +165,31 @@ def _artifact_exists(path: Path) -> bool:
149
  return path.exists()
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def _release_variant_cache(variant: str) -> None:
153
  if variant in {"A1", "A2"}:
154
  bundle = state.a_models.pop(variant, None)
@@ -927,6 +968,81 @@ async def predict_variant(variant: str, question: str, image: Image.Image) -> di
927
  _release_variant_cache(variant)
928
 
929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  def _parse_model_selection(raw_model_name: Optional[str], raw_model_names: Optional[str]) -> list[str]:
931
  if raw_model_names:
932
  try:
@@ -1009,28 +1125,65 @@ async def predict(
1009
  raise HTTPException(status_code=400, detail=f"Failed to read image file: {exc}") from exc
1010
 
1011
  selected_models = _parse_model_selection(model_name, model_names)
1012
- results = []
1013
- async with load_lock:
1014
- for variant in selected_models:
1015
- if state.release_after_predict:
1016
- _release_variant_cache(variant)
1017
- results.append(await predict_variant(variant, question, pil_img))
1018
 
1019
- predictions = {item["variant"]: item["prediction"] for item in results if item.get("status") == "ok"}
1020
- summary = {
1021
- "majority_vote": majority_answer(list(predictions.values())) if predictions else "",
1022
- "success_count": sum(1 for item in results if item.get("status") == "ok"),
1023
- "error_count": sum(1 for item in results if item.get("status", "").startswith("error")),
1024
- }
1025
 
1026
- return JSONResponse(
1027
- {
1028
- "question": question,
1029
- "selected_models": selected_models,
1030
- "results": results,
1031
- "summary": summary,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1032
  }
 
 
 
 
 
1033
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
 
1035
 
1036
  @app.get("/v1/question-suggestions")
 
5
  import json
6
  import os
7
  import re
8
+ import threading
9
  import time
10
+ import uuid
11
  from pathlib import Path
12
  from typing import Any, Optional
13
 
 
133
  self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "0") == "1"
134
  # Chạy lần lượt và giải phóng model sau mỗi lượt để giảm đỉnh RAM/VRAM.
135
  self.release_after_predict = os.getenv("WEB_RELEASE_AFTER_PREDICT", "1") == "1"
136
+ self.progress_state: dict[str, Any] = {
137
+ "job_id": "",
138
+ "active": False,
139
+ "status": "idle",
140
+ "current_variant": "",
141
+ "current_index": 0,
142
+ "total": 0,
143
+ "completed": 0,
144
+ "message": "Idle",
145
+ "updated_at": time.time(),
146
+ }
147
+ self.latest_result: dict[str, Any] | None = None
148
+ self.latest_error: str = ""
149
+ self.progress_lock = threading.Lock()
150
 
151
  @property
152
  def phobert_model(self) -> str:
 
165
  return path.exists()
166
 
167
 
168
+ def _set_progress(
169
+ *,
170
+ job_id: str = "",
171
+ active: bool,
172
+ status: str,
173
+ message: str,
174
+ current_variant: str = "",
175
+ current_index: int = 0,
176
+ total: int = 0,
177
+ completed: int = 0,
178
+ ) -> None:
179
+ with state.progress_lock:
180
+ state.progress_state = {
181
+ "job_id": job_id,
182
+ "active": active,
183
+ "status": status,
184
+ "current_variant": current_variant,
185
+ "current_index": current_index,
186
+ "total": total,
187
+ "completed": completed,
188
+ "message": message,
189
+ "updated_at": time.time(),
190
+ }
191
+
192
+
193
  def _release_variant_cache(variant: str) -> None:
194
  if variant in {"A1", "A2"}:
195
  bundle = state.a_models.pop(variant, None)
 
968
  _release_variant_cache(variant)
969
 
970
 
971
+ async def _predict_models(
972
+ selected_models: list[str],
973
+ question: str,
974
+ pil_img: Image.Image,
975
+ job_id: str = "",
976
+ ) -> dict[str, Any]:
977
+ results = []
978
+ total = len(selected_models)
979
+ _set_progress(job_id=job_id, active=True, status="running", message="Starting comparison...", total=total, completed=0)
980
+ async with load_lock:
981
+ for index, variant in enumerate(selected_models, start=1):
982
+ _set_progress(
983
+ job_id=job_id,
984
+ active=True,
985
+ status="running",
986
+ message=f"Running {variant} ({index}/{total})",
987
+ current_variant=variant,
988
+ current_index=index,
989
+ total=total,
990
+ completed=index - 1,
991
+ )
992
+ result = await predict_variant(variant, question, pil_img)
993
+ results.append(result)
994
+ _set_progress(
995
+ job_id=job_id,
996
+ active=True,
997
+ status="running",
998
+ message=f"Finished {variant} ({index}/{total})",
999
+ current_variant=variant,
1000
+ current_index=index,
1001
+ total=total,
1002
+ completed=index,
1003
+ )
1004
+
1005
+ predictions = {item["variant"]: item["prediction"] for item in results if item.get("status") == "ok"}
1006
+ summary = {
1007
+ "majority_vote": majority_answer(list(predictions.values())) if predictions else "",
1008
+ "success_count": sum(1 for item in results if item.get("status") == "ok"),
1009
+ "error_count": sum(1 for item in results if item.get("status", "").startswith("error")),
1010
+ }
1011
+ payload = {
1012
+ "question": question,
1013
+ "selected_models": selected_models,
1014
+ "results": results,
1015
+ "summary": summary,
1016
+ }
1017
+ _set_progress(
1018
+ job_id=job_id,
1019
+ active=False,
1020
+ status="done",
1021
+ message=f"Finished {total}/{total} models.",
1022
+ total=total,
1023
+ completed=total,
1024
+ )
1025
+ return payload
1026
+
1027
+
1028
+ def _run_predict_job(job_id: str, selected_models: list[str], question: str, image_bytes: bytes) -> None:
1029
+ try:
1030
+ pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
1031
+ payload = asyncio.run(_predict_models(selected_models, question, pil_img, job_id=job_id))
1032
+ with state.progress_lock:
1033
+ state.latest_result = {"job_id": job_id, "payload": payload, "status": "done"}
1034
+ state.latest_error = ""
1035
+ except Exception as exc:
1036
+ with state.progress_lock:
1037
+ state.latest_result = None
1038
+ state.latest_error = str(exc)
1039
+ _set_progress(job_id=job_id, active=False, status="error", message=f"Failed: {exc}")
1040
+ finally:
1041
+ gc.collect()
1042
+ if torch.cuda.is_available():
1043
+ torch.cuda.empty_cache()
1044
+
1045
+
1046
  def _parse_model_selection(raw_model_name: Optional[str], raw_model_names: Optional[str]) -> list[str]:
1047
  if raw_model_names:
1048
  try:
 
1125
  raise HTTPException(status_code=400, detail=f"Failed to read image file: {exc}") from exc
1126
 
1127
  selected_models = _parse_model_selection(model_name, model_names)
1128
+ payload = await _predict_models(selected_models, question, pil_img)
1129
+ return JSONResponse(payload)
 
 
 
 
1130
 
 
 
 
 
 
 
1131
 
1132
+ @app.post("/v1/predict-job")
1133
+ async def predict_job(
1134
+ question: str = Form(..., description="Question for VQA"),
1135
+ model_name: Optional[str] = Form(None, description="Legacy single model name"),
1136
+ model_names: Optional[str] = Form(None, description="Comma-separated or JSON list of models"),
1137
+ image: UploadFile = File(..., description="Image input (JPEG/PNG)"),
1138
+ ) -> JSONResponse:
1139
+ if not question.strip():
1140
+ raise HTTPException(status_code=400, detail="Question is required.")
1141
+
1142
+ try:
1143
+ img_bytes = await image.read()
1144
+ except Exception as exc:
1145
+ raise HTTPException(status_code=400, detail=f"Failed to read image file: {exc}") from exc
1146
+
1147
+ selected_models = _parse_model_selection(model_name, model_names)
1148
+ job_id = uuid.uuid4().hex
1149
+ with state.progress_lock:
1150
+ state.latest_result = None
1151
+ state.latest_error = ""
1152
+ state.progress_state = {
1153
+ "job_id": job_id,
1154
+ "active": True,
1155
+ "status": "queued",
1156
+ "current_variant": "",
1157
+ "current_index": 0,
1158
+ "total": len(selected_models),
1159
+ "completed": 0,
1160
+ "message": "Queued for prediction...",
1161
+ "updated_at": time.time(),
1162
  }
1163
+
1164
+ thread = threading.Thread(
1165
+ target=_run_predict_job,
1166
+ args=(job_id, selected_models, question, img_bytes),
1167
+ daemon=True,
1168
  )
1169
+ thread.start()
1170
+
1171
+ return JSONResponse({"job_id": job_id, "status": "queued", "selected_models": selected_models}, status_code=202)
1172
+
1173
+
1174
+ @app.get("/v1/progress")
1175
+ def predict_progress() -> JSONResponse:
1176
+ return JSONResponse(state.progress_state)
1177
+
1178
+
1179
+ @app.get("/v1/result")
1180
+ def predict_result() -> JSONResponse:
1181
+ with state.progress_lock:
1182
+ if state.latest_result is not None:
1183
+ return JSONResponse(state.latest_result)
1184
+ if state.latest_error:
1185
+ return JSONResponse({"status": "error", "error": state.latest_error}, status_code=500)
1186
+ return JSONResponse({"status": "pending"}, status_code=202)
1187
 
1188
 
1189
  @app.get("/v1/question-suggestions")
web/static/index.html CHANGED
@@ -269,6 +269,16 @@ Reset
269
  </div>
270
 
271
  <div class="space-y-5 pt-2">
 
 
 
 
 
 
 
 
 
 
272
  <div class="flex items-center gap-3">
273
  <span class="text-xs font-bold uppercase tracking-widest text-china-gold">Model set:</span>
274
  <div class="flex gap-2 overflow-x-auto pb-1 no-scrollbar">
@@ -383,11 +393,15 @@ Medical VQA web demo for six-model comparison.
383
  resetBtn: document.getElementById("reset-btn"),
384
  statusText: document.getElementById("status-text"),
385
  resultsGrid: document.getElementById("results-grid"),
 
 
 
386
  };
387
 
388
  let currentImageFile = null;
389
  let selectedModels = new Set(MODEL_ORDER);
390
  let questionSuggestions = [];
 
391
 
392
  function escapeHtml(value) {
393
  return String(value ?? "")
@@ -405,6 +419,56 @@ Medical VQA web demo for six-model comparison.
405
  el.statusText.textContent = message;
406
  }
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  function setPreview(file) {
409
  currentImageFile = file || null;
410
  if (!file) {
@@ -641,6 +705,7 @@ Medical VQA web demo for six-model comparison.
641
  setStatus("Running all selected models...");
642
  renderRunningModelGrid();
643
  applyTiltEffect(".tilt-card", 5);
 
644
 
645
  try {
646
  const formData = new FormData();
@@ -648,19 +713,30 @@ Medical VQA web demo for six-model comparison.
648
  formData.append("model_names", JSON.stringify(Array.from(selectedModels)));
649
  formData.append("image", currentImageFile);
650
 
651
- const res = await fetch("/v1/predict", { method: "POST", body: formData });
652
  const data = await res.json();
653
  if (!res.ok) {
654
  throw new Error(data?.detail || "Prediction failed");
655
  }
656
- renderModelGrid(data.results || []);
 
 
 
 
 
 
 
 
 
 
657
  applyTiltEffect(".tilt-card", 5);
658
- setStatus(`Done. ${data.summary?.success_count ?? 0} models succeeded.`);
659
  } catch (err) {
660
  setStatus(err.message || "Prediction failed");
661
  } finally {
662
  el.runBtn.disabled = false;
663
  el.runBtn.querySelector("span").textContent = "Run Comparison";
 
664
  }
665
  });
666
 
@@ -669,6 +745,7 @@ Medical VQA web demo for six-model comparison.
669
  loadModels();
670
  loadQuestionSuggestions();
671
  renderModelGrid([], "", null);
 
672
  applyTiltEffect(".tilt-card", 5);
673
  </script>
674
 
 
269
  </div>
270
 
271
  <div class="space-y-5 pt-2">
272
+ <div class="space-y-2">
273
+ <div class="flex items-center justify-between text-[12px] uppercase tracking-[0.22em] text-china-gold font-bold">
274
+ <span>Backend Progress</span>
275
+ <span id="progress-label">Idle</span>
276
+ </div>
277
+ <div class="h-3 rounded-full bg-[#E7E1D6] overflow-hidden border border-china-gold/25">
278
+ <div id="progress-bar" class="h-full w-0 bg-gradient-to-r from-imperial-red via-china-gold to-gold-light transition-[width] duration-300 ease-out"></div>
279
+ </div>
280
+ <div id="progress-detail" class="text-[12px] italic font-serif text-ink-black/60">Waiting for a request.</div>
281
+ </div>
282
  <div class="flex items-center gap-3">
283
  <span class="text-xs font-bold uppercase tracking-widest text-china-gold">Model set:</span>
284
  <div class="flex gap-2 overflow-x-auto pb-1 no-scrollbar">
 
393
  resetBtn: document.getElementById("reset-btn"),
394
  statusText: document.getElementById("status-text"),
395
  resultsGrid: document.getElementById("results-grid"),
396
+ progressBar: document.getElementById("progress-bar"),
397
+ progressLabel: document.getElementById("progress-label"),
398
+ progressDetail: document.getElementById("progress-detail"),
399
  };
400
 
401
  let currentImageFile = null;
402
  let selectedModels = new Set(MODEL_ORDER);
403
  let questionSuggestions = [];
404
+ let progressTimer = null;
405
 
406
  function escapeHtml(value) {
407
  return String(value ?? "")
 
419
  el.statusText.textContent = message;
420
  }
421
 
422
+ function setProgressUI(state) {
423
+ const total = Number(state?.total || 0);
424
+ const completed = Number(state?.completed || 0);
425
+ const pct = total > 0 ? Math.max(0, Math.min(100, Math.round((completed / total) * 100))) : 0;
426
+ el.progressBar.style.width = `${pct}%`;
427
+ el.progressLabel.textContent = state?.active ? (state?.status || "running").toUpperCase() : "IDLE";
428
+ el.progressDetail.textContent = state?.message || "Waiting for a request.";
429
+ }
430
+
431
+ async function refreshProgress() {
432
+ try {
433
+ const res = await fetch("/v1/progress", { cache: "no-store" });
434
+ if (!res.ok) return;
435
+ const data = await res.json();
436
+ setProgressUI(data);
437
+ if (!data?.active && progressTimer) {
438
+ clearInterval(progressTimer);
439
+ progressTimer = null;
440
+ }
441
+ return data;
442
+ } catch (err) {
443
+ // ignore polling noise
444
+ }
445
+ return null;
446
+ }
447
+
448
+ function startProgressPolling() {
449
+ if (progressTimer) return;
450
+ refreshProgress();
451
+ progressTimer = setInterval(refreshProgress, 750);
452
+ }
453
+
454
+ function stopProgressPolling() {
455
+ if (progressTimer) {
456
+ clearInterval(progressTimer);
457
+ progressTimer = null;
458
+ }
459
+ refreshProgress();
460
+ }
461
+
462
+ async function waitForJobCompletion() {
463
+ while (true) {
464
+ const data = await refreshProgress();
465
+ if (data?.status === "done" || data?.status === "error") {
466
+ return data;
467
+ }
468
+ await new Promise((resolve) => setTimeout(resolve, 750));
469
+ }
470
+ }
471
+
472
  function setPreview(file) {
473
  currentImageFile = file || null;
474
  if (!file) {
 
705
  setStatus("Running all selected models...");
706
  renderRunningModelGrid();
707
  applyTiltEffect(".tilt-card", 5);
708
+ startProgressPolling();
709
 
710
  try {
711
  const formData = new FormData();
 
713
  formData.append("model_names", JSON.stringify(Array.from(selectedModels)));
714
  formData.append("image", currentImageFile);
715
 
716
+ const res = await fetch("/v1/predict-job", { method: "POST", body: formData });
717
  const data = await res.json();
718
  if (!res.ok) {
719
  throw new Error(data?.detail || "Prediction failed");
720
  }
721
+
722
+ setStatus(`Job queued: ${data.job_id}`);
723
+ await waitForJobCompletion();
724
+
725
+ const resultRes = await fetch("/v1/result", { cache: "no-store" });
726
+ const resultData = await resultRes.json();
727
+ if (!resultRes.ok) {
728
+ throw new Error(resultData?.error || "Prediction failed");
729
+ }
730
+
731
+ renderModelGrid(resultData?.payload?.results || []);
732
  applyTiltEffect(".tilt-card", 5);
733
+ setStatus(`Done. ${resultData?.payload?.summary?.success_count ?? 0} models succeeded.`);
734
  } catch (err) {
735
  setStatus(err.message || "Prediction failed");
736
  } finally {
737
  el.runBtn.disabled = false;
738
  el.runBtn.querySelector("span").textContent = "Run Comparison";
739
+ stopProgressPolling();
740
  }
741
  });
742
 
 
745
  loadModels();
746
  loadQuestionSuggestions();
747
  renderModelGrid([], "", null);
748
+ refreshProgress();
749
  applyTiltEffect(".tilt-card", 5);
750
  </script>
751