File size: 9,163 Bytes
5e77923
 
 
 
 
 
 
 
 
 
 
 
 
72c410e
 
5e77923
 
72c410e
5e77923
 
72c410e
 
 
 
 
 
5728513
 
 
 
 
72c410e
5728513
 
72c410e
 
 
 
b134d23
72c410e
 
5e77923
d82866b
5e77923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d276c0e
5e77923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d82866b
 
5e77923
 
 
 
 
 
 
 
 
 
d276c0e
5e77923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d276c0e
5e77923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""ZeroGPU Gradio demo for Stable Audio 3 — Medium, Small Music, Small SFX.

All three models are preloaded at module level (per the ZeroGPU contract), and
a radio selector picks which one runs inside the ``@spaces.GPU`` infer call.
The visible UI mirrors the high-level ``stable_audio_3`` defaults (prompt +
duration); steps / CFG / sampler / seed live in an Advanced accordion.
"""

from __future__ import annotations

import spaces  # noqa: F401

import os
import subprocess
import sys
import tempfile
import time
import types
from dataclasses import dataclass

def _ensure_stable_audio_tools() -> None:
    try:
        import stable_audio_tools  # noqa: F401
        return
    except ImportError:
        pass
    # stable-audio-tools 0.0.20 strict-pins torch==2.7.1 / torchaudio==2.7.1,
    # which lack sm_120 (Blackwell) kernels. Install with --no-deps; the
    # transitive deps are listed in requirements.txt and resolved against the
    # sm_120-capable torch at build time.
    print("[startup] installing stable-audio-tools (--no-deps) …", flush=True)
    subprocess.check_call(
        [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps",
         "stable-audio-tools"],
    )
    import stable_audio_tools  # noqa: F401
    print("[startup] stable-audio-tools installed.", flush=True)

_ensure_stable_audio_tools()


import gradio as gr
import soundfile as sf
import torch
from einops import rearrange

from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond_inpaint


# ---------------------------------------------------------------------------
# Variants
# ---------------------------------------------------------------------------


@dataclass
class Variant:
    key: str
    repo: str
    label: str
    default_duration: int
    placeholder: str


VARIANTS: list[Variant] = [
    Variant(
        key="medium",
        repo="stabilityai/stable-audio-3-medium",
        label="Medium — general audio (largest)",
        default_duration=60,
        placeholder="A dream-like Synthpop instrumental that would accompany a dream-sequence in a surrealist movie 120 BPM",
    ),
    Variant(
        key="small-music",
        repo="stabilityai/stable-audio-3-small-music",
        label="Small Music — 0.6B, music-focused",
        default_duration=60,
        placeholder="Cinematic neo-soul groove with electric piano, brushed drums, walking upright bass, smoky vibe 92 BPM",
    ),
    Variant(
        key="small-sfx",
        repo="stabilityai/stable-audio-3-small-sfx",
        label="Small SFX — 0.6B, sound effects",
        default_duration=7,
        placeholder="Chugging train coming into station with horn",
    ),
]


# ---------------------------------------------------------------------------
# Preload all variants at module level (ZeroGPU CUDA emulation accepts it)
# ---------------------------------------------------------------------------

@dataclass
class LoadedVariant:
    variant: Variant
    model: object
    sample_rate: int
    sample_size: int
    max_seconds: int


LOADED: dict[str, LoadedVariant] = {}
for v in VARIANTS:
    print(f"[startup] loading {v.repo} …", flush=True)
    t0 = time.time()
    model, config = get_pretrained_model(v.repo)
    sr = int(config["sample_rate"])
    ss = int(config["sample_size"])
    model = model.to("cuda").to(torch.float16)
    LOADED[v.key] = LoadedVariant(
        variant=v,
        model=model,
        sample_rate=sr,
        sample_size=ss,
        max_seconds=ss // sr,
    )
    print(
        f"[startup] {v.key} ready in {time.time() - t0:.1f}s · "
        f"sr={sr} · sample_size={ss} (~{ss // sr}s max)",
        flush=True,
    )

VARIANT_CHOICES = [(v.label, v.key) for v in VARIANTS]
SAMPLERS = ["pingpong", "k-dpmpp-2m", "k-heun", "dpmpp-2s-ancestral", "dpmpp-3m-sde"]


# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------


@spaces.GPU
def infer(
    variant_key: str,
    prompt: str,
    duration: int = 60,
    steps: int = 8,
    cfg_scale: float = 1.0,
    sampler_type: str = "pingpong",
    seed: int = 0,
    progress: gr.Progress = gr.Progress(),
):
    prompt = (prompt or "").strip()
    if not prompt:
        raise gr.Error("Please enter a prompt.")
    if variant_key not in LOADED:
        raise gr.Error(f"Unknown variant {variant_key!r}.")
    lv = LOADED[variant_key]

    duration = max(1, min(int(duration), lv.max_seconds))

    progress(0.1, desc=f"[{variant_key}] preparing conditioning")
    conditioning = [{"prompt": prompt, "seconds_total": int(duration)}]

    if seed and int(seed) > 0:
        torch.manual_seed(int(seed))
    else:
        torch.seed()

    progress(0.25, desc=f"[{variant_key}] sampling {steps} steps with {sampler_type}")
    t0 = time.time()
    output = generate_diffusion_cond_inpaint(
        lv.model,
        steps=int(steps),
        cfg_scale=float(cfg_scale),
        conditioning=conditioning,
        sample_size=lv.sample_size,
        sampler_type=sampler_type,
        device="cuda",
    )
    print(f"[infer/{variant_key}] sampling done in {time.time() - t0:.1f}s", flush=True)

    progress(0.92, desc="Normalising & saving")
    output = rearrange(output, "b d n -> d (b n)")
    output = (
        output.to(torch.float32)
        .div(torch.max(torch.abs(output)).clamp(min=1e-9))
        .clamp(-1, 1)
        .mul(32767)
        .to(torch.int16)
        .cpu()
    )
    output = output[:, : int(duration) * lv.sample_rate]

    out_path = os.path.join(tempfile.mkdtemp(), f"sa3_{variant_key}.wav")
    # soundfile expects (samples, channels); our tensor is (channels, samples).
    sf.write(out_path, output.numpy().T, lv.sample_rate, subtype="PCM_16")
    return out_path


# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------

DESCRIPTION = """
# 🎵 Stable Audio 3

