from __future__ import annotations import gc import hashlib import importlib import importlib.util import json import os import sys import tempfile import threading import time import urllib.error import urllib.request from contextlib import contextmanager from dataclasses import dataclass from typing import Any os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import gradio as gr import numpy as np def _filter_known_unraisable(unraisable): object_name = getattr(unraisable.object, "__qualname__", "") if ( object_name == "BaseEventLoop.__del__" and isinstance(unraisable.exc_value, ValueError) and "Invalid file descriptor" in str(unraisable.exc_value) ): return sys.__unraisablehook__(unraisable) sys.unraisablehook = _filter_known_unraisable @dataclass(frozen=True) class GenerationModel: label: str key: str repo_id: str family: str default_prompt: str default_duration: int max_duration: int default_steps: int default_cfg: float default_sampler: str requires_cuda: bool = False gated: bool = False note: str = "" GENERATION_MODELS: dict[str, GenerationModel] = { "small-music": GenerationModel( label="Stable Audio 3 Small Music", key="small-music", repo_id="stabilityai/stable-audio-3-small-music", family="post-trained", default_prompt=( "Warm lo-fi house groove, soft sidechained pads, clean drums, " "late-night atmosphere, 118 BPM" ), default_duration=20, max_duration=120, default_steps=8, default_cfg=1.0, default_sampler="pingpong", gated=True, note="Lightweight music checkpoint.", ), "small-sfx": GenerationModel( label="Stable Audio 3 Small SFX", key="small-sfx", repo_id="stabilityai/stable-audio-3-small-sfx", family="post-trained", default_prompt="Close binaural rain on a window, soft cloth movement, detailed texture", default_duration=8, max_duration=120, default_steps=8, default_cfg=1.0, default_sampler="pingpong", gated=True, note="Lightweight sound-effects checkpoint.", ), "medium": GenerationModel( label="Stable Audio 3 Medium", key="medium", repo_id="stabilityai/stable-audio-3-medium", family="post-trained", default_prompt=( "Cinematic ambient electronic cue, deep sub pulse, shimmering stereo texture, " "slow evolving melody" ), default_duration=20, max_duration=380, default_steps=8, default_cfg=1.0, default_sampler="pingpong", requires_cuda=True, gated=True, note="High-quality checkpoint; GPU-first.", ), "small-music-base": GenerationModel( label="Stable Audio 3 Small Music Base", key="small-music-base", repo_id="stabilityai/stable-audio-3-small-music-base", family="base", default_prompt="Dreamlike synthpop instrumental, surreal film sequence, 120 BPM", default_duration=20, max_duration=120, default_steps=50, default_cfg=7.0, default_sampler="euler", note="Base checkpoint intended mainly for fine-tuning.", ), "small-sfx-base": GenerationModel( label="Stable Audio 3 Small SFX Base", key="small-sfx-base", repo_id="stabilityai/stable-audio-3-small-sfx-base", family="base", default_prompt="Chugging train coming into station with horn", default_duration=7, max_duration=120, default_steps=50, default_cfg=7.0, default_sampler="euler", note="Base checkpoint intended mainly for fine-tuning.", ), "medium-base": GenerationModel( label="Stable Audio 3 Medium Base", key="medium-base", repo_id="stabilityai/stable-audio-3-medium-base", family="base", default_prompt="Dreamlike synthpop instrumental, surreal film sequence, 120 BPM", default_duration=20, max_duration=380, default_steps=50, default_cfg=7.0, default_sampler="euler", requires_cuda=True, note="Base checkpoint intended mainly for fine-tuning; GPU-first.", ), } AUTOENCODER_MODELS = { "same-s": { "label": "SAME-S", "repo_id": "stabilityai/SAME-S", "requires_cuda": False, }, "same-l": { "label": "SAME-L", "repo_id": "stabilityai/SAME-L", "requires_cuda": True, }, } COLLECTION_ROWS = [ ["stable-audio-3-small-music", "Text-to-audio", "Generate tab", "Gated post-trained small music"], ["stable-audio-3-small-sfx", "Text-to-audio", "Generate tab", "Gated post-trained small SFX"], ["stable-audio-3-medium", "Text-to-audio", "Generate tab", "Gated medium; GPU-first"], ["stable-audio-3-small-music-base", "Text-to-audio", "Generate tab", "Base checkpoint"], ["stable-audio-3-small-sfx-base", "Text-to-audio", "Generate tab", "Base checkpoint"], ["stable-audio-3-medium-base", "Text-to-audio", "Generate tab", "Base checkpoint; GPU-first"], ["stable-audio-3-optimized", "Optimized assets", "Listed only", "MLX/TensorRT artifacts, not generic PyTorch generation"], ["SAME-S", "Autoencoder", "Autoencoder tab", "CPU-capable round trip"], ["SAME-L", "Autoencoder", "Autoencoder tab", "Large autoencoder; CUDA recommended"], ] MODEL_CACHE: dict[str, Any] = {"key": None, "model": None} AE_CACHE: dict[str, Any] = {"key": None, "model": None} ACCESS_CACHE: dict[tuple[str, str], float] = {} ACCESS_CACHE_TTL_SECONDS = max(0, int(os.getenv("SA3_ACCESS_CACHE_TTL_SECONDS", "600"))) MODEL_LOAD_LOCK = threading.RLock() def gpu_task(duration: int): if os.getenv("SA3_USE_SPACES_GPU", "1").strip().lower() in {"0", "false", "no"}: return lambda fn: fn try: import spaces return spaces.GPU(duration=duration) except Exception: return lambda fn: fn def import_torch(): return importlib.import_module("torch") def current_device(torch_module: Any) -> str: if torch_module.cuda.is_available(): return "cuda" if hasattr(torch_module.backends, "mps") and torch_module.backends.mps.is_available(): return "mps" return "cpu" def flash_attn_available() -> bool: return importlib.util.find_spec("flash_attn") is not None def oauth_token_value(oauth_token: gr.OAuthToken | None) -> str | None: token = getattr(oauth_token, "token", None) return token if isinstance(token, str) and token else None def hf_api_token_value(hf_api_token: str | None) -> str | None: if not isinstance(hf_api_token, str): return None token = hf_api_token.strip() return token or None def request_token_value( oauth_token: gr.OAuthToken | None, hf_api_token: str | None, ) -> str | None: return hf_api_token_value(hf_api_token) or oauth_token_value(oauth_token) def oauth_username(oauth_profile: gr.OAuthProfile | None) -> str | None: username = getattr(oauth_profile, "username", None) return username if isinstance(username, str) and username else None def auth_source( oauth_profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, hf_api_token: str | None, ) -> str | None: if hf_api_token_value(hf_api_token): return "hf_token" if oauth_profile is not None and oauth_token_value(oauth_token): return "oauth" return None def stable_audio_token_hint(model: GenerationModel) -> str: if not model.gated: return "Sign in with Hugging Face or paste a Hugging Face access token before running this Space." return ( "Sign in with Hugging Face or paste a Hugging Face access token from an " "account that has accepted this gated model's terms." ) def access_cache_key(repo_id: str, token: str) -> tuple[str, str]: token_digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16] return repo_id, token_digest def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]: cache_key = access_cache_key(repo_id, token) cached_until = ACCESS_CACHE.get(cache_key) now = time.time() if cached_until is not None: if cached_until > now: return True, None ACCESS_CACHE.pop(cache_key, None) request = urllib.request.Request( f"https://huggingface.co/{repo_id}/resolve/main/model_config.json", method="HEAD", headers={"Authorization": f"Bearer {token}"}, ) try: with urllib.request.urlopen(request, timeout=20) as response: has_access = response.status < 400 if has_access and ACCESS_CACHE_TTL_SECONDS: ACCESS_CACHE[cache_key] = time.time() + ACCESS_CACHE_TTL_SECONDS return has_access, None except urllib.error.HTTPError as exc: if exc.code in {401, 403}: return ( False, "Your Hugging Face account does not have access to this gated model yet. " "Open the model page while logged in, accept Stability's terms, then retry.", ) return False, f"Hugging Face access check failed with HTTP {exc.code}." except Exception as exc: return False, f"Hugging Face access check failed: {exc!r}" @contextmanager def hub_download_token(token: str | None): if not token: yield return import stable_audio_3.model_configs as model_configs original_download = model_configs.hf_hub_download token_env_keys = ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN") previous_token_env = {key: os.environ.get(key) for key in token_env_keys} def download_with_user_token(*args, **kwargs): kwargs.setdefault("token", token) return original_download(*args, **kwargs) model_configs.hf_hub_download = download_with_user_token for key in token_env_keys: os.environ[key] = token try: yield finally: model_configs.hf_hub_download = original_download for key, previous in previous_token_env.items(): if previous is None: os.environ.pop(key, None) else: os.environ[key] = previous def generation_preflight_error( model: GenerationModel, allow_cpu_medium: bool, oauth_profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, hf_api_token: str | None, ) -> tuple[str | None, str]: device = "unknown" token = request_token_value(oauth_token, hf_api_token) if not token: return ( "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.", device, ) torch = import_torch() device = current_device(torch) if model.requires_cuda and device != "cuda" and not allow_cpu_medium: return ( f"{model.label} is blocked on this runtime because CUDA is not available. " "Use a GPU Space or enable the CPU override for a slow/debug-only attempt.", device, ) if model.gated: has_access, error = user_can_download_gated_model(model.repo_id, token) if not has_access: return error or "Your Hugging Face account cannot access this gated model.", device return None, device def assert_generation_runtime( model: GenerationModel, allow_cpu_medium: bool, oauth_profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, hf_api_token: str | None, ) -> str: error, device = generation_preflight_error( model, allow_cpu_medium, oauth_profile, oauth_token, hf_api_token, ) if error: raise gr.Error(error) return device def normalize_audio_array(data: np.ndarray) -> np.ndarray: array = np.asarray(data) if np.issubdtype(array.dtype, np.integer): limit = max(abs(np.iinfo(array.dtype).min), np.iinfo(array.dtype).max) array = array.astype(np.float32) / float(limit) else: array = array.astype(np.float32) if array.ndim == 1: array = array[None, :] elif array.ndim == 2: array = array.T else: raise gr.Error("Audio must be mono or stereo.") return np.nan_to_num(array, nan=0.0, posinf=0.0, neginf=0.0) def clear_torch_memory() -> None: try: torch = import_torch() if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass gc.collect() def load_generation_model( model_key: str, allow_cpu_medium: bool, oauth_profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, hf_api_token: str | None, ): model_def = GENERATION_MODELS[model_key] device = assert_generation_runtime( model_def, allow_cpu_medium, oauth_profile, oauth_token, hf_api_token, ) if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None: return MODEL_CACHE["model"], device, True, 0.0 with MODEL_LOAD_LOCK: if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None: return MODEL_CACHE["model"], device, True, 0.0 load_started = time.time() MODEL_CACHE["model"] = None MODEL_CACHE["key"] = None clear_torch_memory() from stable_audio_3 import StableAudioModel model_half = device == "cuda" with hub_download_token(request_token_value(oauth_token, hf_api_token)): model = StableAudioModel.from_pretrained(model_key, model_half=model_half) MODEL_CACHE["key"] = model_key MODEL_CACHE["model"] = model return model, device, False, round(time.time() - load_started, 3) def load_autoencoder( model_key: str, allow_cpu_same_l: bool, oauth_profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, hf_api_token: str | None, ): if not request_token_value(oauth_token, hf_api_token): raise gr.Error("Sign in with Hugging Face or paste a Hugging Face access token before running this Space.") model_def = AUTOENCODER_MODELS[model_key] torch = import_torch() device = current_device(torch) if model_def["requires_cuda"] and device != "cuda" and not allow_cpu_same_l: raise gr.Error( f"{model_def['label']} is blocked on this runtime because CUDA is not available. " "Use SAME-S or enable the CPU override for a slow/debug-only attempt." ) if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None: return AE_CACHE["model"], device, True, 0.0 with MODEL_LOAD_LOCK: if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None: return AE_CACHE["model"], device, True, 0.0 load_started = time.time() AE_CACHE["model"] = None AE_CACHE["key"] = None clear_torch_memory() from stable_audio_3 import AutoencoderModel with hub_download_token(request_token_value(oauth_token, hf_api_token)): model = AutoencoderModel.from_pretrained(model_key) AE_CACHE["key"] = model_key AE_CACHE["model"] = model return model, device, False, round(time.time() - load_started, 3) def model_changed(model_key: str): model = GENERATION_MODELS[model_key] return ( gr.update(value=model.default_prompt), gr.update(value=model.default_duration, maximum=model.max_duration), gr.update(value=model.default_steps), gr.update(value=model.default_cfg), gr.update(value=model.default_sampler), { "repo_id": model.repo_id, "family": model.family, "max_duration_s": model.max_duration, "default_sampler": model.default_sampler, "note": model.note, "token_hint": stable_audio_token_hint(model), }, ) @gpu_task(duration=int(os.getenv("SPACES_GENERATE_GPU_SECONDS", "900"))) def generate_audio( model_key: str, prompt: str, negative_prompt: str, duration: float, steps: int, cfg_scale: float, sampler_type: str, seed: int, chunked_decode: bool, allow_cpu_medium: bool, hf_api_token: str | None, oauth_profile: gr.OAuthProfile | None = None, oauth_token: gr.OAuthToken | None = None, progress=gr.Progress(track_tqdm=True), ): model_def = GENERATION_MODELS[model_key] if not prompt or not prompt.strip(): return None, { "status": "blocked", "error": "Prompt is required.", "model": model_def.key, "repo_id": model_def.repo_id, } preflight_error, preflight_device = generation_preflight_error( model_def, allow_cpu_medium, oauth_profile, oauth_token, hf_api_token, ) if preflight_error: return None, { "status": "blocked", "error": preflight_error, "model": model_def.key, "repo_id": model_def.repo_id, "device": preflight_device, "authenticated": bool(request_token_value(oauth_token, hf_api_token)), "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), "oauth_signed_in": oauth_profile is not None, "username": oauth_username(oauth_profile), "oauth_token_present": bool(oauth_token_value(oauth_token)), "hf_api_token_present": bool(hf_api_token_value(hf_api_token)), } progress(0.05, desc="Loading model") started = time.time() seed = int(seed) if seed < 0: seed = int.from_bytes(os.urandom(4), "little") % 100000 model, device, cache_hit, load_elapsed = load_generation_model( model_key, allow_cpu_medium, oauth_profile, oauth_token, hf_api_token, ) progress(0.25, desc="Generating") audio = model.generate( prompt=prompt.strip(), negative_prompt=negative_prompt.strip() or None, duration=float(duration), steps=int(steps), cfg_scale=float(cfg_scale), seed=seed, sampler_type=sampler_type, chunked_decode=bool(chunked_decode), ) progress(0.9, desc="Writing WAV") import torchaudio sample_rate = int(model.model_config["sample_rate"]) waveform = audio[0].detach().to("cpu").float().clamp(-1, 1) out_file = tempfile.NamedTemporaryFile(prefix=f"{model_key}-", suffix=".wav", delete=False) out_file.close() torchaudio.save(out_file.name, waveform, sample_rate) elapsed = round(time.time() - started, 3) metadata = { "status": "ok", "model": model_def.key, "repo_id": model_def.repo_id, "family": model_def.family, "device": device, "duration_s": float(duration), "steps": int(steps), "cfg_scale": float(cfg_scale), "sampler_type": sampler_type, "seed": seed, "sample_rate": sample_rate, "elapsed_s": elapsed, "cache_hit": cache_hit, "load_elapsed_s": load_elapsed, "output_file": out_file.name, "note": model_def.note, "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), "username": oauth_username(oauth_profile), } return out_file.name, metadata @gpu_task(duration=int(os.getenv("SPACES_AUTOENCODER_GPU_SECONDS", "600"))) def roundtrip_autoencoder( model_key: str, audio_input: tuple[int, np.ndarray] | None, chunked: bool, allow_cpu_same_l: bool, hf_api_token: str | None, oauth_profile: gr.OAuthProfile | None = None, oauth_token: gr.OAuthToken | None = None, progress=gr.Progress(track_tqdm=True), ): if not request_token_value(oauth_token, hf_api_token): return None, { "status": "blocked", "error": "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.", "autoencoder": model_key, "repo_id": AUTOENCODER_MODELS[model_key]["repo_id"], "authenticated": bool(request_token_value(oauth_token, hf_api_token)), "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), "oauth_signed_in": oauth_profile is not None, "hf_api_token_present": bool(hf_api_token_value(hf_api_token)), } if audio_input is None: return None, { "status": "blocked", "error": "Upload or record audio first.", "autoencoder": model_key, "repo_id": AUTOENCODER_MODELS[model_key]["repo_id"], } model_def = AUTOENCODER_MODELS[model_key] torch = import_torch() device = current_device(torch) if model_def["requires_cuda"] and device != "cuda" and not allow_cpu_same_l: return None, { "status": "blocked", "error": ( f"{model_def['label']} is blocked on this runtime because CUDA is not available. " "Use SAME-S or enable the CPU override for a slow/debug-only attempt." ), "autoencoder": model_key, "repo_id": model_def["repo_id"], "device": device, } progress(0.05, desc="Loading autoencoder") started = time.time() model, device, cache_hit, load_elapsed = load_autoencoder( model_key, allow_cpu_same_l, oauth_profile, oauth_token, hf_api_token, ) progress(0.25, desc="Encoding") sr, data = audio_input waveform_np = normalize_audio_array(data) torch = import_torch() waveform = torch.from_numpy(waveform_np) latents = model.encode(waveform, int(sr), chunked=bool(chunked)) progress(0.65, desc="Decoding") decoded = model.decode(latents, chunked=bool(chunked)) decoded = decoded[0].detach().to("cpu").float().clamp(-1, 1) import torchaudio out_file = tempfile.NamedTemporaryFile(prefix=f"{model_key}-roundtrip-", suffix=".wav", delete=False) out_file.close() torchaudio.save(out_file.name, decoded, int(model.sample_rate)) metadata = { "status": "ok", "autoencoder": model_key, "repo_id": AUTOENCODER_MODELS[model_key]["repo_id"], "device": device, "input_sample_rate": int(sr), "output_sample_rate": int(model.sample_rate), "input_shape": list(waveform.shape), "latent_shape": list(latents.shape), "elapsed_s": round(time.time() - started, 3), "cache_hit": cache_hit, "load_elapsed_s": load_elapsed, "output_file": out_file.name, "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), "username": oauth_username(oauth_profile), } return out_file.name, metadata def unload_models( hf_api_token: str | None = None, oauth_profile: gr.OAuthProfile | None = None, oauth_token: gr.OAuthToken | None = None, ): if not request_token_value(oauth_token, hf_api_token): return { "status": "blocked", "error": "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.", } MODEL_CACHE["key"] = None MODEL_CACHE["model"] = None AE_CACHE["key"] = None AE_CACHE["model"] = None clear_torch_memory() return { "status": "unloaded", "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), "username": oauth_username(oauth_profile), } def runtime_status( hf_api_token: str | None = None, oauth_profile: gr.OAuthProfile | None = None, oauth_token: gr.OAuthToken | None = None, ): try: torch = import_torch() device = current_device(torch) cuda_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else None except Exception as exc: device = "unavailable" cuda_name = None return {"torch": repr(exc), "device": device} return { "device": device, "cuda_name": cuda_name, "flash_attn": flash_attn_available(), "authenticated": bool(request_token_value(oauth_token, hf_api_token)), "auth_source": auth_source(oauth_profile, oauth_token, hf_api_token), "oauth_signed_in": oauth_profile is not None, "username": oauth_username(oauth_profile), "oauth_token_present": bool(oauth_token_value(oauth_token)), "hf_api_token_present": bool(hf_api_token_value(hf_api_token)), "loaded_generation_model": MODEL_CACHE["key"], "loaded_autoencoder": AE_CACHE["key"], } MODEL_CHOICES = [(model.label, key) for key, model in GENERATION_MODELS.items()] AE_CHOICES = [(value["label"], key) for key, value in AUTOENCODER_MODELS.items()] SAMPLER_CHOICES = ["pingpong", "euler", "rk4", "dpmpp", "dpmpp-3m-sde"] css = """ .gradio-container { max-width: 1160px !important; } #run-buttons button { min-height: 42px; } """ with gr.Blocks(title="Stable Audio 3 Lab") as demo: gr.Markdown("# Stable Audio 3 Lab") gr.LoginButton(value="Sign in with Hugging Face", logout_value="Logout ({})") hf_api_token_box = gr.Textbox( label="Hugging Face access token", type="password", placeholder="hf_...", lines=1, value="", info="Optional fallback for API use or browsers where OAuth is unavailable. Use a read token from an account with access to the selected Stability AI model.", ) with gr.Tab("Generate"): with gr.Row(equal_height=False): with gr.Column(scale=2): model_dropdown = gr.Dropdown( label="Model", choices=MODEL_CHOICES, value="small-sfx", interactive=True, ) prompt_box = gr.Textbox( label="Prompt", value=GENERATION_MODELS["small-sfx"].default_prompt, lines=4, ) negative_prompt_box = gr.Textbox(label="Negative prompt", lines=2) with gr.Row(): duration_slider = gr.Slider( label="Duration", minimum=1, maximum=GENERATION_MODELS["small-sfx"].max_duration, value=GENERATION_MODELS["small-sfx"].default_duration, step=1, ) steps_slider = gr.Slider( label="Steps", minimum=1, maximum=100, value=GENERATION_MODELS["small-sfx"].default_steps, step=1, ) cfg_slider = gr.Slider( label="CFG", minimum=0, maximum=12, value=GENERATION_MODELS["small-sfx"].default_cfg, step=0.1, ) with gr.Row(): sampler_dropdown = gr.Dropdown( label="Sampler", choices=SAMPLER_CHOICES, value=GENERATION_MODELS["small-sfx"].default_sampler, ) seed_number = gr.Number(label="Seed", value=-1, precision=0) with gr.Row(): chunked_decode_box = gr.Checkbox(label="Chunked decode", value=True) allow_cpu_medium_box = gr.Checkbox(label="CPU override", value=False) with gr.Row(elem_id="run-buttons"): generate_button = gr.Button("Generate", variant="primary") unload_button = gr.Button("Unload") status_button = gr.Button("Runtime") with gr.Column(scale=1): model_info = gr.JSON( label="Model info", value={ "repo_id": GENERATION_MODELS["small-sfx"].repo_id, "family": GENERATION_MODELS["small-sfx"].family, "note": GENERATION_MODELS["small-sfx"].note, "token_hint": stable_audio_token_hint(GENERATION_MODELS["small-sfx"]), }, ) audio_output = gr.Audio(label="Output", type="filepath") metadata_output = gr.JSON(label="Run metadata") model_dropdown.change( model_changed, inputs=model_dropdown, outputs=[ prompt_box, duration_slider, steps_slider, cfg_slider, sampler_dropdown, model_info, ], ) generate_button.click( generate_audio, inputs=[ model_dropdown, prompt_box, negative_prompt_box, duration_slider, steps_slider, cfg_slider, sampler_dropdown, seed_number, chunked_decode_box, allow_cpu_medium_box, hf_api_token_box, ], outputs=[audio_output, metadata_output], concurrency_limit=1, ) unload_button.click(unload_models, inputs=hf_api_token_box, outputs=metadata_output) status_button.click(runtime_status, inputs=hf_api_token_box, outputs=metadata_output) with gr.Tab("Autoencoder"): with gr.Row(equal_height=False): with gr.Column(scale=2): ae_dropdown = gr.Dropdown(label="Autoencoder", choices=AE_CHOICES, value="same-s") ae_audio_input = gr.Audio(label="Input", sources=["upload", "microphone"], type="numpy") with gr.Row(): ae_chunked_box = gr.Checkbox(label="Chunked", value=True) ae_allow_cpu_box = gr.Checkbox(label="CPU override", value=False) ae_button = gr.Button("Round Trip", variant="primary") with gr.Column(scale=1): ae_output = gr.Audio(label="Decoded", type="filepath") ae_metadata = gr.JSON(label="Round-trip metadata") ae_button.click( roundtrip_autoencoder, inputs=[ae_dropdown, ae_audio_input, ae_chunked_box, ae_allow_cpu_box, hf_api_token_box], outputs=[ae_output, ae_metadata], concurrency_limit=1, ) with gr.Tab("Coverage"): gr.Dataframe( value=COLLECTION_ROWS, headers=["Collection entry", "Type", "Space path", "Status"], datatype=["str", "str", "str", "str"], interactive=False, wrap=True, ) gr.JSON(label="Runtime", value=runtime_status()) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch(css=css, ssr_mode=False)