Spaces:
Running on Zero
Running on Zero
Add runtime cache metadata and access check cache
Browse files
README.md
CHANGED
|
@@ -56,3 +56,15 @@ Stability AI's public MIT-licensed repository because its package metadata pins
|
|
| 56 |
Torch 2.7.1. ZeroGPU currently provides Torch 2.8.0, so installing the upstream
|
| 57 |
package through normal dependency resolution would downgrade Torch and break the
|
| 58 |
ZeroGPU runtime.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
Torch 2.7.1. ZeroGPU currently provides Torch 2.8.0, so installing the upstream
|
| 57 |
package through normal dependency resolution would downgrade Torch and break the
|
| 58 |
ZeroGPU runtime.
|
| 59 |
+
|
| 60 |
+
## Optimization notes
|
| 61 |
+
|
| 62 |
+
- Repeated runs with the same selected model reuse the loaded model inside the
|
| 63 |
+
ZeroGPU worker when the worker stays warm. Run metadata includes `cache_hit`
|
| 64 |
+
and `load_elapsed_s` so this is visible.
|
| 65 |
+
- Successful gated-repo access checks are cached briefly per token digest and
|
| 66 |
+
repo ID to avoid a Hugging Face `HEAD` request on every generation.
|
| 67 |
+
- The `stable-audio-3-optimized` repo currently provides MLX, ONNX, and
|
| 68 |
+
TensorRT assets. This Space keeps the portable PyTorch path because the
|
| 69 |
+
TensorRT engines are prebuilt for `sm_90`, while the current ZeroGPU host is
|
| 70 |
+
a Blackwell GPU, and MLX is Apple-only.
|
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import gc
|
|
|
|
| 4 |
import importlib
|
| 5 |
import importlib.util
|
| 6 |
import json
|
|
@@ -15,11 +16,13 @@ from contextlib import contextmanager
|
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from typing import Any
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
import gradio as gr
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
-
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 22 |
-
|
| 23 |
|
| 24 |
def _filter_known_unraisable(unraisable):
|
| 25 |
object_name = getattr(unraisable.object, "__qualname__", "")
|
|
@@ -171,6 +174,8 @@ COLLECTION_ROWS = [
|
|
| 171 |
|
| 172 |
MODEL_CACHE: dict[str, Any] = {"key": None, "model": None}
|
| 173 |
AE_CACHE: dict[str, Any] = {"key": None, "model": None}
|
|
|
|
|
|
|
| 174 |
MODEL_LOAD_LOCK = threading.RLock()
|
| 175 |
|
| 176 |
|
|
@@ -246,7 +251,20 @@ def stable_audio_token_hint(model: GenerationModel) -> str:
|
|
| 246 |
)
|
| 247 |
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
request = urllib.request.Request(
|
| 251 |
f"https://huggingface.co/{repo_id}/resolve/main/model_config.json",
|
| 252 |
method="HEAD",
|
|
@@ -254,7 +272,10 @@ def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str |
|
|
| 254 |
)
|
| 255 |
try:
|
| 256 |
with urllib.request.urlopen(request, timeout=20) as response:
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
| 258 |
except urllib.error.HTTPError as exc:
|
| 259 |
if exc.code in {401, 403}:
|
| 260 |
return (
|
|
@@ -389,12 +410,13 @@ def load_generation_model(
|
|
| 389 |
)
|
| 390 |
|
| 391 |
if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
|
| 392 |
-
return MODEL_CACHE["model"], device
|
| 393 |
|
| 394 |
with MODEL_LOAD_LOCK:
|
| 395 |
if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
|
| 396 |
-
return MODEL_CACHE["model"], device
|
| 397 |
|
|
|
|
| 398 |
MODEL_CACHE["model"] = None
|
| 399 |
MODEL_CACHE["key"] = None
|
| 400 |
clear_torch_memory()
|
|
@@ -406,7 +428,7 @@ def load_generation_model(
|
|
| 406 |
model = StableAudioModel.from_pretrained(model_key, model_half=model_half)
|
| 407 |
MODEL_CACHE["key"] = model_key
|
| 408 |
MODEL_CACHE["model"] = model
|
| 409 |
-
return model, device
|
| 410 |
|
| 411 |
|
| 412 |
def load_autoencoder(
|
|
@@ -429,12 +451,13 @@ def load_autoencoder(
|
|
| 429 |
)
|
| 430 |
|
| 431 |
if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
|
| 432 |
-
return AE_CACHE["model"], device
|
| 433 |
|
| 434 |
with MODEL_LOAD_LOCK:
|
| 435 |
if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
|
| 436 |
-
return AE_CACHE["model"], device
|
| 437 |
|
|
|
|
| 438 |
AE_CACHE["model"] = None
|
| 439 |
AE_CACHE["key"] = None
|
| 440 |
clear_torch_memory()
|
|
@@ -445,7 +468,7 @@ def load_autoencoder(
|
|
| 445 |
model = AutoencoderModel.from_pretrained(model_key)
|
| 446 |
AE_CACHE["key"] = model_key
|
| 447 |
AE_CACHE["model"] = model
|
| 448 |
-
return model, device
|
| 449 |
|
| 450 |
|
| 451 |
def model_changed(model_key: str):
|
|
@@ -521,7 +544,7 @@ def generate_audio(
|
|
| 521 |
if seed < 0:
|
| 522 |
seed = int.from_bytes(os.urandom(4), "little") % 100000
|
| 523 |
|
| 524 |
-
model, device = load_generation_model(
|
| 525 |
model_key,
|
| 526 |
allow_cpu_medium,
|
| 527 |
oauth_profile,
|
|
@@ -563,6 +586,8 @@ def generate_audio(
|
|
| 563 |
"seed": seed,
|
| 564 |
"sample_rate": sample_rate,
|
| 565 |
"elapsed_s": elapsed,
|
|
|
|
|
|
|
| 566 |
"output_file": out_file.name,
|
| 567 |
"note": model_def.note,
|
| 568 |
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
|
|
@@ -619,7 +644,7 @@ def roundtrip_autoencoder(
|
|
| 619 |
|
| 620 |
progress(0.05, desc="Loading autoencoder")
|
| 621 |
started = time.time()
|
| 622 |
-
model, device = load_autoencoder(
|
| 623 |
model_key,
|
| 624 |
allow_cpu_same_l,
|
| 625 |
oauth_profile,
|
|
@@ -655,6 +680,8 @@ def roundtrip_autoencoder(
|
|
| 655 |
"input_shape": list(waveform.shape),
|
| 656 |
"latent_shape": list(latents.shape),
|
| 657 |
"elapsed_s": round(time.time() - started, 3),
|
|
|
|
|
|
|
| 658 |
"output_file": out_file.name,
|
| 659 |
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
|
| 660 |
"username": oauth_username(oauth_profile),
|
|
@@ -710,6 +737,7 @@ def runtime_status(
|
|
| 710 |
"hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
|
| 711 |
"loaded_generation_model": MODEL_CACHE["key"],
|
| 712 |
"loaded_autoencoder": AE_CACHE["key"],
|
|
|
|
| 713 |
}
|
| 714 |
|
| 715 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import gc
|
| 4 |
+
import hashlib
|
| 5 |
import importlib
|
| 6 |
import importlib.util
|
| 7 |
import json
|
|
|
|
| 16 |
from dataclasses import dataclass
|
| 17 |
from typing import Any
|
| 18 |
|
| 19 |
+
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 20 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 21 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 22 |
+
|
| 23 |
import gradio as gr
|
| 24 |
import numpy as np
|
| 25 |
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def _filter_known_unraisable(unraisable):
|
| 28 |
object_name = getattr(unraisable.object, "__qualname__", "")
|
|
|
|
| 174 |
|
| 175 |
MODEL_CACHE: dict[str, Any] = {"key": None, "model": None}
|
| 176 |
AE_CACHE: dict[str, Any] = {"key": None, "model": None}
|
| 177 |
+
ACCESS_CACHE: dict[tuple[str, str], float] = {}
|
| 178 |
+
ACCESS_CACHE_TTL_SECONDS = max(0, int(os.getenv("SA3_ACCESS_CACHE_TTL_SECONDS", "600")))
|
| 179 |
MODEL_LOAD_LOCK = threading.RLock()
|
| 180 |
|
| 181 |
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
|
| 254 |
+
def access_cache_key(repo_id: str, token: str) -> tuple[str, str]:
|
| 255 |
+
token_digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16]
|
| 256 |
+
return repo_id, token_digest
|
| 257 |
+
|
| 258 |
+
|
| 259 |
def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]:
|
| 260 |
+
cache_key = access_cache_key(repo_id, token)
|
| 261 |
+
cached_until = ACCESS_CACHE.get(cache_key)
|
| 262 |
+
now = time.time()
|
| 263 |
+
if cached_until is not None:
|
| 264 |
+
if cached_until > now:
|
| 265 |
+
return True, None
|
| 266 |
+
ACCESS_CACHE.pop(cache_key, None)
|
| 267 |
+
|
| 268 |
request = urllib.request.Request(
|
| 269 |
f"https://huggingface.co/{repo_id}/resolve/main/model_config.json",
|
| 270 |
method="HEAD",
|
|
|
|
| 272 |
)
|
| 273 |
try:
|
| 274 |
with urllib.request.urlopen(request, timeout=20) as response:
|
| 275 |
+
has_access = response.status < 400
|
| 276 |
+
if has_access and ACCESS_CACHE_TTL_SECONDS:
|
| 277 |
+
ACCESS_CACHE[cache_key] = time.time() + ACCESS_CACHE_TTL_SECONDS
|
| 278 |
+
return has_access, None
|
| 279 |
except urllib.error.HTTPError as exc:
|
| 280 |
if exc.code in {401, 403}:
|
| 281 |
return (
|
|
|
|
| 410 |
)
|
| 411 |
|
| 412 |
if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
|
| 413 |
+
return MODEL_CACHE["model"], device, True, 0.0
|
| 414 |
|
| 415 |
with MODEL_LOAD_LOCK:
|
| 416 |
if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
|
| 417 |
+
return MODEL_CACHE["model"], device, True, 0.0
|
| 418 |
|
| 419 |
+
load_started = time.time()
|
| 420 |
MODEL_CACHE["model"] = None
|
| 421 |
MODEL_CACHE["key"] = None
|
| 422 |
clear_torch_memory()
|
|
|
|
| 428 |
model = StableAudioModel.from_pretrained(model_key, model_half=model_half)
|
| 429 |
MODEL_CACHE["key"] = model_key
|
| 430 |
MODEL_CACHE["model"] = model
|
| 431 |
+
return model, device, False, round(time.time() - load_started, 3)
|
| 432 |
|
| 433 |
|
| 434 |
def load_autoencoder(
|
|
|
|
| 451 |
)
|
| 452 |
|
| 453 |
if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
|
| 454 |
+
return AE_CACHE["model"], device, True, 0.0
|
| 455 |
|
| 456 |
with MODEL_LOAD_LOCK:
|
| 457 |
if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
|
| 458 |
+
return AE_CACHE["model"], device, True, 0.0
|
| 459 |
|
| 460 |
+
load_started = time.time()
|
| 461 |
AE_CACHE["model"] = None
|
| 462 |
AE_CACHE["key"] = None
|
| 463 |
clear_torch_memory()
|
|
|
|
| 468 |
model = AutoencoderModel.from_pretrained(model_key)
|
| 469 |
AE_CACHE["key"] = model_key
|
| 470 |
AE_CACHE["model"] = model
|
| 471 |
+
return model, device, False, round(time.time() - load_started, 3)
|
| 472 |
|
| 473 |
|
| 474 |
def model_changed(model_key: str):
|
|
|
|
| 544 |
if seed < 0:
|
| 545 |
seed = int.from_bytes(os.urandom(4), "little") % 100000
|
| 546 |
|
| 547 |
+
model, device, cache_hit, load_elapsed = load_generation_model(
|
| 548 |
model_key,
|
| 549 |
allow_cpu_medium,
|
| 550 |
oauth_profile,
|
|
|
|
| 586 |
"seed": seed,
|
| 587 |
"sample_rate": sample_rate,
|
| 588 |
"elapsed_s": elapsed,
|
| 589 |
+
"cache_hit": cache_hit,
|
| 590 |
+
"load_elapsed_s": load_elapsed,
|
| 591 |
"output_file": out_file.name,
|
| 592 |
"note": model_def.note,
|
| 593 |
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
|
|
|
|
| 644 |
|
| 645 |
progress(0.05, desc="Loading autoencoder")
|
| 646 |
started = time.time()
|
| 647 |
+
model, device, cache_hit, load_elapsed = load_autoencoder(
|
| 648 |
model_key,
|
| 649 |
allow_cpu_same_l,
|
| 650 |
oauth_profile,
|
|
|
|
| 680 |
"input_shape": list(waveform.shape),
|
| 681 |
"latent_shape": list(latents.shape),
|
| 682 |
"elapsed_s": round(time.time() - started, 3),
|
| 683 |
+
"cache_hit": cache_hit,
|
| 684 |
+
"load_elapsed_s": load_elapsed,
|
| 685 |
"output_file": out_file.name,
|
| 686 |
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
|
| 687 |
"username": oauth_username(oauth_profile),
|
|
|
|
| 737 |
"hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
|
| 738 |
"loaded_generation_model": MODEL_CACHE["key"],
|
| 739 |
"loaded_autoencoder": AE_CACHE["key"],
|
| 740 |
+
"access_cache_entries": len(ACCESS_CACHE),
|
| 741 |
}
|
| 742 |
|
| 743 |
|