Text-to-audio generation with <a href="https://huggingface.co/collections/stabilityai/stable-audio-3" target="_blank" rel="noopener noreferrer">Stable Audio 3</a>. Pick a variant, write a prompt, hit Generate.
"""

EXAMPLES = [
    ["medium",      "House music that encapsulates the feeling of being at a festival in the sunny weather with all your friends 124 BPM", 60],
    ["small-music", "Cinematic neo-soul groove with electric piano, brushed drums, walking upright bass, smoky vibe 92 BPM", 45],
    ["small-music", "Driving techno track with rolling 16th-note hats, deep sub bass, acid arpeggios building tension 132 BPM", 60],
    ["small-sfx",   "Chugging train coming into station with horn", 7],
    ["small-sfx",   "Heavy rain on a tin roof with distant thunder rolls", 10],
    ["medium",      "Rainy night, lo-fi hip-hop beat with vinyl crackle, mellow piano chords, soft kick and snare 80 BPM", 30],
]


def _on_variant_change(variant_key: str):
    lv = LOADED[variant_key]
    return (
        gr.update(maximum=lv.max_seconds, value=min(lv.variant.default_duration, lv.max_seconds),
                  label=f"Duration (s) · model max {lv.max_seconds}s"),
        gr.update(placeholder=lv.variant.placeholder),
    )


with gr.Blocks(theme=gr.themes.Citrus(), title="Stable Audio 3") as demo:
    gr.Markdown(DESCRIPTION)

    variant = gr.Radio(
        choices=VARIANT_CHOICES,
        value=VARIANTS[0].key,
        label="Model",
    )

    with gr.Row():
        with gr.Column(scale=2):
            prompt = gr.Textbox(
                label="Prompt",
                placeholder=VARIANTS[0].placeholder,
                lines=3,
            )
            duration = gr.Slider(
                1, LOADED[VARIANTS[0].key].max_seconds,
                value=VARIANTS[0].default_duration, step=1,
                label=f"Duration (s) · model max {LOADED[VARIANTS[0].key].max_seconds}s",
            )
            with gr.Accordion("Advanced settings", open=False):
                steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
                cfg_scale = gr.Slider(0.5, 8.0, value=1.0, step=0.1, label="CFG scale")
                sampler_type = gr.Dropdown(SAMPLERS, value="pingpong", label="Sampler")
                seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
            run_btn = gr.Button("🎼 Generate", variant="primary", size="lg")

        with gr.Column(scale=1):
            audio_out = gr.Audio(label="Output", type="filepath", autoplay=True)

    gr.Examples(
        examples=EXAMPLES,
        inputs=[variant, prompt, duration],
        outputs=[audio_out],
        fn=infer,
        cache_examples=True,
        cache_mode="lazy",
        label="Examples (lazy-cached on first click)",
    )

    variant.change(
        fn=_on_variant_change,
        inputs=[variant],
        outputs=[duration, prompt],
    )

    run_btn.click(
        fn=infer,
        inputs=[variant, prompt, duration, steps, cfg_scale, sampler_type, seed],
        outputs=[audio_out],
    )


if __name__ == "__main__":
    demo.launch()