dek924 commited on
Commit
11786ab
Β·
1 Parent(s): ba9cf35

feat: add global call limit & remove sim start limit

Browse files
Files changed (2) hide show
  1. app.py +16 -11
  2. rate_limiter.py +50 -25
app.py CHANGED
@@ -609,14 +609,6 @@ def start_simulation(
609
  return _setup_error("Invalid model selection.")
610
 
611
  using_own_key = bool(user_api_key.strip())
612
-
613
- # Only apply rate limiting when using the shared demo key
614
- if not using_own_key:
615
- client_key = get_client_key(request)
616
- allowed, limit_msg = _rate_limiter.check_simulation_start(client_key)
617
- if not allowed:
618
- return _setup_error(limit_msg)
619
-
620
  is_openai = "gpt" in model.lower()
621
 
622
  if using_own_key:
@@ -654,7 +646,7 @@ def start_simulation(
654
  **patient,
655
  )
656
  except Exception as e:
657
- _logger.error("Failed to initialize patient agent: %s", e, exc_info=True)
658
  return _setup_error(f"Failed to initialize patient agent: {_sanitize_error(str(e))}")
659
 
660
  recap = build_recap_html(hadm_id, model, cefr, personality, recall, confusion)
@@ -744,11 +736,16 @@ def chat(message: str, history: list, agent, sim_config: dict, request: gr.Reque
744
  raise gr.Error("Invalid input detected. Please enter a valid clinical question.")
745
 
746
  using_own_key = bool(sim_config and sim_config.get("user_api_key"))
 
747
  if not using_own_key:
748
- client_key = get_client_key(request)
749
  allowed, limit_msg = _rate_limiter.check_chat_message(client_key)
750
  if not allowed:
751
  raise gr.Error(limit_msg)
 
 
 
 
 
752
 
753
  response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
754
  history = history + [
@@ -812,6 +809,14 @@ def start_auto(agent, sim_config: dict, request: gr.Request = None):
812
  gr.Warning(limit_msg)
813
  yield _auto_fallback_outputs()
814
  return
 
 
 
 
 
 
 
 
815
 
816
  try:
817
  agent.reset_history(verbose=False)
@@ -847,7 +852,7 @@ def start_auto(agent, sim_config: dict, request: gr.Request = None):
847
  )
848
 
849
  except Exception as e:
850
- _logger.error("Failed to initialize doctor agent: %s", e, exc_info=True)
851
  gr.Error(f"Failed to initialize doctor agent: {_sanitize_error(str(e))}")
852
  yield _auto_fallback_outputs()
853
  return
 
609
  return _setup_error("Invalid model selection.")
610
 
611
  using_own_key = bool(user_api_key.strip())
 
 
 
 
 
 
 
 
612
  is_openai = "gpt" in model.lower()
613
 
614
  if using_own_key:
 
646
  **patient,
647
  )
648
  except Exception as e:
649
+ _logger.error("Failed to initialize patient agent: %s", _sanitize_error(str(e)))
650
  return _setup_error(f"Failed to initialize patient agent: {_sanitize_error(str(e))}")
651
 
652
  recap = build_recap_html(hadm_id, model, cefr, personality, recall, confusion)
 
736
  raise gr.Error("Invalid input detected. Please enter a valid clinical question.")
737
 
738
  using_own_key = bool(sim_config and sim_config.get("user_api_key"))
739
+ client_key = get_client_key(request)
740
  if not using_own_key:
 
741
  allowed, limit_msg = _rate_limiter.check_chat_message(client_key)
742
  if not allowed:
743
  raise gr.Error(limit_msg)
744
+ else:
745
+ # Own-key users bypass per-IP quotas but still respect global capacity.
746
+ allowed, limit_msg = _rate_limiter.check_global_capacity()
747
+ if not allowed:
748
+ raise gr.Error(limit_msg)
749
 
750
  response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
751
  history = history + [
 
809
  gr.Warning(limit_msg)
810
  yield _auto_fallback_outputs()
811
  return
812
+ else:
813
+ # Own-key users bypass per-IP quotas but still enforce the concurrent
814
+ # run cap and the hard global capacity limit.
815
+ allowed, limit_msg = _rate_limiter.check_own_key_auto_run(client_key)
816
+ if not allowed:
817
+ gr.Warning(limit_msg)
818
+ yield _auto_fallback_outputs()
819
+ return
820
 
821
  try:
822
  agent.reset_history(verbose=False)
 
852
  )
853
 
854
  except Exception as e:
855
+ _logger.error("Failed to initialize doctor agent: %s", _sanitize_error(str(e)))
856
  gr.Error(f"Failed to initialize doctor agent: {_sanitize_error(str(e))}")
857
  yield _auto_fallback_outputs()
858
  return
rate_limiter.py CHANGED
@@ -7,7 +7,6 @@ until the process is restarted (or the SQLite DB is cleared).
7
 
8
  Limits are configurable via environment variables:
9
 
10
- RATE_LIMIT_SIM_STARTS β€” max simulation setups total per IP (default: 5)
11
  RATE_LIMIT_CHAT_MSGS β€” max chat messages total per IP (default: 50)
12
  RATE_LIMIT_AUTO_RUNS β€” max auto simulation runs total per IP (default: 5)
13
  RATE_LIMIT_TOTAL_API_CALLS β€” max total LLM calls across all modes (default: 200)
@@ -38,7 +37,6 @@ import gradio as gr
38
  # ---------------------------------------------------------------------------
39
  # Configuration β€” overridable via environment variables
40
  # ---------------------------------------------------------------------------
41
- SIM_STARTS_LIMIT: int = int(os.environ.get("RATE_LIMIT_SIM_STARTS", "5"))
42
  CHAT_MSGS_LIMIT: int = int(os.environ.get("RATE_LIMIT_CHAT_MSGS", "50"))
43
  AUTO_RUNS_LIMIT: int = int(os.environ.get("RATE_LIMIT_AUTO_RUNS", "5"))
44
  TOTAL_API_CALLS_LIMIT: int = int(os.environ.get("RATE_LIMIT_TOTAL_API_CALLS", "200"))
@@ -137,7 +135,6 @@ class RateLimiter:
137
 
138
  Tracks four independent counters per key:
139
 
140
- * **sim_starts** β€” calls to ``start_simulation()``
141
  * **chat_msgs** β€” individual chat messages (1 LLM call each)
142
  * **auto_runs** β€” auto simulation runs (each reserved as
143
  ``_AUTO_RUN_CALL_RESERVATION`` LLM calls in ``total_calls``)
@@ -151,7 +148,7 @@ class RateLimiter:
151
  Example
152
  -------
153
  >>> limiter = RateLimiter()
154
- >>> allowed, msg = limiter.check_simulation_start("ip:1.2.3.4")
155
  >>> if not allowed:
156
  ... raise gr.Error(msg)
157
  """
@@ -164,7 +161,6 @@ class RateLimiter:
164
  # SQLite-backed persistent counters; fall back to in-memory on failure
165
  self._db: Optional[sqlite3.Connection] = None
166
  self._mem: Dict[str, Dict[str, int]] = {
167
- "sim_starts": defaultdict(int),
168
  "chat_msgs": defaultdict(int),
169
  "auto_runs": defaultdict(int),
170
  "total_calls": defaultdict(int),
@@ -242,24 +238,6 @@ class RateLimiter:
242
  # Public check methods
243
  # ------------------------------------------------------------------
244
 
245
- def check_simulation_start(self, key: Optional[str]) -> Tuple[bool, str]:
246
- """
247
- Check whether a new simulation setup is allowed.
248
-
249
- Called once when the user clicks **Start Simulation**.
250
- """
251
- if not key:
252
- return False, self._UNIDENTIFIED_MSG
253
- with self._lock:
254
- count = self._get("sim_starts", key) + 1
255
- if count > SIM_STARTS_LIMIT:
256
- return False, (
257
- f"Simulation setup limit reached "
258
- f"(maximum {SIM_STARTS_LIMIT} simulations per session)."
259
- )
260
- self._set("sim_starts", key, count)
261
- return True, ""
262
-
263
  def check_chat_message(self, key: Optional[str]) -> Tuple[bool, str]:
264
  """
265
  Check whether sending a chat message is allowed (= 1 LLM API call).
@@ -331,6 +309,54 @@ class RateLimiter:
331
  self._active_auto_runs[key] += 1
332
  return True, ""
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  def release_auto_slot(self, key: Optional[str]) -> None:
335
  """
336
  Release one concurrent auto run slot for *key*.
@@ -355,12 +381,11 @@ class RateLimiter:
355
 
356
  Returns
357
  -------
358
- dict with keys ``sim_starts``, ``chat_messages``, ``auto_runs``,
359
  ``total_api_calls``; each value is a dict with ``used`` and ``limit``.
360
  """
361
  with self._lock:
362
  return {
363
- "sim_starts": {"used": self._get("sim_starts", key), "limit": SIM_STARTS_LIMIT},
364
  "chat_messages": {"used": self._get("chat_msgs", key), "limit": CHAT_MSGS_LIMIT},
365
  "auto_runs": {"used": self._get("auto_runs", key), "limit": AUTO_RUNS_LIMIT},
366
  "total_api_calls": {"used": self._get("total_calls", key), "limit": TOTAL_API_CALLS_LIMIT},
 
7
 
8
  Limits are configurable via environment variables:
9
 
 
10
  RATE_LIMIT_CHAT_MSGS β€” max chat messages total per IP (default: 50)
11
  RATE_LIMIT_AUTO_RUNS β€” max auto simulation runs total per IP (default: 5)
12
  RATE_LIMIT_TOTAL_API_CALLS β€” max total LLM calls across all modes (default: 200)
 
37
  # ---------------------------------------------------------------------------
38
  # Configuration β€” overridable via environment variables
39
  # ---------------------------------------------------------------------------
 
40
  CHAT_MSGS_LIMIT: int = int(os.environ.get("RATE_LIMIT_CHAT_MSGS", "50"))
41
  AUTO_RUNS_LIMIT: int = int(os.environ.get("RATE_LIMIT_AUTO_RUNS", "5"))
42
  TOTAL_API_CALLS_LIMIT: int = int(os.environ.get("RATE_LIMIT_TOTAL_API_CALLS", "200"))
 
135
 
136
  Tracks four independent counters per key:
137
 
 
138
  * **chat_msgs** β€” individual chat messages (1 LLM call each)
139
  * **auto_runs** β€” auto simulation runs (each reserved as
140
  ``_AUTO_RUN_CALL_RESERVATION`` LLM calls in ``total_calls``)
 
148
  Example
149
  -------
150
  >>> limiter = RateLimiter()
151
+ >>> allowed, msg = limiter.check_chat_message("ip:1.2.3.4")
152
  >>> if not allowed:
153
  ... raise gr.Error(msg)
154
  """
 
161
  # SQLite-backed persistent counters; fall back to in-memory on failure
162
  self._db: Optional[sqlite3.Connection] = None
163
  self._mem: Dict[str, Dict[str, int]] = {
 
164
  "chat_msgs": defaultdict(int),
165
  "auto_runs": defaultdict(int),
166
  "total_calls": defaultdict(int),
 
238
  # Public check methods
239
  # ------------------------------------------------------------------
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def check_chat_message(self, key: Optional[str]) -> Tuple[bool, str]:
242
  """
243
  Check whether sending a chat message is allowed (= 1 LLM API call).
 
309
  self._active_auto_runs[key] += 1
310
  return True, ""
311
 
312
+ def check_global_capacity(self) -> Tuple[bool, str]:
313
+ """
314
+ Lightweight global-capacity check for users supplying their own API keys.
315
+
316
+ Per-IP quotas (sim_starts, chat_msgs, auto_runs, total_calls) are
317
+ intentionally skipped β€” own-key users are not billed against the shared
318
+ pool. However, the hard global cap still applies to prevent the server
319
+ from being overwhelmed regardless of who is calling.
320
+
321
+ Unlike the per-IP check methods, this method **does** increment
322
+ ``_global_calls`` so that the counter accurately reflects all LLM
323
+ calls, not just those made through the shared key.
324
+ """
325
+ with self._lock:
326
+ new_global = self._global_calls + 1
327
+ if new_global > GLOBAL_TOTAL_CALLS_LIMIT:
328
+ return False, "Service capacity reached. Please try again later."
329
+ self._global_calls = new_global
330
+ return True, ""
331
+
332
+ def check_own_key_auto_run(self, key: Optional[str]) -> Tuple[bool, str]:
333
+ """
334
+ Concurrent-run and global-capacity check for own-key auto simulations.
335
+
336
+ Per-IP auto-run quota and total-call quota are intentionally skipped.
337
+ The concurrent run cap (``_MAX_CONCURRENT_AUTO``) **is** enforced to
338
+ prevent a single client from spawning many parallel simulations and
339
+ exhausting server threads. The global hard cap is also applied and the
340
+ global counter is updated.
341
+
342
+ Must be paired with a ``release_auto_slot()`` call in a ``finally``
343
+ block, just like ``check_auto_run()``.
344
+ """
345
+ if not key:
346
+ return False, self._UNIDENTIFIED_MSG
347
+ with self._lock:
348
+ if self._active_auto_runs[key] >= _MAX_CONCURRENT_AUTO:
349
+ return False, "An auto simulation is already running. Please wait."
350
+
351
+ new_global = self._global_calls + _AUTO_RUN_CALL_RESERVATION
352
+ if new_global > GLOBAL_TOTAL_CALLS_LIMIT:
353
+ return False, "Service capacity reached. Please try again later."
354
+
355
+ # All checks passed β€” commit atomically
356
+ self._global_calls = new_global
357
+ self._active_auto_runs[key] += 1
358
+ return True, ""
359
+
360
  def release_auto_slot(self, key: Optional[str]) -> None:
361
  """
362
  Release one concurrent auto run slot for *key*.
 
381
 
382
  Returns
383
  -------
384
+ dict with keys ``chat_messages``, ``auto_runs``,
385
  ``total_api_calls``; each value is a dict with ``used`` and ``limit``.
386
  """
387
  with self._lock:
388
  return {
 
389
  "chat_messages": {"used": self._get("chat_msgs", key), "limit": CHAT_MSGS_LIMIT},
390
  "auto_runs": {"used": self._get("auto_runs", key), "limit": AUTO_RUNS_LIMIT},
391
  "total_api_calls": {"used": self._get("total_calls", key), "limit": TOTAL_API_CALLS_LIMIT},