owenisas's picture
Clean up optimization status metadata
b493d6c verified
from __future__ import annotations
import gc
import hashlib
import importlib
import importlib.util
import json
import os
import sys
import tempfile
import threading
import time
import urllib.error
import urllib.request
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import gradio as gr
import numpy as np
def _filter_known_unraisable(unraisable):
object_name = getattr(unraisable.object, "__qualname__", "")
if (
object_name == "BaseEventLoop.__del__"
and isinstance(unraisable.exc_value, ValueError)
and "Invalid file descriptor" in str(unraisable.exc_value)
):
return
sys.__unraisablehook__(unraisable)
sys.unraisablehook = _filter_known_unraisable
@dataclass(frozen=True)
class GenerationModel:
label: str
key: str
repo_id: str
family: str
default_prompt: str
default_duration: int
max_duration: int
default_steps: int
default_cfg: float
default_sampler: str
requires_cuda: bool = False
gated: bool = False
note: str = ""
GENERATION_MODELS: dict[str, GenerationModel] = {
"small-music": GenerationModel(
label="Stable Audio 3 Small Music",
key="small-music",
repo_id="stabilityai/stable-audio-3-small-music",
family="post-trained",
default_prompt=(
"Warm lo-fi house groove, soft sidechained pads, clean drums, "
"late-night atmosphere, 118 BPM"
),
default_duration=20,
max_duration=120,
default_steps=8,
default_cfg=1.0,
default_sampler="pingpong",
gated=True,
note="Lightweight music checkpoint.",
),
"small-sfx": GenerationModel(
label="Stable Audio 3 Small SFX",
key="small-sfx",
repo_id="stabilityai/stable-audio-3-small-sfx",
family="post-trained",
default_prompt="Close binaural rain on a window, soft cloth movement, detailed texture",
default_duration=8,
max_duration=120,
default_steps=8,
default_cfg=1.0,
default_sampler="pingpong",
gated=True,
note="Lightweight sound-effects checkpoint.",
),
"medium": GenerationModel(
label="Stable Audio 3 Medium",
key="medium",
repo_id="stabilityai/stable-audio-3-medium",
family="post-trained",
default_prompt=(
"Cinematic ambient electronic cue, deep sub pulse, shimmering stereo texture, "
"slow evolving melody"
),
default_duration=20,
max_duration=380,
default_steps=8,
default_cfg=1.0,
default_sampler="pingpong",
requires_cuda=True,
gated=True,
note="High-quality checkpoint; GPU-first.",
),
"small-music-base": GenerationModel(
label="Stable Audio 3 Small Music Base",
key="small-music-base",
repo_id="stabilityai/stable-audio-3-small-music-base",
family="base",
default_prompt="Dreamlike synthpop instrumental, surreal film sequence, 120 BPM",
default_duration=20,
max_duration=120,
default_steps=50,
default_cfg=7.0,
default_sampler="euler",
note="Base checkpoint intended mainly for fine-tuning.",
),
"small-sfx-base": GenerationModel(
label="Stable Audio 3 Small SFX Base",
key="small-sfx-base",
repo_id="stabilityai/stable-audio-3-small-sfx-base",
family="base",
default_prompt="Chugging train coming into station with horn",
default_duration=7,
max_duration=120,
default_steps=50,
default_cfg=7.0,
default_sampler="euler",
note="Base checkpoint intended mainly for fine-tuning.",
),
"medium-base": GenerationModel(
label="Stable Audio 3 Medium Base",
key="medium-base",
repo_id="stabilityai/stable-audio-3-medium-base",
family="base",
default_prompt="Dreamlike synthpop instrumental, surreal film sequence, 120 BPM",
default_duration=20,
max_duration=380,
default_steps=50,
default_cfg=7.0,
default_sampler="euler",
requires_cuda=True,
note="Base checkpoint intended mainly for fine-tuning; GPU-first.",
),
}
AUTOENCODER_MODELS = {
"same-s": {
"label": "SAME-S",
"repo_id": "stabilityai/SAME-S",
"requires_cuda": False,
},
"same-l": {
"label": "SAME-L",
"repo_id": "stabilityai/SAME-L",
"requires_cuda": True,
},
}
COLLECTION_ROWS = [
["stable-audio-3-small-music", "Text-to-audio", "Generate tab", "Gated post-trained small music"],
["stable-audio-3-small-sfx", "Text-to-audio", "Generate tab", "Gated post-trained small SFX"],
["stable-audio-3-medium", "Text-to-audio", "Generate tab", "Gated medium; GPU-first"],
["stable-audio-3-small-music-base", "Text-to-audio", "Generate tab", "Base checkpoint"],
["stable-audio-3-small-sfx-base", "Text-to-audio", "Generate tab", "Base checkpoint"],
["stable-audio-3-medium-base", "Text-to-audio", "Generate tab", "Base checkpoint; GPU-first"],
["stable-audio-3-optimized", "Optimized assets", "Listed only", "MLX/TensorRT artifacts, not generic PyTorch generation"],
["SAME-S", "Autoencoder", "Autoencoder tab", "CPU-capable round trip"],
["SAME-L", "Autoencoder", "Autoencoder tab", "Large autoencoder; CUDA recommended"],
]
MODEL_CACHE: dict[str, Any] = {"key": None, "model": None}
AE_CACHE: dict[str, Any] = {"key": None, "model": None}
ACCESS_CACHE: dict[tuple[str, str], float] = {}
ACCESS_CACHE_TTL_SECONDS = max(0, int(os.getenv("SA3_ACCESS_CACHE_TTL_SECONDS", "600")))
MODEL_LOAD_LOCK = threading.RLock()
def gpu_task(duration: int):
if os.getenv("SA3_USE_SPACES_GPU", "1").strip().lower() in {"0", "false", "no"}:
return lambda fn: fn
try:
import spaces
return spaces.GPU(duration=duration)
except Exception:
return lambda fn: fn
def import_torch():
return importlib.import_module("torch")
def current_device(torch_module: Any) -> str:
if torch_module.cuda.is_available():
return "cuda"
if hasattr(torch_module.backends, "mps") and torch_module.backends.mps.is_available():
return "mps"
return "cpu"
def flash_attn_available() -> bool:
return importlib.util.find_spec("flash_attn") is not None
def oauth_token_value(oauth_token: gr.OAuthToken | None) -> str | None:
token = getattr(oauth_token, "token", None)
return token if isinstance(token, str) and token else None
def hf_api_token_value(hf_api_token: str | None) -> str | None:
if not isinstance(hf_api_token, str):
return None
token = hf_api_token.strip()
return token or None
def request_token_value(
oauth_token: gr.OAuthToken | None,
hf_api_token: str | None,
) -> str | None:
return hf_api_token_value(hf_api_token) or oauth_token_value(oauth_token)
def oauth_username(oauth_profile: gr.OAuthProfile | None) -> str | None:
username = getattr(oauth_profile, "username", None)
return username if isinstance(username, str) and username else None
def auth_source(
oauth_profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
hf_api_token: str | None,
) -> str | None:
if hf_api_token_value(hf_api_token):
return "hf_token"
if oauth_profile is not None and oauth_token_value(oauth_token):
return "oauth"
return None
def stable_audio_token_hint(model: GenerationModel) -> str:
if not model.gated:
return "Sign in with Hugging Face or paste a Hugging Face access token before running this Space."
return (
"Sign in with Hugging Face or paste a Hugging Face access token from an "
"account that has accepted this gated model's terms."
)
def access_cache_key(repo_id: str, token: str) -> tuple[str, str]:
token_digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16]
return repo_id, token_digest
def user_can_download_gated_model(repo_id: str, token: str) -> tuple[bool, str | None]:
cache_key = access_cache_key(repo_id, token)
cached_until = ACCESS_CACHE.get(cache_key)
now = time.time()
if cached_until is not None:
if cached_until > now:
return True, None
ACCESS_CACHE.pop(cache_key, None)
request = urllib.request.Request(
f"https://huggingface.co/{repo_id}/resolve/main/model_config.json",
method="HEAD",
headers={"Authorization": f"Bearer {token}"},
)
try:
with urllib.request.urlopen(request, timeout=20) as response:
has_access = response.status < 400
if has_access and ACCESS_CACHE_TTL_SECONDS:
ACCESS_CACHE[cache_key] = time.time() + ACCESS_CACHE_TTL_SECONDS
return has_access, None
except urllib.error.HTTPError as exc:
if exc.code in {401, 403}:
return (
False,
"Your Hugging Face account does not have access to this gated model yet. "
"Open the model page while logged in, accept Stability's terms, then retry.",
)
return False, f"Hugging Face access check failed with HTTP {exc.code}."
except Exception as exc:
return False, f"Hugging Face access check failed: {exc!r}"
@contextmanager
def hub_download_token(token: str | None):
if not token:
yield
return
import stable_audio_3.model_configs as model_configs
original_download = model_configs.hf_hub_download
token_env_keys = ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN")
previous_token_env = {key: os.environ.get(key) for key in token_env_keys}
def download_with_user_token(*args, **kwargs):
kwargs.setdefault("token", token)
return original_download(*args, **kwargs)
model_configs.hf_hub_download = download_with_user_token
for key in token_env_keys:
os.environ[key] = token
try:
yield
finally:
model_configs.hf_hub_download = original_download
for key, previous in previous_token_env.items():
if previous is None:
os.environ.pop(key, None)
else:
os.environ[key] = previous
def generation_preflight_error(
model: GenerationModel,
allow_cpu_medium: bool,
oauth_profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
hf_api_token: str | None,
) -> tuple[str | None, str]:
device = "unknown"
token = request_token_value(oauth_token, hf_api_token)
if not token:
return (
"Sign in with Hugging Face or paste a Hugging Face access token before running this Space.",
device,
)
torch = import_torch()
device = current_device(torch)
if model.requires_cuda and device != "cuda" and not allow_cpu_medium:
return (
f"{model.label} is blocked on this runtime because CUDA is not available. "
"Use a GPU Space or enable the CPU override for a slow/debug-only attempt.",
device,
)
if model.gated:
has_access, error = user_can_download_gated_model(model.repo_id, token)
if not has_access:
return error or "Your Hugging Face account cannot access this gated model.", device
return None, device
def assert_generation_runtime(
model: GenerationModel,
allow_cpu_medium: bool,
oauth_profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
hf_api_token: str | None,
) -> str:
error, device = generation_preflight_error(
model,
allow_cpu_medium,
oauth_profile,
oauth_token,
hf_api_token,
)
if error:
raise gr.Error(error)
return device
def normalize_audio_array(data: np.ndarray) -> np.ndarray:
array = np.asarray(data)
if np.issubdtype(array.dtype, np.integer):
limit = max(abs(np.iinfo(array.dtype).min), np.iinfo(array.dtype).max)
array = array.astype(np.float32) / float(limit)
else:
array = array.astype(np.float32)
if array.ndim == 1:
array = array[None, :]
elif array.ndim == 2:
array = array.T
else:
raise gr.Error("Audio must be mono or stereo.")
return np.nan_to_num(array, nan=0.0, posinf=0.0, neginf=0.0)
def clear_torch_memory() -> None:
try:
torch = import_torch()
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
gc.collect()
def load_generation_model(
model_key: str,
allow_cpu_medium: bool,
oauth_profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
hf_api_token: str | None,
):
model_def = GENERATION_MODELS[model_key]
device = assert_generation_runtime(
model_def,
allow_cpu_medium,
oauth_profile,
oauth_token,
hf_api_token,
)
if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
return MODEL_CACHE["model"], device, True, 0.0
with MODEL_LOAD_LOCK:
if MODEL_CACHE["key"] == model_key and MODEL_CACHE["model"] is not None:
return MODEL_CACHE["model"], device, True, 0.0
load_started = time.time()
MODEL_CACHE["model"] = None
MODEL_CACHE["key"] = None
clear_torch_memory()
from stable_audio_3 import StableAudioModel
model_half = device == "cuda"
with hub_download_token(request_token_value(oauth_token, hf_api_token)):
model = StableAudioModel.from_pretrained(model_key, model_half=model_half)
MODEL_CACHE["key"] = model_key
MODEL_CACHE["model"] = model
return model, device, False, round(time.time() - load_started, 3)
def load_autoencoder(
model_key: str,
allow_cpu_same_l: bool,
oauth_profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
hf_api_token: str | None,
):
if not request_token_value(oauth_token, hf_api_token):
raise gr.Error("Sign in with Hugging Face or paste a Hugging Face access token before running this Space.")
model_def = AUTOENCODER_MODELS[model_key]
torch = import_torch()
device = current_device(torch)
if model_def["requires_cuda"] and device != "cuda" and not allow_cpu_same_l:
raise gr.Error(
f"{model_def['label']} is blocked on this runtime because CUDA is not available. "
"Use SAME-S or enable the CPU override for a slow/debug-only attempt."
)
if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
return AE_CACHE["model"], device, True, 0.0
with MODEL_LOAD_LOCK:
if AE_CACHE["key"] == model_key and AE_CACHE["model"] is not None:
return AE_CACHE["model"], device, True, 0.0
load_started = time.time()
AE_CACHE["model"] = None
AE_CACHE["key"] = None
clear_torch_memory()
from stable_audio_3 import AutoencoderModel
with hub_download_token(request_token_value(oauth_token, hf_api_token)):
model = AutoencoderModel.from_pretrained(model_key)
AE_CACHE["key"] = model_key
AE_CACHE["model"] = model
return model, device, False, round(time.time() - load_started, 3)
def model_changed(model_key: str):
model = GENERATION_MODELS[model_key]
return (
gr.update(value=model.default_prompt),
gr.update(value=model.default_duration, maximum=model.max_duration),
gr.update(value=model.default_steps),
gr.update(value=model.default_cfg),
gr.update(value=model.default_sampler),
{
"repo_id": model.repo_id,
"family": model.family,
"max_duration_s": model.max_duration,
"default_sampler": model.default_sampler,
"note": model.note,
"token_hint": stable_audio_token_hint(model),
},
)
@gpu_task(duration=int(os.getenv("SPACES_GENERATE_GPU_SECONDS", "900")))
def generate_audio(
model_key: str,
prompt: str,
negative_prompt: str,
duration: float,
steps: int,
cfg_scale: float,
sampler_type: str,
seed: int,
chunked_decode: bool,
allow_cpu_medium: bool,
hf_api_token: str | None,
oauth_profile: gr.OAuthProfile | None = None,
oauth_token: gr.OAuthToken | None = None,
progress=gr.Progress(track_tqdm=True),
):
model_def = GENERATION_MODELS[model_key]
if not prompt or not prompt.strip():
return None, {
"status": "blocked",
"error": "Prompt is required.",
"model": model_def.key,
"repo_id": model_def.repo_id,
}
preflight_error, preflight_device = generation_preflight_error(
model_def,
allow_cpu_medium,
oauth_profile,
oauth_token,
hf_api_token,
)
if preflight_error:
return None, {
"status": "blocked",
"error": preflight_error,
"model": model_def.key,
"repo_id": model_def.repo_id,
"device": preflight_device,
"authenticated": bool(request_token_value(oauth_token, hf_api_token)),
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
"oauth_signed_in": oauth_profile is not None,
"username": oauth_username(oauth_profile),
"oauth_token_present": bool(oauth_token_value(oauth_token)),
"hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
}
progress(0.05, desc="Loading model")
started = time.time()
seed = int(seed)
if seed < 0:
seed = int.from_bytes(os.urandom(4), "little") % 100000
model, device, cache_hit, load_elapsed = load_generation_model(
model_key,
allow_cpu_medium,
oauth_profile,
oauth_token,
hf_api_token,
)
progress(0.25, desc="Generating")
audio = model.generate(
prompt=prompt.strip(),
negative_prompt=negative_prompt.strip() or None,
duration=float(duration),
steps=int(steps),
cfg_scale=float(cfg_scale),
seed=seed,
sampler_type=sampler_type,
chunked_decode=bool(chunked_decode),
)
progress(0.9, desc="Writing WAV")
import torchaudio
sample_rate = int(model.model_config["sample_rate"])
waveform = audio[0].detach().to("cpu").float().clamp(-1, 1)
out_file = tempfile.NamedTemporaryFile(prefix=f"{model_key}-", suffix=".wav", delete=False)
out_file.close()
torchaudio.save(out_file.name, waveform, sample_rate)
elapsed = round(time.time() - started, 3)
metadata = {
"status": "ok",
"model": model_def.key,
"repo_id": model_def.repo_id,
"family": model_def.family,
"device": device,
"duration_s": float(duration),
"steps": int(steps),
"cfg_scale": float(cfg_scale),
"sampler_type": sampler_type,
"seed": seed,
"sample_rate": sample_rate,
"elapsed_s": elapsed,
"cache_hit": cache_hit,
"load_elapsed_s": load_elapsed,
"output_file": out_file.name,
"note": model_def.note,
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
"username": oauth_username(oauth_profile),
}
return out_file.name, metadata
@gpu_task(duration=int(os.getenv("SPACES_AUTOENCODER_GPU_SECONDS", "600")))
def roundtrip_autoencoder(
model_key: str,
audio_input: tuple[int, np.ndarray] | None,
chunked: bool,
allow_cpu_same_l: bool,
hf_api_token: str | None,
oauth_profile: gr.OAuthProfile | None = None,
oauth_token: gr.OAuthToken | None = None,
progress=gr.Progress(track_tqdm=True),
):
if not request_token_value(oauth_token, hf_api_token):
return None, {
"status": "blocked",
"error": "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.",
"autoencoder": model_key,
"repo_id": AUTOENCODER_MODELS[model_key]["repo_id"],
"authenticated": bool(request_token_value(oauth_token, hf_api_token)),
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
"oauth_signed_in": oauth_profile is not None,
"hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
}
if audio_input is None:
return None, {
"status": "blocked",
"error": "Upload or record audio first.",
"autoencoder": model_key,
"repo_id": AUTOENCODER_MODELS[model_key]["repo_id"],
}
model_def = AUTOENCODER_MODELS[model_key]
torch = import_torch()
device = current_device(torch)
if model_def["requires_cuda"] and device != "cuda" and not allow_cpu_same_l:
return None, {
"status": "blocked",
"error": (
f"{model_def['label']} is blocked on this runtime because CUDA is not available. "
"Use SAME-S or enable the CPU override for a slow/debug-only attempt."
),
"autoencoder": model_key,
"repo_id": model_def["repo_id"],
"device": device,
}
progress(0.05, desc="Loading autoencoder")
started = time.time()
model, device, cache_hit, load_elapsed = load_autoencoder(
model_key,
allow_cpu_same_l,
oauth_profile,
oauth_token,
hf_api_token,
)
progress(0.25, desc="Encoding")
sr, data = audio_input
waveform_np = normalize_audio_array(data)
torch = import_torch()
waveform = torch.from_numpy(waveform_np)
latents = model.encode(waveform, int(sr), chunked=bool(chunked))
progress(0.65, desc="Decoding")
decoded = model.decode(latents, chunked=bool(chunked))
decoded = decoded[0].detach().to("cpu").float().clamp(-1, 1)
import torchaudio
out_file = tempfile.NamedTemporaryFile(prefix=f"{model_key}-roundtrip-", suffix=".wav", delete=False)
out_file.close()
torchaudio.save(out_file.name, decoded, int(model.sample_rate))
metadata = {
"status": "ok",
"autoencoder": model_key,
"repo_id": AUTOENCODER_MODELS[model_key]["repo_id"],
"device": device,
"input_sample_rate": int(sr),
"output_sample_rate": int(model.sample_rate),
"input_shape": list(waveform.shape),
"latent_shape": list(latents.shape),
"elapsed_s": round(time.time() - started, 3),
"cache_hit": cache_hit,
"load_elapsed_s": load_elapsed,
"output_file": out_file.name,
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
"username": oauth_username(oauth_profile),
}
return out_file.name, metadata
def unload_models(
hf_api_token: str | None = None,
oauth_profile: gr.OAuthProfile | None = None,
oauth_token: gr.OAuthToken | None = None,
):
if not request_token_value(oauth_token, hf_api_token):
return {
"status": "blocked",
"error": "Sign in with Hugging Face or paste a Hugging Face access token before running this Space.",
}
MODEL_CACHE["key"] = None
MODEL_CACHE["model"] = None
AE_CACHE["key"] = None
AE_CACHE["model"] = None
clear_torch_memory()
return {
"status": "unloaded",
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
"username": oauth_username(oauth_profile),
}
def runtime_status(
hf_api_token: str | None = None,
oauth_profile: gr.OAuthProfile | None = None,
oauth_token: gr.OAuthToken | None = None,
):
try:
torch = import_torch()
device = current_device(torch)
cuda_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
except Exception as exc:
device = "unavailable"
cuda_name = None
return {"torch": repr(exc), "device": device}
return {
"device": device,
"cuda_name": cuda_name,
"flash_attn": flash_attn_available(),
"authenticated": bool(request_token_value(oauth_token, hf_api_token)),
"auth_source": auth_source(oauth_profile, oauth_token, hf_api_token),
"oauth_signed_in": oauth_profile is not None,
"username": oauth_username(oauth_profile),
"oauth_token_present": bool(oauth_token_value(oauth_token)),
"hf_api_token_present": bool(hf_api_token_value(hf_api_token)),
"loaded_generation_model": MODEL_CACHE["key"],
"loaded_autoencoder": AE_CACHE["key"],
}
MODEL_CHOICES = [(model.label, key) for key, model in GENERATION_MODELS.items()]
AE_CHOICES = [(value["label"], key) for key, value in AUTOENCODER_MODELS.items()]
SAMPLER_CHOICES = ["pingpong", "euler", "rk4", "dpmpp", "dpmpp-3m-sde"]
css = """
.gradio-container { max-width: 1160px !important; }
#run-buttons button { min-height: 42px; }
"""
with gr.Blocks(title="Stable Audio 3 Lab") as demo:
gr.Markdown("# Stable Audio 3 Lab")
gr.LoginButton(value="Sign in with Hugging Face", logout_value="Logout ({})")
hf_api_token_box = gr.Textbox(
label="Hugging Face access token",
type="password",
placeholder="hf_...",
lines=1,
value="",
info="Optional fallback for API use or browsers where OAuth is unavailable. Use a read token from an account with access to the selected Stability AI model.",
)
with gr.Tab("Generate"):
with gr.Row(equal_height=False):
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
label="Model",
choices=MODEL_CHOICES,
value="small-sfx",
interactive=True,
)
prompt_box = gr.Textbox(
label="Prompt",
value=GENERATION_MODELS["small-sfx"].default_prompt,
lines=4,
)
negative_prompt_box = gr.Textbox(label="Negative prompt", lines=2)
with gr.Row():
duration_slider = gr.Slider(
label="Duration",
minimum=1,
maximum=GENERATION_MODELS["small-sfx"].max_duration,
value=GENERATION_MODELS["small-sfx"].default_duration,
step=1,
)
steps_slider = gr.Slider(
label="Steps",
minimum=1,
maximum=100,
value=GENERATION_MODELS["small-sfx"].default_steps,
step=1,
)
cfg_slider = gr.Slider(
label="CFG",
minimum=0,
maximum=12,
value=GENERATION_MODELS["small-sfx"].default_cfg,
step=0.1,
)
with gr.Row():
sampler_dropdown = gr.Dropdown(
label="Sampler",
choices=SAMPLER_CHOICES,
value=GENERATION_MODELS["small-sfx"].default_sampler,
)
seed_number = gr.Number(label="Seed", value=-1, precision=0)
with gr.Row():
chunked_decode_box = gr.Checkbox(label="Chunked decode", value=True)
allow_cpu_medium_box = gr.Checkbox(label="CPU override", value=False)
with gr.Row(elem_id="run-buttons"):
generate_button = gr.Button("Generate", variant="primary")
unload_button = gr.Button("Unload")
status_button = gr.Button("Runtime")
with gr.Column(scale=1):
model_info = gr.JSON(
label="Model info",
value={
"repo_id": GENERATION_MODELS["small-sfx"].repo_id,
"family": GENERATION_MODELS["small-sfx"].family,
"note": GENERATION_MODELS["small-sfx"].note,
"token_hint": stable_audio_token_hint(GENERATION_MODELS["small-sfx"]),
},
)
audio_output = gr.Audio(label="Output", type="filepath")
metadata_output = gr.JSON(label="Run metadata")
model_dropdown.change(
model_changed,
inputs=model_dropdown,
outputs=[
prompt_box,
duration_slider,
steps_slider,
cfg_slider,
sampler_dropdown,
model_info,
],
)
generate_button.click(
generate_audio,
inputs=[
model_dropdown,
prompt_box,
negative_prompt_box,
duration_slider,
steps_slider,
cfg_slider,
sampler_dropdown,
seed_number,
chunked_decode_box,
allow_cpu_medium_box,
hf_api_token_box,
],
outputs=[audio_output, metadata_output],
concurrency_limit=1,
)
unload_button.click(unload_models, inputs=hf_api_token_box, outputs=metadata_output)
status_button.click(runtime_status, inputs=hf_api_token_box, outputs=metadata_output)
with gr.Tab("Autoencoder"):
with gr.Row(equal_height=False):
with gr.Column(scale=2):
ae_dropdown = gr.Dropdown(label="Autoencoder", choices=AE_CHOICES, value="same-s")
ae_audio_input = gr.Audio(label="Input", sources=["upload", "microphone"], type="numpy")
with gr.Row():
ae_chunked_box = gr.Checkbox(label="Chunked", value=True)
ae_allow_cpu_box = gr.Checkbox(label="CPU override", value=False)
ae_button = gr.Button("Round Trip", variant="primary")
with gr.Column(scale=1):
ae_output = gr.Audio(label="Decoded", type="filepath")
ae_metadata = gr.JSON(label="Round-trip metadata")
ae_button.click(
roundtrip_autoencoder,
inputs=[ae_dropdown, ae_audio_input, ae_chunked_box, ae_allow_cpu_box, hf_api_token_box],
outputs=[ae_output, ae_metadata],
concurrency_limit=1,
)
with gr.Tab("Coverage"):
gr.Dataframe(
value=COLLECTION_ROWS,
headers=["Collection entry", "Type", "Space path", "Status"],
datatype=["str", "str", "str", "str"],
interactive=False,
wrap=True,
)
gr.JSON(label="Runtime", value=runtime_status())
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(css=css, ssr_mode=False)