SpringWang08 commited on
Commit
0ee23ae
·
1 Parent(s): ad0cfad

Release models after each prediction for lower memory

Browse files
Files changed (1) hide show
  1. web/main.py +21 -0
web/main.py CHANGED
@@ -129,6 +129,8 @@ class VQAServerState:
129
  self.question_suggestions: list[dict[str, Any]] = []
130
  # Giữ mặc định là không preload để tránh ngốn RAM/VRAM khi Space khởi động.
131
  self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "0") == "1"
 
 
132
 
133
  @property
134
  def phobert_model(self) -> str:
@@ -147,6 +149,20 @@ def _artifact_exists(path: Path) -> bool:
147
  return path.exists()
148
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def _download_hub_snapshot(repo_id: str, cache_subdir: str, allow_patterns: Optional[list[str]] = None) -> Path:
151
  target_dir = state.artifact_cache_dir / cache_subdir
152
  target_dir.mkdir(parents=True, exist_ok=True)
@@ -906,6 +922,9 @@ async def predict_variant(variant: str, question: str, image: Image.Image) -> di
906
  "checkpoint": "",
907
  "latency_ms": round((time.perf_counter() - start) * 1000, 2),
908
  }
 
 
 
909
 
910
 
911
  def _parse_model_selection(raw_model_name: Optional[str], raw_model_names: Optional[str]) -> list[str]:
@@ -993,6 +1012,8 @@ async def predict(
993
  results = []
994
  async with load_lock:
995
  for variant in selected_models:
 
 
996
  results.append(await predict_variant(variant, question, pil_img))
997
 
998
  predictions = {item["variant"]: item["prediction"] for item in results if item.get("status") == "ok"}
 
129
  self.question_suggestions: list[dict[str, Any]] = []
130
  # Giữ mặc định là không preload để tránh ngốn RAM/VRAM khi Space khởi động.
131
  self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "0") == "1"
132
+ # Chạy lần lượt và giải phóng model sau mỗi lượt để giảm đỉnh RAM/VRAM.
133
+ self.release_after_predict = os.getenv("WEB_RELEASE_AFTER_PREDICT", "1") == "1"
134
 
135
  @property
136
  def phobert_model(self) -> str:
 
149
  return path.exists()
150
 
151
 
152
+ def _release_variant_cache(variant: str) -> None:
153
+ if variant in {"A1", "A2"}:
154
+ bundle = state.a_models.pop(variant, None)
155
+ if bundle is not None:
156
+ bundle["model"] = None
157
+ else:
158
+ if state.llava_bundle is not None:
159
+ state.llava_bundle["model"] = None
160
+ state.llava_bundle = None
161
+ gc.collect()
162
+ if torch.cuda.is_available():
163
+ torch.cuda.empty_cache()
164
+
165
+
166
  def _download_hub_snapshot(repo_id: str, cache_subdir: str, allow_patterns: Optional[list[str]] = None) -> Path:
167
  target_dir = state.artifact_cache_dir / cache_subdir
168
  target_dir.mkdir(parents=True, exist_ok=True)
 
922
  "checkpoint": "",
923
  "latency_ms": round((time.perf_counter() - start) * 1000, 2),
924
  }
925
+ finally:
926
+ if state.release_after_predict:
927
+ _release_variant_cache(variant)
928
 
929
 
930
  def _parse_model_selection(raw_model_name: Optional[str], raw_model_names: Optional[str]) -> list[str]:
 
1012
  results = []
1013
  async with load_lock:
1014
  for variant in selected_models:
1015
+ if state.release_after_predict:
1016
+ _release_variant_cache(variant)
1017
  results.append(await predict_variant(variant, question, pil_img))
1018
 
1019
  predictions = {item["variant"]: item["prediction"] for item in results if item.get("status") == "ok"}