owenisas commited on
Commit
68da4a5
·
verified ·
1 Parent(s): 04f9cd3

Add runtime cache metadata and access check cache

Browse files
Files changed (2) hide show
  1. README.md +12 -0
  2. app.py +39 -11
README.md CHANGED
@@ -56,3 +56,15 @@ Stability AI's public MIT-licensed repository because its package metadata pins
56
  Torch 2.7.1. ZeroGPU currently provides Torch 2.8.0, so installing the upstream
57
  package through normal dependency resolution would downgrade Torch and break the
58
  ZeroGPU runtime.
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  Torch 2.7.1. ZeroGPU currently provides Torch 2.8.0, so installing the upstream
57
  package through normal dependency resolution would downgrade Torch and break the
58
  ZeroGPU runtime.
59
+
60
+ ## Optimization notes
61
+
62
+ - Repeated runs with the same selected model reuse the loaded model inside the
63
+ ZeroGPU worker when the worker stays warm. Run metadata includes `cache_hit`
64
+ and `load_elapsed_s` so this is visible.
65
+ - Successful gated-repo access checks are cached briefly per token digest and
66
+ repo ID to avoid a Hugging Face `HEAD` request on every generation.
67
+ - The `stable-audio-3-optimized` repo currently provides MLX, ONNX, and
68
+ TensorRT assets. This Space keeps the portable PyTorch path because the
69
+ TensorRT engines are prebuilt for `sm_90`, while the current ZeroGPU host is
70
+ a Blackwell GPU, and MLX is Apple-only.
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import gc
 
4
  import importlib
5
  import importlib.util
6
  import json
@@ -15,11 +16,13 @@ from contextlib import contextmanager
15
  from dataclasses import dataclass
16
  from typing import Any
17
 
 
 
 
 
18
  import gradio as gr
19
  import numpy as np
20
 
21
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
22
-
23
 
24
  def _filter_known_unraisable(unraisable):
25
  object_name = getattr(unraisable.object, "__qualname__", "")
@@ -171,6 +174,8 @@ COLLECTION_ROWS = [
171
 
172
  MODEL_CACHE: dict[str, Any] = {"key": None, "model": None}
173
  AE_CACHE: dict[str, Any] = {"key": None, "model": None}
 
 
174
  MODEL_LOAD_LOCK = threading.RLock()
175
 
176
 
@@ -246,7 +251,20 @@ def stable_audio_token_hint(model: GenerationModel) -> str:
246
  )
247
 
248
 
 
 
 
 
 
249
  def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]:
 
 
 
 
 
 
 
 
250
  request = urllib.request.Request(
251
  f"https://huggingface.co/{repo_id}/resolve/main/model_config.json",
252
  method="HEAD",
@@ -254,7 +272,10 @@ def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str |
254
  )
255
  try:
256
  with urllib.request.urlopen(request, timeout=20) as response:
257
- return response.status < 400, None
 
 
 
258
  except urllib.error.HTTPError as exc:
259
  if exc.code in {401, 403}:
