Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import gc | |
| import hashlib | |
| import importlib | |
| import importlib.util | |
| import json | |
| import os | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import urllib.error | |
| import urllib.request | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from typing import Any | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| import gradio as gr | |
| import numpy as np | |
| def _filter_known_unraisable(unraisable): | |
| object_name = getattr(unraisable.object, "__qualname__", "") | |
| if ( | |
| object_name == "BaseEventLoop.__del__" | |
| and isinstance(unraisable.exc_value, ValueError) | |
| and "Invalid file descriptor" in str(unraisable.exc_value) | |
| ): | |
| return | |
| sys.__unraisablehook__(unraisable) | |
| sys.unraisablehook = _filter_known_unraisable | |
| class GenerationModel: | |
| label: str | |
| key: str | |
| repo_id: str | |
| family: str | |
| default_prompt: str | |
| default_duration: int | |
| max_duration: int | |
| default_steps: int | |
| default_cfg: float | |
| default_sampler: str | |
| requires_cuda: bool = False | |
| gated: bool = False | |
| note: str = "" | |
| GENERATION_MODELS: dict[str, GenerationModel] = { | |
| "small-music": GenerationModel( | |
| label="Stable Audio 3 Small Music", | |
| key="small-music", | |
| repo_id="stabilityai/stable-audio-3-small-music", | |
| family="post-trained", | |
| default_prompt=( | |
| "Warm lo-fi house groove, soft sidechained pads, clean drums, " | |
| "late-night atmosphere, 118 BPM" | |
| ), | |
| default_duration=20, | |
| max_duration=120, | |
| default_steps=8, | |
| default_cfg=1.0, | |
| default_sampler="pingpong", | |
| gated=True, | |
| note="Lightweight music checkpoint.", | |
| ), | |
| "small-sfx": GenerationModel( | |
| label="Stable Audio 3 Small SFX", | |
| key="small-sfx", | |
| repo_id="stabilityai/stable-audio-3-small-sfx", | |
| family="post-trained", | |
| default_prompt="Close binaural rain on a window, soft cloth movement, detailed texture", | |
| default_duration=8, | |
| max_duration=120, | |
| default_steps=8, | |
| default_cfg=1.0, | |
| default_sampler="pingpong", | |
| gated=True, | |
| note="Lightweight sound-effects checkpoint.", | |
| ), | |
| "medium": GenerationModel( | |
| label="Stable Audio 3 Medium", | |
| key="medium", | |
| repo_id="stabilityai/stable-audio-3-medium", | |
| family="post-trained", | |
| default_prompt=( | |
| "Cinematic ambient electronic cue, deep sub pulse, shimmering stereo texture, " | |
| "slow evolving melody" | |
| ), | |
| default_duration=20, | |
| max_duration=380, | |
| default_steps=8, | |
| default_cfg=1.0, | |
| default_sampler="pingpong", | |
| requires_cuda=True, | |
| gated=True, | |
| note="High-quality checkpoint; GPU-first.", | |
| ), | |
| "small-music-base": GenerationModel( | |
| label="Stable Audio 3 Small Music Base", | |
| key="small-music-base", | |
| repo_id="stabilityai/stable-audio-3-small-music-base", | |
| family="base", | |
| default_prompt="Dreamlike synthpop instrumental, surreal film sequence, 120 BPM", | |
| default_duration=20, | |
| max_duration=120, | |
| default_steps=50, | |
| default_cfg=7.0, | |
| default_sampler="euler", | |
| note="Base checkpoint intended mainly for fine-tuning.", | |
| ), | |
| "small-sfx-base": GenerationModel( | |
| label="Stable Audio 3 Small SFX Base", | |
| key="small-sfx-base", | |
| repo_id="stabilityai/stable-audio-3-small-sfx-base", | |
| family="base", | |
| default_prompt="Chugging train coming into station with horn", | |
| default_duration=7, | |
| max_duration=120, | |
| default_steps=50, | |
| default_cfg=7.0, | |
| default_sampler="euler", | |
| note="Base checkpoint intended mainly for fine-tuning.", | |
| ), | |
| "medium-base": GenerationModel( | |
| label="Stable Audio 3 Medium Base", | |
| key="medium-base", | |
| repo_id="stabilityai/stable-audio-3-medium-base", | |
| family="base", | |
| default_prompt="Dreamlike synthpop instrumental, surreal film sequence, 120 BPM", | |
| default_duration=20, | |
| max_duration=380, | |
| default_steps=50, | |
| default_cfg=7.0, | |
| default_sampler="euler", | |
| requires_cuda=True, | |
| note="Base checkpoint intended mainly for fine-tuning; GPU-first.", | |
| ), | |
| } | |
| AUTOENCODER_MODELS = { | |
| "same-s": { | |
| "label": "SAME-S", | |
| "repo_id": "stabilityai/SAME-S", | |
| "requires_cuda": False, | |
| }, | |
| "same-l": { | |
| "label": "SAME-L", | |
| "repo_id": "stabilityai/SAME-L", | |
| "requires_cuda": True, | |
| }, | |
| } | |
| COLLECTION_ROWS = [ | |
| ["stable-audio-3-small-music", "Text-to-audio", "Generate tab", "Gated post-trained small music"], | |
| ["stable-audio-3-small-sfx", "Text-to-audio", "Generate tab", "Gated post-trained small SFX"], | |
| ["stable-audio-3-medium", "Text-to-audio", "Generate tab", "Gated medium; GPU-first"], | |
| ["stable-audio-3-small-music-base", "Text-to-audio", "Generate tab", "Base checkpoint"], | |
| ["stable-audio-3-small-sfx-base", "Text-to-audio", "Generate tab", "Base checkpoint"], | |
| ["stable-audio-3-medium-base", "Text-to-audio", "Generate tab", "Base checkpoint; GPU-first"], | |
| ["stable-audio-3-optimized", "Optimized assets", "Listed only", "MLX/TensorRT artifacts, not generic PyTorch generation"], | |
| ["SAME-S", "Autoencoder", "Autoencoder tab", "CPU-capable round trip"], | |
| ["SAME-L", "Autoencoder", "Autoencoder tab", "Large autoencoder; CUDA recommended"], | |
| ] | |
| MODEL_CACHE: dict[str, Any] = {"key": None, "model": None} | |
| AE_CACHE: dict[str, Any] = {"key": None, "model": None} | |
| ACCESS_CACHE: dict[tuple[str, str], float] = {} | |
| ACCESS_CACHE_TTL_SECONDS = max(0, int(os.getenv("SA3_ACCESS_CACHE_TTL_SECONDS", "600"))) | |
| MODEL_LOAD_LOCK = threading.RLock() | |
| def gpu_task(duration: int): | |
| if os.getenv("SA3_USE_SPACES_GPU", "1").strip().lower() in {"0", "false", "no"}: | |
| return lambda fn: fn | |
| try: | |
| import spaces | |
| return spaces.GPU(duration=duration) | |
| except Exception: | |
| return lambda fn: fn | |
| def import_torch(): | |
| return importlib.import_module("torch") | |
| def current_device(torch_module: Any) -> str: | |
| if torch_module.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch_module.backends, "mps") and torch_module.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def flash_attn_available() -> bool: | |
| return importlib.util.find_spec("flash_attn") is not None | |
| def oauth_token_value(oauth_token: gr.OAuthToken | None) -> str | None: | |
| token = getattr(oauth_token, "token", None) | |
| return token if isinstance(token, str) and token else None | |
| def hf_api_token_value(hf_api_token: str | None) -> str | None: | |
| if not isinstance(hf_api_token, str): | |
| return None | |
| token = hf_api_token.strip() | |
| return token or None | |
| def request_token_value( | |
| oauth_token: gr.OAuthToken | None, | |
| hf_api_token: str | None, | |
| ) -> str | None: | |
| return hf_api_token_value(hf_api_token) or oauth_token_value(oauth_token) | |
| def oauth_username(oauth_profile: gr.OAuthProfile | None) -> str | None: | |
| username = getattr(oauth_profile, "username", None) | |
| return username if isinstance(username, str) and username else None | |
| def auth_source( | |
| oauth_profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| hf_api_token: str | None, | |
| ) -> str | None: | |
| if hf_api_token_value(hf_api_token): | |
| return "hf_token" | |
| if oauth_profile is not None and oauth_token_value(oauth_token): | |
| return "oauth" | |
| return None | |
| def stable_audio_token_hint(model: GenerationModel) -> str: | |
| if not model.gated: | |
| return "Sign in with Hugging Face or paste a Hugging Face access token before running this Space." | |
| return ( | |
| "Sign in with Hugging Face or paste a Hugging Face access token from an " | |
| "account that has accepted this gated model's terms." | |
| ) | |
| def access_cache_key(repo_id: str, token: str) -> tuple[str, str]: | |
| token_digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16] | |
| return repo_id, token_digest | |
| def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]: | |
| cache_key = access_cache_key(repo_id, token) | |
| cached_until = ACCESS_CACHE.get(cache_key) | |
| now = time.time() | |
| if cached_until is not None: | |
| if cached_until > now: | |
| return True, None | |
| ACCESS_CACHE.pop(cache_key, None) | |
| request = urllib.request.Request( | |
| f"https://huggingface.co/{repo_id}/resolve/main/model_config.json", | |
| method="HEAD", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| try: | |
| with urllib.request.urlopen(request, timeout=20) as response: | |
| has_access = response.status < 400 | |
| if has_access and ACCESS_CACHE_TTL_SECONDS: | |
| ACCESS_CACHE[cache_key] = time.time() + ACCESS_CACHE_TTL_SECONDS | |
| return has_access, None | |
| except urllib.error.HTTPError as exc: | |
| if exc.code in {401, 403}: | |
| return ( | |
| False, | |
| "Your Hugging Face account does not have access to this gated model yet. " | |
| "Open the model page while logged in, accept Stability's terms, then retry.", | |
| ) | |
| return False, f"Hugging Face access check failed with HTTP {exc.code}." | |
| except Exception as exc: | |
| return False, f"Hugging Face access check failed: {exc!r}" | |
| def hub_download_token(token: str | None): | |
| if not token: | |
| yield | |
| return | |
| import stable_audio_3.model_configs as model_configs | |
| original_download = model_configs.hf_hub_download | |
| token_env_keys = ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN") | |
| previous_token_env = {key: os.environ.get(key) for key in token_env_keys} | |
| def download_with_user_token(*args, **kwargs): | |
| kwargs.setdefault("token", token) | |
| return original_download(*args, **kwargs) | |
| model_configs.hf_hub_download = download_with_user_token | |
| for key in token_env_keys: | |
| os.environ[key] = token | |
| try: | |
| yield | |
| finally: | |
| model_configs.hf_hub_download = original_download | |
| for key, previous in previous_token_env.items(): | |
| if previous is None: | |
| os.environ.pop(key, None) | |
| else: | |
| os.environ[key] = previous | |
| def generation_preflight_error( | |
| model: GenerationModel, | |
| allow_cpu_medium: bool, | |
| oauth_profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| hf_api_token: str | None, | |
| ) -> tuple[str | None, str]: | |
| device = "unknown" | |
| token = request_token_value(oauth_token, hf_api_token) | |
| if not token: | |
| return ( | |
| "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.", | |
| device, | |
| ) | |
| torch = import_torch() | |
| device = current_device(torch) | |
| if model.requires_cuda and device != "cuda" and not allow_cpu_medium: | |
| return ( | |
| f"{model.label} is blocked on this runtime because CUDA is not available. " | |
| "Use a GPU Space or enable the CPU override for a slow/debug-only attempt.", | |
| device, | |
| ) | |
| if model.gated: | |
| has_access, error = user_can_download_gated_model(model.repo_id, token) | |
| if not has_access: | |
| return error or "Your Hugging Face account cannot access this gated model.", device | |
| return None, device | |
| def assert_generation_runtime( | |
| model: GenerationModel, | |
| allow_cpu_medium: bool, | |
| oauth_profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| hf_api_token: str | None, | |
| ) -> str: | |
| error, device = generation_preflight_error( | |
| model, | |
| allow_cpu_medium, | |
| oauth_profile, | |
| oauth_token, | |
| hf_api_token, | |
| ) | |
| if error: | |
| raise gr.Error(error) | |
| return device | |
| def normalize_audio_array(data: np.ndarray) -> np.ndarray: | |
| array = np.asarray(data) | |
| if np.issubdtype(array.dtype, np.integer): | |
| limit = max(abs(np.iinfo(array.dtype).min), np.iinfo(array.dtype).max) | |
| array = array.astype(np.float32) / float(limit) | |
| else: | |
| array = array.astype(np.float32) | |
| if array.ndim == 1: | |
| array = array[None, :] | |
| elif array.ndim == 2: | |
| array = array.T | |
| else: | |
| raise gr.Error("Audio must be mono or stereo.") | |
| return np.nan_to_num(array, nan=0.0, posinf=0.0, neginf=0.0) | |
| def clear_torch_memory() -> None: | |
| try: | |
| torch = import_torch() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| gc.collect() | |
| def load_generation_model( | |
| model_key: str, | |
| allow_cpu_medium: bool, | |
| oauth_profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| hf_api_token: str | None, | |
| ): | |
| model_def = GENERATION_MODELS[model_key] | |
| device = assert_generation_runtime( | |
| model_def, | |
| allow_cpu_medium, | |
| oauth_profile, | |
| oauth_token, | |
| hf_api_token, | |
| ) | |
| if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None: | |
| return MODEL_CACHE["model"], device, True, 0.0 | |
| with MODEL_LOAD_LOCK: | |
| if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None: | |
| return MODEL_CACHE["model"], device, True, 0.0 | |
| load_started = time.time() | |
| MODEL_CACHE["model"] = None | |
| MODEL_CACHE["key"] = None | |
| clear_torch_memory() | |
| from stable_audio_3 import StableAudioModel | |
| model_half = device == "cuda" | |
| with hub_download_token(request_token_value(oauth_token, hf_api_token)): | |
| model = StableAudioModel.from_pretrained(model_key, model_half=model_half) | |
| MODEL_CACHE["key"] = model_key | |
| MODEL_CACHE["model"] = model | |
| return model, device, False, round(time.time() - load_started, 3) | |
| def load_autoencoder( | |
| model_key: str, | |
| allow_cpu_same_l: bool, | |
| oauth_profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| hf_api_token: str | None, | |
| ): | |
| if not request_token_value(oauth_token, hf_api_token): | |
| raise gr.Error("Sign in with Hugging Face or paste a Hugging Face access token before running this Space.") | |
| model_def = AUTOENCODER_MODELS[model_key] | |
| torch = import_torch() | |
| device = current_device(torch) | |
| if model_def["requires_cuda"] and device != "cuda" and not allow_cpu_same_l: | |
| raise gr.Error( | |
| f"{model_def['label']} is blocked on this runtime because CUDA is not available. " | |
| "Use SAME-S or enable the CPU override for a slow/debug-only attempt." | |
| ) | |
| if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None: | |
| return AE_CACHE["model"], device, True, 0.0 | |
| with MODEL_LOAD_LOCK: | |
| if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None: | |
| return AE_CACHE["model"], device, True, 0.0 | |
| load_started = time.time() | |
| AE_CACHE["model"] = None | |
| AE_CACHE["key"] = None | |
| clear_torch_memory() | |
| from stable_audio_3 import AutoencoderModel | |
| with hub_download_token(request_token_value(oauth_token, hf_api_token)): | |
| model = AutoencoderModel.from_pretrained(model_key) | |
| AE_CACHE["key"] = model_key | |
| AE_CACHE["model"] = model | |
| return model, device, False, round(time.time() - load_started, 3) | |
| def model_changed(model_key: str): | |
| model = GENERATION_MODELS[model_key] | |
| return ( | |
| gr.update(value=model.default_prompt), | |
| gr.update(value=model.default_duration, maximum=model.max_duration), | |
| gr.update(value=model.default_steps), | |
| gr.update(value=model.default_cfg), | |
| gr.update(value=model.default_sampler), | |
| { | |
| "repo_id": model.repo_id, | |
| "family": model.family, | |
| "max_duration_s": model.max_duration, | |
| "default_sampler": model.default_sampler, | |
| "note": model.note, | |
| "token_hint": stable_audio_token_hint(model), | |
| }, | |
| ) | |
| def generate_audio( | |
| model_key: str, | |
| prompt: str, | |
| negative_prompt: str, | |
| duration: float, | |
| steps: int, | |
| cfg_scale: float, | |
| sampler_type: str, | |
| seed: int, | |
| chunked_decode: bool, | |
| allow_cpu_medium: bool, | |
| hf_api_token: str | None, | |
| oauth_profile: gr.OAuthProfile | None = None, | |
| oauth_token: gr.OAuthToken | None = None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| model_def = GENERATION_MODELS[model_key] | |
| if not prompt or not prompt.strip(): | |
| return None, { | |
| "status": "blocked", | |
| "error": "Prompt is required.", | |
| "model": model_def.key, | |
| "repo_id": model_def.repo_id, | |
| } | |
| preflight_error, preflight_device = generation_preflight_error( | |
| model_def, | |
| allow_cpu_medium, | |
| oauth_profile, | |
| oauth_token, | |
| hf_api_token, | |
| ) | |
| if preflight_error: | |
| return None, { | |
| "status": "blocked", | |
| "error": preflight_error, | |
| "model": model_def.key, | |
| "repo_id": model_def.repo_id, | |
| "device": preflight_device, | |
| "authenticated": bool(request_token_value(oauth_token, hf_api_token)), | |
| "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), | |
| "oauth_signed_in": oauth_profile is not None, | |
| "username": oauth_username(oauth_profile), | |
| "oauth_token_present": bool(oauth_token_value(oauth_token)), | |
| "hf_api_token_present": bool(hf_api_token_value(hf_api_token)), | |
| } | |
| progress(0.05, desc="Loading model") | |
| started = time.time() | |
| seed = int(seed) | |
| if seed < 0: | |
| seed = int.from_bytes(os.urandom(4), "little") % 100000 | |
| model, device, cache_hit, load_elapsed = load_generation_model( | |
| model_key, | |
| allow_cpu_medium, | |
| oauth_profile, | |
| oauth_token, | |
| hf_api_token, | |
| ) | |
| progress(0.25, desc="Generating") | |
| audio = model.generate( | |
| prompt=prompt.strip(), | |
| negative_prompt=negative_prompt.strip() or None, | |
| duration=float(duration), | |
| steps=int(steps), | |
| cfg_scale=float(cfg_scale), | |
| seed=seed, | |
| sampler_type=sampler_type, | |
| chunked_decode=bool(chunked_decode), | |
| ) | |
| progress(0.9, desc="Writing WAV") | |
| import torchaudio | |
| sample_rate = int(model.model_config["sample_rate"]) | |
| waveform = audio[0].detach().to("cpu").float().clamp(-1, 1) | |
| out_file = tempfile.NamedTemporaryFile(prefix=f"{model_key}-", suffix=".wav", delete=False) | |
| out_file.close() | |
| torchaudio.save(out_file.name, waveform, sample_rate) | |
| elapsed = round(time.time() - started, 3) | |
| metadata = { | |
| "status": "ok", | |
| "model": model_def.key, | |
| "repo_id": model_def.repo_id, | |
| "family": model_def.family, | |
| "device": device, | |
| "duration_s": float(duration), | |
| "steps": int(steps), | |
| "cfg_scale": float(cfg_scale), | |
| "sampler_type": sampler_type, | |
| "seed": seed, | |
| "sample_rate": sample_rate, | |
| "elapsed_s": elapsed, | |
| "cache_hit": cache_hit, | |
| "load_elapsed_s": load_elapsed, | |
| "output_file": out_file.name, | |
| "note": model_def.note, | |
| "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), | |
| "username": oauth_username(oauth_profile), | |
| } | |
| return out_file.name, metadata | |
| def roundtrip_autoencoder( | |
| model_key: str, | |
| audio_input: tuple[int, np.ndarray] | None, | |
| chunked: bool, | |
| allow_cpu_same_l: bool, | |
| hf_api_token: str | None, | |
| oauth_profile: gr.OAuthProfile | None = None, | |
| oauth_token: gr.OAuthToken | None = None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if not request_token_value(oauth_token, hf_api_token): | |
| return None, { | |
| "status": "blocked", | |
| "error": "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.", | |
| "autoencoder": model_key, | |
| "repo_id": AUTOENCODER_MODELS[model_key]["repo_id"], | |
| "authenticated": bool(request_token_value(oauth_token, hf_api_token)), | |
| "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), | |
| "oauth_signed_in": oauth_profile is not None, | |
| "hf_api_token_present": bool(hf_api_token_value(hf_api_token)), | |
| } | |
| if audio_input is None: | |
| return None, { | |
| "status": "blocked", | |
| "error": "Upload or record audio first.", | |
| "autoencoder": model_key, | |
| "repo_id": AUTOENCODER_MODELS[model_key]["repo_id"], | |
| } | |
| model_def = AUTOENCODER_MODELS[model_key] | |
| torch = import_torch() | |
| device = current_device(torch) | |
| if model_def["requires_cuda"] and device != "cuda" and not allow_cpu_same_l: | |
| return None, { | |
| "status": "blocked", | |
| "error": ( | |
| f"{model_def['label']} is blocked on this runtime because CUDA is not available. " | |
| "Use SAME-S or enable the CPU override for a slow/debug-only attempt." | |
| ), | |
| "autoencoder": model_key, | |
| "repo_id": model_def["repo_id"], | |
| "device": device, | |
| } | |
| progress(0.05, desc="Loading autoencoder") | |
| started = time.time() | |
| model, device, cache_hit, load_elapsed = load_autoencoder( | |
| model_key, | |
| allow_cpu_same_l, | |
| oauth_profile, | |
| oauth_token, | |
| hf_api_token, | |
| ) | |
| progress(0.25, desc="Encoding") | |
| sr, data = audio_input | |
| waveform_np = normalize_audio_array(data) | |
| torch = import_torch() | |
| waveform = torch.from_numpy(waveform_np) | |
| latents = model.encode(waveform, int(sr), chunked=bool(chunked)) | |
| progress(0.65, desc="Decoding") | |
| decoded = model.decode(latents, chunked=bool(chunked)) | |
| decoded = decoded[0].detach().to("cpu").float().clamp(-1, 1) | |
| import torchaudio | |
| out_file = tempfile.NamedTemporaryFile(prefix=f"{model_key}-roundtrip-", suffix=".wav", delete=False) | |
| out_file.close() | |
| torchaudio.save(out_file.name, decoded, int(model.sample_rate)) | |
| metadata = { | |
| "status": "ok", | |
| "autoencoder": model_key, | |
| "repo_id": AUTOENCODER_MODELS[model_key]["repo_id"], | |
| "device": device, | |
| "input_sample_rate": int(sr), | |
| "output_sample_rate": int(model.sample_rate), | |
| "input_shape": list(waveform.shape), | |
| "latent_shape": list(latents.shape), | |
| "elapsed_s": round(time.time() - started, 3), | |
| "cache_hit": cache_hit, | |
| "load_elapsed_s": load_elapsed, | |
| "output_file": out_file.name, | |
| "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), | |
| "username": oauth_username(oauth_profile), | |
| } | |
| return out_file.name, metadata | |
| def unload_models( | |
| hf_api_token: str | None = None, | |
| oauth_profile: gr.OAuthProfile | None = None, | |
| oauth_token: gr.OAuthToken | None = None, | |
| ): | |
| if not request_token_value(oauth_token, hf_api_token): | |
| return { | |
| "status": "blocked", | |
| "error": "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.", | |
| } | |
| MODEL_CACHE["key"] = None | |
| MODEL_CACHE["model"] = None | |
| AE_CACHE["key"] = None | |
| AE_CACHE["model"] = None | |
| clear_torch_memory() | |
| return { | |
| "status": "unloaded", | |
| "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), | |
| "username": oauth_username(oauth_profile), | |
| } | |
| def runtime_status( | |
| hf_api_token: str | None = None, | |
| oauth_profile: gr.OAuthProfile | None = None, | |
| oauth_token: gr.OAuthToken | None = None, | |
| ): | |
| try: | |
| torch = import_torch() | |
| device = current_device(torch) | |
| cuda_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else None | |
| except Exception as exc: | |
| device = "unavailable" | |
| cuda_name = None | |
| return {"torch": repr(exc), "device": device} | |
| return { | |
| "device": device, | |
| "cuda_name": cuda_name, | |
| "flash_attn": flash_attn_available(), | |
| "authenticated": bool(request_token_value(oauth_token, hf_api_token)), | |
| "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), | |
| "oauth_signed_in": oauth_profile is not None, | |
| "username": oauth_username(oauth_profile), | |
| "oauth_token_present": bool(oauth_token_value(oauth_token)), | |
| "hf_api_token_present": bool(hf_api_token_value(hf_api_token)), | |
| "loaded_generation_model": MODEL_CACHE["key"], | |
| "loaded_autoencoder": AE_CACHE["key"], | |
| } | |
| MODEL_CHOICES = [(model.label, key) for key, model in GENERATION_MODELS.items()] | |
| AE_CHOICES = [(value["label"], key) for key, value in AUTOENCODER_MODELS.items()] | |
| SAMPLER_CHOICES = ["pingpong", "euler", "rk4", "dpmpp", "dpmpp-3m-sde"] | |
| css = """ | |
| .gradio-container { max-width: 1160px !important; } | |
| #run-buttons button { min-height: 42px; } | |
| """ | |
| with gr.Blocks(title="Stable Audio 3 Lab") as demo: | |
| gr.Markdown("# Stable Audio 3 Lab") | |
| gr.LoginButton(value="Sign in with Hugging Face", logout_value="Logout ({})") | |
| hf_api_token_box = gr.Textbox( | |
| label="Hugging Face access token", | |
| type="password", | |
| placeholder="hf_...", | |
| lines=1, | |
| value="", | |
| info="Optional fallback for API use or browsers where OAuth is unavailable. Use a read token from an account with access to the selected Stability AI model.", | |
| ) | |
| with gr.Tab("Generate"): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=2): | |
| model_dropdown = gr.Dropdown( | |
| label="Model", | |
| choices=MODEL_CHOICES, | |
| value="small-sfx", | |
| interactive=True, | |
| ) | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| value=GENERATION_MODELS["small-sfx"].default_prompt, | |
| lines=4, | |
| ) | |
| negative_prompt_box = gr.Textbox(label="Negative prompt", lines=2) | |
| with gr.Row(): | |
| duration_slider = gr.Slider( | |
| label="Duration", | |
| minimum=1, | |
| maximum=GENERATION_MODELS["small-sfx"].max_duration, | |
| value=GENERATION_MODELS["small-sfx"].default_duration, | |
| step=1, | |
| ) | |
| steps_slider = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=GENERATION_MODELS["small-sfx"].default_steps, | |
| step=1, | |
| ) | |
| cfg_slider = gr.Slider( | |
| label="CFG", | |
| minimum=0, | |
| maximum=12, | |
| value=GENERATION_MODELS["small-sfx"].default_cfg, | |
| step=0.1, | |
| ) | |
| with gr.Row(): | |
| sampler_dropdown = gr.Dropdown( | |
| label="Sampler", | |
| choices=SAMPLER_CHOICES, | |
| value=GENERATION_MODELS["small-sfx"].default_sampler, | |
| ) | |
| seed_number = gr.Number(label="Seed", value=-1, precision=0) | |
| with gr.Row(): | |
| chunked_decode_box = gr.Checkbox(label="Chunked decode", value=True) | |
| allow_cpu_medium_box = gr.Checkbox(label="CPU override", value=False) | |
| with gr.Row(elem_id="run-buttons"): | |
| generate_button = gr.Button("Generate", variant="primary") | |
| unload_button = gr.Button("Unload") | |
| status_button = gr.Button("Runtime") | |
| with gr.Column(scale=1): | |
| model_info = gr.JSON( | |
| label="Model info", | |
| value={ | |
| "repo_id": GENERATION_MODELS["small-sfx"].repo_id, | |
| "family": GENERATION_MODELS["small-sfx"].family, | |
| "note": GENERATION_MODELS["small-sfx"].note, | |
| "token_hint": stable_audio_token_hint(GENERATION_MODELS["small-sfx"]), | |
| }, | |
| ) | |
| audio_output = gr.Audio(label="Output", type="filepath") | |
| metadata_output = gr.JSON(label="Run metadata") | |
| model_dropdown.change( | |
| model_changed, | |
| inputs=model_dropdown, | |
| outputs=[ | |
| prompt_box, | |
| duration_slider, | |
| steps_slider, | |
| cfg_slider, | |
| sampler_dropdown, | |
| model_info, | |
| ], | |
| ) | |
| generate_button.click( | |
| generate_audio, | |
| inputs=[ | |
| model_dropdown, | |
| prompt_box, | |
| negative_prompt_box, | |
| duration_slider, | |
| steps_slider, | |
| cfg_slider, | |
| sampler_dropdown, | |
| seed_number, | |
| chunked_decode_box, | |
| allow_cpu_medium_box, | |
| hf_api_token_box, | |
| ], | |
| outputs=[audio_output, metadata_output], | |
| concurrency_limit=1, | |
| ) | |
| unload_button.click(unload_models, inputs=hf_api_token_box, outputs=metadata_output) | |
| status_button.click(runtime_status, inputs=hf_api_token_box, outputs=metadata_output) | |
| with gr.Tab("Autoencoder"): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=2): | |
| ae_dropdown = gr.Dropdown(label="Autoencoder", choices=AE_CHOICES, value="same-s") | |
| ae_audio_input = gr.Audio(label="Input", sources=["upload", "microphone"], type="numpy") | |
| with gr.Row(): | |
| ae_chunked_box = gr.Checkbox(label="Chunked", value=True) | |
| ae_allow_cpu_box = gr.Checkbox(label="CPU override", value=False) | |
| ae_button = gr.Button("Round Trip", variant="primary") | |
| with gr.Column(scale=1): | |
| ae_output = gr.Audio(label="Decoded", type="filepath") | |
| ae_metadata = gr.JSON(label="Round-trip metadata") | |
| ae_button.click( | |
| roundtrip_autoencoder, | |
| inputs=[ae_dropdown, ae_audio_input, ae_chunked_box, ae_allow_cpu_box, hf_api_token_box], | |
| outputs=[ae_output, ae_metadata], | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("Coverage"): | |
| gr.Dataframe( | |
| value=COLLECTION_ROWS, | |
| headers=["Collection entry", "Type", "Space path", "Status"], | |
| datatype=["str", "str", "str", "str"], | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| gr.JSON(label="Runtime", value=runtime_status()) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch(css=css, ssr_mode=False) | |