260
  return (
@@ -389,12 +410,13 @@ def load_generation_model(
389
  )
390
 
391
  if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
392
- return MODEL_CACHE["model"], device
393
 
394
  with MODEL_LOAD_LOCK:
395
  if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
396
- return MODEL_CACHE["model"], device
397
 
 
398
  MODEL_CACHE["model"] = None
399
  MODEL_CACHE["key"] = None
400
  clear_torch_memory()
@@ -406,7 +428,7 @@ def load_generation_model(
406
  model = StableAudioModel.from_pretrained(model_key, model_half=model_half)
407
  MODEL_CACHE["key"] = model_key
408
  MODEL_CACHE["model"] = model
409
- return model, device
410
 
411
 
412
  def load_autoencoder(
@@ -429,12 +451,13 @@ def load_autoencoder(
429
  )
430
 
431
  if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
432
- return AE_CACHE["model"], device
433
 
434
  with MODEL_LOAD_LOCK:
435
  if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
436
- return AE_CACHE["model"], device
437
 
 
438
  AE_CACHE["model"] = None
439
  AE_CACHE["key"] = None
440
  clear_torch_memory()
@@ -445,7 +468,7 @@ def load_autoencoder(
445
  model = AutoencoderModel.from_pretrained(model_key)
446
  AE_CACHE["key"] = model_key
447
  AE_CACHE["model"] = model
448
- return model, device
449
 
450
 
451
  def model_changed(model_key: str):
@@ -521,7 +544,7 @@ def generate_audio(
521
  if seed < 0:
522
  seed = int.from_bytes(os.urandom(4), "little") % 100000
523
 
524
- model, device = load_generation_model(
525
  model_key,
526
  allow_cpu_medium,
527
  oauth_profile,
@@ -563,6 +586,8 @@ def generate_audio(
563
  "seed": seed,
564
  "sample_rate": sample_rate,
565
  "elapsed_s": elapsed,
 
 
566
  "output_file": out_file.name,
567
  "note": model_def.note,
568
  "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
@@ -619,7 +644,7 @@ def roundtrip_autoencoder(
619
 
620
  progress(0.05, desc="Loading autoencoder")
621
  started = time.time()
622
- model, device = load_autoencoder(
623
  model_key,
624
  allow_cpu_same_l,
625
  oauth_profile,
@@ -655,6 +680,8 @@ def roundtrip_autoencoder(
655
  "input_shape": list(waveform.shape),
656
  "latent_shape": list(latents.shape),
657
  "elapsed_s": round(time.time() - started, 3),
 
 
658
  "output_file": out_file.name,
659
  "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
660
  "username": oauth_username(oauth_profile),
@@ -710,6 +737,7 @@ def runtime_status(
710
  "hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
711
  "loaded_generation_model": MODEL_CACHE["key"],
712
  "loaded_autoencoder": AE_CACHE["key"],
 
713
  }
714
 
715
 
 
1
  from __future__ import annotations
2
 
3
  import gc
4
+ import hashlib
5
  import importlib
6
  import importlib.util
7
  import json
 
16
  from dataclasses import dataclass
17
  from typing import Any
18
 
19
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
20
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
21
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
22
+
23
  import gradio as gr
24
  import numpy as np
25
 
 
 
26
 
27
  def _filter_known_unraisable(unraisable):
28
  object_name = getattr(unraisable.object, "__qualname__", "")
 
174
 
175
  MODEL_CACHE: dict[str, Any] = {"key": None, "model": None}
176
  AE_CACHE: dict[str, Any] = {"key": None, "model": None}
177
+ ACCESS_CACHE: dict[tuple[str, str], float] = {}
178
+ ACCESS_CACHE_TTL_SECONDS = max(0, int(os.getenv("SA3_ACCESS_CACHE_TTL_SECONDS", "600")))
179
  MODEL_LOAD_LOCK = threading.RLock()
180
 
181
 
 
251
  )
252
 
253
 
254
+ def access_cache_key(repo_id: str, token: str) -> tuple[str, str]:
255
+ token_digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16]
256
+ return repo_id, token_digest
257
+
258
+
259
  def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]:
260
+ cache_key = access_cache_key(repo_id, token)
261
+ cached_until = ACCESS_CACHE.get(cache_key)
262
+ now = time.time()
263
+ if cached_until is not None:
264
+ if cached_until > now:
265
+ return True, None
266
+ ACCESS_CACHE.pop(cache_key, None)
267
+
268
  request = urllib.request.Request(
269
  f"https://huggingface.co/{repo_id}/resolve/main/model_config.json",
270
  method="HEAD",
 
272
  )
273
  try:
274
  with urllib.request.urlopen(request, timeout=20) as response:
275
+ has_access = response.status < 400
276
+ if has_access and ACCESS_CACHE_TTL_SECONDS:
277
+ ACCESS_CACHE[cache_key] = time.time() + ACCESS_CACHE_TTL_SECONDS
278
+ return has_access, None
279
  except urllib.error.HTTPError as exc:
280
  if exc.code in {401, 403}:
281
  return (
 
410
  )
411
 
412
  if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
413
+ return MODEL_CACHE["model"], device, True, 0.0
414
 
415
  with MODEL_LOAD_LOCK:
416
  if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
417
+ return MODEL_CACHE["model"], device, True, 0.0
418
 
419
+ load_started = time.time()
420
  MODEL_CACHE["model"] = None
421
  MODEL_CACHE["key"] = None
422
  clear_torch_memory()
 
428
  model = StableAudioModel.from_pretrained(model_key, model_half=model_half)
429
  MODEL_CACHE["key"] = model_key
430
  MODEL_CACHE["model"] = model
431
+ return model, device, False, round(time.time() - load_started, 3)
432
 
433
 
434
  def load_autoencoder(
 
451
  )
452
 
453
  if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
454
+ return AE_CACHE["model"], device, True, 0.0
455
 
456
  with MODEL_LOAD_LOCK:
457
  if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
458
+ return AE_CACHE["model"], device, True, 0.0
459
 
460
+ load_started = time.time()
461
  AE_CACHE["model"] = None
462
  AE_CACHE["key"] = None
463
  clear_torch_memory()
 
468
  model = AutoencoderModel.from_pretrained(model_key)
469
  AE_CACHE["key"] = model_key
470
  AE_CACHE["model"] = model
471
+ return model, device, False, round(time.time() - load_started, 3)
472
 
473
 
474
  def model_changed(model_key: str):
 
544
  if seed < 0:
545
  seed = int.from_bytes(os.urandom(4), "little") % 100000
546
 
547
+ model, device, cache_hit, load_elapsed = load_generation_model(
548
  model_key,
549
  allow_cpu_medium,
550
  oauth_profile,
 
586
  "seed": seed,
587
  "sample_rate": sample_rate,
588
  "elapsed_s": elapsed,
589
+ "cache_hit": cache_hit,
590
+ "load_elapsed_s": load_elapsed,
591
  "output_file": out_file.name,
592
  "note": model_def.note,
593
  "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
 
644
 
645
  progress(0.05, desc="Loading autoencoder")
646
  started = time.time()
647
+ model, device, cache_hit, load_elapsed = load_autoencoder(
648
  model_key,
649
  allow_cpu_same_l,
650
  oauth_profile,
 
680
  "input_shape": list(waveform.shape),
681
  "latent_shape": list(latents.shape),
682
  "elapsed_s": round(time.time() - started, 3),
683
+ "cache_hit": cache_hit,
684
+ "load_elapsed_s": load_elapsed,
685
  "output_file": out_file.name,
686
  "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
687
  "username": oauth_username(oauth_profile),
 
737
  "hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
738
  "loaded_generation_model": MODEL_CACHE["key"],
739
  "loaded_autoencoder": AE_CACHE["key"],
740
+ "access_cache_entries": len(ACCESS_CACHE),
741
  }
742
 
743