multimodalart HF Staff commited on
Commit
232ab2a
·
1 Parent(s): d276c0e

Add Advanced tab mirroring reference repo UI

Browse files

Wraps the existing UI in a Simple tab and adds an Advanced tab that
mirrors stable_audio_3/interface/diffusion_cond.py: negative prompt,
sampler params (sigma_max, APG, duration padding), init audio + noise
level, inpainting (audio + mask start/end), output spectrogram gallery,
and send-to-init / send-to-inpaint buttons. SAMPLERS narrowed to those
valid for rf_denoiser. Inpaint/init audio is pre-resampled to model SR
and cast to model dtype to avoid fp16/fp32 mismatches in the pretransform
encoder. matplotlib/Pillow added for the mel-spectrogram helper.

Files changed (2) hide show
  1. app.py +497 -78
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,9 +1,14 @@
1
  """ZeroGPU Gradio demo for Stable Audio 3 — Medium, Small Music, Small SFX.
2
 
3
- All three models are preloaded at module level (per the ZeroGPU contract), and
4
- a radio selector picks which one runs inside the ``@spaces.GPU`` infer call.
5
- The visible UI mirrors the high-level ``stable_audio_3`` defaults (prompt +
6
- duration); steps / CFG / sampler / seed live in an Advanced accordion.
 
 
 
 
 
7
  """
8
 
9
  from __future__ import annotations
@@ -15,8 +20,8 @@ import subprocess
15
  import sys
16
  import tempfile
17
  import time
18
- import types
19
  from dataclasses import dataclass
 
20
 
21
  def _ensure_stable_audio_tools() -> None:
22
  try:
@@ -40,9 +45,15 @@ _ensure_stable_audio_tools()
40
 
41
 
42
  import gradio as gr
 
43
  import soundfile as sf
44
  import torch
 
 
45
  from einops import rearrange
 
 
 
46
 
47
  from stable_audio_tools import get_pretrained_model
48
  from stable_audio_tools.inference.generation import generate_diffusion_cond_inpaint
@@ -122,7 +133,46 @@ for v in VARIANTS:
122
  )
123
 
124
  VARIANT_CHOICES = [(v.label, v.key) for v in VARIANTS]
125
- SAMPLERS = ["pingpong", "k-dpmpp-2m", "k-heun", "dpmpp-2s-ancestral", "dpmpp-3m-sde"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  # ---------------------------------------------------------------------------
@@ -130,63 +180,254 @@ SAMPLERS = ["pingpong", "k-dpmpp-2m", "k-heun", "dpmpp-2s-ancestral", "dpmpp-3m-
130
  # ---------------------------------------------------------------------------
131
 
132
 
133
- @spaces.GPU
134
- def infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  variant_key: str,
136
  prompt: str,
 
137
  duration: int = 60,
138
  steps: int = 8,
139
  cfg_scale: float = 1.0,
140
  sampler_type: str = "pingpong",
141
  seed: int = 0,
 
 
 
 
 
 
 
 
 
 
 
142
  progress: gr.Progress = gr.Progress(),
143
  ):
 
 
144
  prompt = (prompt or "").strip()
145
  if not prompt:
146
  raise gr.Error("Please enter a prompt.")
147
  if variant_key not in LOADED:
148
  raise gr.Error(f"Unknown variant {variant_key!r}.")
149
  lv = LOADED[variant_key]
150
-
151
  duration = max(1, min(int(duration), lv.max_seconds))
152
 
153
- progress(0.1, desc=f"[{variant_key}] preparing conditioning")
154
  conditioning = [{"prompt": prompt, "seconds_total": int(duration)}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- if seed and int(seed) > 0:
157
- torch.manual_seed(int(seed))
158
- else:
159
- torch.seed()
160
-
161
- progress(0.25, desc=f"[{variant_key}] sampling {steps} steps with {sampler_type}")
162
- t0 = time.time()
163
- output = generate_diffusion_cond_inpaint(
164
- lv.model,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  steps=int(steps),
166
  cfg_scale=float(cfg_scale),
167
  conditioning=conditioning,
 
168
  sample_size=lv.sample_size,
169
  sampler_type=sampler_type,
 
170
  device="cuda",
 
 
 
171
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  print(f"[infer/{variant_key}] sampling done in {time.time() - t0:.1f}s", flush=True)
173
 
174
  progress(0.92, desc="Normalising & saving")
175
- output = rearrange(output, "b d n -> d (b n)")
176
- output = (
177
- output.to(torch.float32)
178
- .div(torch.max(torch.abs(output)).clamp(min=1e-9))
179
- .clamp(-1, 1)
180
- .mul(32767)
181
- .to(torch.int16)
182
- .cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
- output = output[:, : int(duration) * lv.sample_rate]
185
 
186
- out_path = os.path.join(tempfile.mkdtemp(), f"sa3_{variant_key}.wav")
187
- # soundfile expects (samples, channels); our tensor is (channels, samples).
188
- sf.write(out_path, output.numpy().T, lv.sample_rate, subtype="PCM_16")
189
- return out_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
 
192
  # ---------------------------------------------------------------------------
@@ -196,7 +437,7 @@ def infer(
196
  DESCRIPTION = """
197
  # 🎵 Stable Audio 3
198
 
199
- Text-to-audio generation with <a href="https://huggingface.co/collections/stabilityai/stable-audio-3" target="_blank" rel="noopener noreferrer">Stable Audio 3</a>. Pick a variant, write a prompt, hit Generate.
200
  """
201
 
202
  EXAMPLES = [
@@ -209,7 +450,7 @@ EXAMPLES = [
209
  ]
210
 
211
 
212
- def _on_variant_change(variant_key: str):
213
  lv = LOADED[variant_key]
214
  return (
215
  gr.update(maximum=lv.max_seconds, value=min(lv.variant.default_duration, lv.max_seconds),
@@ -218,58 +459,236 @@ def _on_variant_change(variant_key: str):
218
  )
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  with gr.Blocks(theme=gr.themes.Citrus(), title="Stable Audio 3") as demo:
222
  gr.Markdown(DESCRIPTION)
223
 
224
- variant = gr.Radio(
225
- choices=VARIANT_CHOICES,
226
- value=VARIANTS[0].key,
227
- label="Model",
228
- )
 
 
 
 
 
229
 
230
- with gr.Row():
231
- with gr.Column(scale=2):
232
- prompt = gr.Textbox(
233
- label="Prompt",
234
- placeholder=VARIANTS[0].placeholder,
235
- lines=3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  )
237
- duration = gr.Slider(
238
- 1, LOADED[VARIANTS[0].key].max_seconds,
239
- value=VARIANTS[0].default_duration, step=1,
240
- label=f"Duration (s) · model max {LOADED[VARIANTS[0].key].max_seconds}s",
 
241
  )
242
- with gr.Accordion("Advanced settings", open=False):
243
- steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
244
- cfg_scale = gr.Slider(0.5, 8.0, value=1.0, step=0.1, label="CFG scale")
245
- sampler_type = gr.Dropdown(SAMPLERS, value="pingpong", label="Sampler")
246
- seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
247
- run_btn = gr.Button("🎼 Generate", variant="primary", size="lg")
248
-
249
- with gr.Column(scale=1):
250
- audio_out = gr.Audio(label="Output", type="filepath", autoplay=True)
251
-
252
- gr.Examples(
253
- examples=EXAMPLES,
254
- inputs=[variant, prompt, duration],
255
- outputs=[audio_out],
256
- fn=infer,
257
- cache_examples=True,
258
- cache_mode="lazy",
259
- label="Examples (lazy-cached on first click)",
260
- )
261
 
262
- variant.change(
263
- fn=_on_variant_change,
264
- inputs=[variant],
265
- outputs=[duration, prompt],
266
- )
267
 
268
- run_btn.click(
269
- fn=infer,
270
- inputs=[variant, prompt, duration, steps, cfg_scale, sampler_type, seed],
271
- outputs=[audio_out],
272
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
 
275
  if __name__ == "__main__":
 
1
  """ZeroGPU Gradio demo for Stable Audio 3 — Medium, Small Music, Small SFX.
2
 
3
+ Two tabs:
4
+
5
+ * **Simple** prompt + duration with a slim Advanced accordion (steps/CFG/seed
6
+ /sampler). Mirrors the original tiny UI.
7
+ * **Advanced** — replicates the reference repo's
8
+ ``stable_audio_3/interface/diffusion_cond.py`` controls: negative prompt,
9
+ sampler params (sigma_max, APG, duration padding), init audio + noise level,
10
+ inpainting with mask start/end, spectrogram gallery, send-to-init /
11
+ send-to-inpaint buttons.
12
  """
13
 
14
  from __future__ import annotations
 
20
  import sys
21
  import tempfile
22
  import time
 
23
  from dataclasses import dataclass
24
+ from typing import Optional, Tuple
25
 
26
  def _ensure_stable_audio_tools() -> None:
27
  try:
 
45
 
46
 
47
  import gradio as gr
48
+ import numpy as np
49
  import soundfile as sf
50
  import torch
51
+ import torchaudio
52
+ import torchaudio.transforms as T
53
  from einops import rearrange
54
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
55
+ from matplotlib.figure import Figure
56
+ from PIL import Image
57
 
58
  from stable_audio_tools import get_pretrained_model
59
  from stable_audio_tools.inference.generation import generate_diffusion_cond_inpaint
 
133
  )
134
 
135
  VARIANT_CHOICES = [(v.label, v.key) for v in VARIANTS]
136
+ # Samplers valid for rf_denoiser diffusion objective (the SA3 family).
137
+ SAMPLERS = ["pingpong", "euler", "rk4", "dpmpp"]
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Spectrogram helper (Mel; adapted from the reference repo's aeiou.py)
142
+ # ---------------------------------------------------------------------------
143
+
144
+
145
+ def _power_to_db(spec: np.ndarray, amin: float = 1e-10) -> np.ndarray:
146
+ return 10.0 * np.log10(np.maximum(amin, spec))
147
+
148
+
149
+ def audio_spectrogram_image(
150
+ waveform: torch.Tensor,
151
+ sample_rate: int,
152
+ db_range=(35, 120),
153
+ figsize=(5, 4),
154
+ ) -> Image.Image:
155
+ """Render a Mel spectrogram (left channel) as a PIL image."""
156
+ if waveform.dim() == 1:
157
+ waveform = waveform.unsqueeze(0)
158
+ n_fft = 1024
159
+ hop_length = n_fft // 2
160
+ mel_op = T.MelSpectrogram(
161
+ sample_rate=sample_rate, n_fft=n_fft, win_length=None,
162
+ hop_length=hop_length, center=True, pad_mode="reflect", power=2.0,
163
+ norm="slaney", onesided=True, n_mels=128, mel_scale="htk",
164
+ )
165
+ melspec = mel_op(waveform.float())[0] # left channel
166
+ fig = Figure(figsize=figsize, dpi=100)
167
+ canvas = FigureCanvasAgg(fig)
168
+ ax = fig.add_subplot()
169
+ ax.imshow(_power_to_db(melspec.numpy()), origin="lower", aspect="auto",
170
+ vmin=db_range[0], vmax=db_range[1])
171
+ ax.set_ylabel("mel bins (log freq)")
172
+ ax.set_xlabel("frame")
173
+ ax.set_title("MelSpectrogram")
174
+ canvas.draw()
175
+ return Image.fromarray(np.asarray(canvas.buffer_rgba()))
176
 
177
 
178
  # ---------------------------------------------------------------------------
 
180
  # ---------------------------------------------------------------------------
181
 
182
 
183
+ def _gradio_audio_to_tensor(
184
+ audio_in: Optional[Tuple[int, np.ndarray]],
185
+ ) -> Optional[Tuple[int, torch.Tensor]]:
186
+ """Convert a gr.Audio (numpy) value to the (sr, torch.Tensor[C,N]) tuple
187
+ that ``generate_diffusion_cond_inpaint`` expects. Accepts mono or stereo."""
188
+ if audio_in is None:
189
+ return None
190
+ sr, arr = audio_in
191
+ if arr is None or (hasattr(arr, "size") and arr.size == 0):
192
+ return None
193
+ arr = np.asarray(arr)
194
+ if arr.dtype.kind in ("i", "u"):
195
+ max_val = float(np.iinfo(arr.dtype).max)
196
+ arr = arr.astype(np.float32) / max_val
197
+ else:
198
+ arr = arr.astype(np.float32)
199
+ if arr.ndim == 1:
200
+ arr = arr[None, :] # (1, N)
201
+ else:
202
+ # gr.Audio returns (N, C); transpose to (C, N)
203
+ arr = arr.T if arr.shape[0] > arr.shape[1] else arr
204
+ return int(sr), torch.from_numpy(arr)
205
+
206
+
207
+ def _tensor_to_wav(
208
+ output: torch.Tensor,
209
+ sample_rate: int,
210
+ duration_seconds: Optional[int],
211
+ out_dir: Optional[str] = None,
212
+ ) -> Tuple[str, torch.Tensor]:
213
+ """Pack a (B, C, N) generation tensor to int16, optionally cut to duration,
214
+ write to disk, and return (path, int16-tensor)."""
215
+ output = rearrange(output, "b d n -> d (b n)")
216
+ output = (
217
+ output.to(torch.float32)
218
+ .div(torch.max(torch.abs(output)).clamp(min=1e-9))
219
+ .clamp(-1, 1)
220
+ .mul(32767)
221
+ .to(torch.int16)
222
+ .cpu()
223
+ )
224
+ if duration_seconds is not None:
225
+ output = output[:, : int(duration_seconds) * sample_rate]
226
+ out_dir = out_dir or tempfile.mkdtemp()
227
+ out_path = os.path.join(out_dir, "sa3.wav")
228
+ sf.write(out_path, output.numpy().T, sample_rate, subtype="PCM_16")
229
+ return out_path, output
230
+
231
+
232
+ def _run_inference(
233
  variant_key: str,
234
  prompt: str,
235
+ negative_prompt: str = "",
236
  duration: int = 60,
237
  steps: int = 8,
238
  cfg_scale: float = 1.0,
239
  sampler_type: str = "pingpong",
240
  seed: int = 0,
241
+ sigma_max: float = 1.0,
242
+ apg_scale: float = 1.0,
243
+ duration_padding_sec: float = 6.0,
244
+ cut_to_seconds_total: bool = True,
245
+ init_audio: Optional[Tuple[int, np.ndarray]] = None,
246
+ init_noise_level: float = 0.9,
247
+ inpaint_audio: Optional[Tuple[int, np.ndarray]] = None,
248
+ mask_start_sec: float = 0.0,
249
+ mask_end_sec: float = 0.0,
250
+ preview_every: int = 0,
251
+ return_spectrogram: bool = True,
252
  progress: gr.Progress = gr.Progress(),
253
  ):
254
+ """Full-featured generation. Returns (audio_path, [spectrogram_img, *previews])
255
+ when ``return_spectrogram`` is True, else just ``audio_path``."""
256
  prompt = (prompt or "").strip()
257
  if not prompt:
258
  raise gr.Error("Please enter a prompt.")
259
  if variant_key not in LOADED:
260
  raise gr.Error(f"Unknown variant {variant_key!r}.")
261
  lv = LOADED[variant_key]
 
262
  duration = max(1, min(int(duration), lv.max_seconds))
263
 
264
+ progress(0.05, desc=f"[{variant_key}] preparing conditioning")
265
  conditioning = [{"prompt": prompt, "seconds_total": int(duration)}]
266
+ negative_conditioning = None
267
+ neg = (negative_prompt or "").strip()
268
+ if neg:
269
+ negative_conditioning = [{"prompt": neg, "seconds_total": int(duration)}]
270
+
271
+ # The pretransform encoder is fp16 (we cast the whole model at startup),
272
+ # but prepare_audio's torchaudio Resample uses an fp32 kernel. Pre-resample
273
+ # in fp32 here so prepare_audio's resample is a no-op, then cast to the
274
+ # model dtype so the encoder doesn't see a dtype mismatch.
275
+ model_dtype = next(lv.model.parameters()).dtype
276
+
277
+ def _prep(tup):
278
+ if tup is None:
279
+ return None
280
+ sr, t = tup
281
+ t = t.float()
282
+ if sr != lv.sample_rate:
283
+ t = torchaudio.functional.resample(t, sr, lv.sample_rate)
284
+ return lv.sample_rate, t.to(model_dtype)
285
+
286
+ init_audio_t = _prep(_gradio_audio_to_tensor(init_audio))
287
+ inpaint_audio_t = _prep(_gradio_audio_to_tensor(inpaint_audio))
288
+
289
+ # Inpaint mask: only enable if mask_end > mask_start AND we have either
290
+ # inpaint_audio or init_audio (otherwise the mask wraps zero content).
291
+ mask_start = max(0.0, float(mask_start_sec))
292
+ mask_end = min(float(duration), float(mask_end_sec))
293
+ use_mask = (
294
+ inpaint_audio_t is not None
295
+ and mask_end > mask_start
296
+ )
297
 
298
+ seed_val = int(seed) if seed and int(seed) > 0 else -1
299
+
300
+ preview_images: list = []
301
+ callback = None
302
+ if preview_every and int(preview_every) > 0:
303
+ every = int(preview_every)
304
+
305
+ def _cb(info):
306
+ i = info["i"]
307
+ if i % every != 0:
308
+ return
309
+ denoised = info["denoised"]
310
+ try:
311
+ if lv.model.pretransform is not None:
312
+ denoised = lv.model.pretransform.decode(denoised)
313
+ d = rearrange(denoised, "b d n -> d (b n)")
314
+ d = d.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
315
+ img = audio_spectrogram_image(d, sample_rate=lv.sample_rate)
316
+ preview_images.append((img, f"Step {i + 1}"))
317
+ except Exception as e:
318
+ print(f"[preview] skipped step {i}: {e}", flush=True)
319
+ callback = _cb
320
+
321
+ gen_kwargs: dict = dict(
322
  steps=int(steps),
323
  cfg_scale=float(cfg_scale),
324
  conditioning=conditioning,
325
+ negative_conditioning=negative_conditioning,
326
  sample_size=lv.sample_size,
327
  sampler_type=sampler_type,
328
+ seed=seed_val,
329
  device="cuda",
330
+ sigma_max=float(sigma_max),
331
+ apg_scale=float(apg_scale),
332
+ duration_padding_sec=float(duration_padding_sec),
333
  )
334
+ if init_audio_t is not None:
335
+ gen_kwargs["init_audio"] = init_audio_t
336
+ gen_kwargs["init_noise_level"] = float(init_noise_level)
337
+ if inpaint_audio_t is not None:
338
+ gen_kwargs["inpaint_audio"] = inpaint_audio_t
339
+ if use_mask:
340
+ gen_kwargs["inpaint_mask_start_seconds"] = mask_start
341
+ gen_kwargs["inpaint_mask_end_seconds"] = mask_end
342
+ if callback is not None:
343
+ gen_kwargs["callback"] = callback
344
+
345
+ progress(0.25, desc=f"[{variant_key}] sampling {steps} steps with {sampler_type}")
346
+ t0 = time.time()
347
+ output = generate_diffusion_cond_inpaint(lv.model, **gen_kwargs)
348
  print(f"[infer/{variant_key}] sampling done in {time.time() - t0:.1f}s", flush=True)
349
 
350
  progress(0.92, desc="Normalising & saving")
351
+ cut_dur = int(duration) if cut_to_seconds_total else None
352
+ out_path, int16_audio = _tensor_to_wav(output, lv.sample_rate, cut_dur)
353
+
354
+ if not return_spectrogram:
355
+ return out_path
356
+
357
+ spec_img = audio_spectrogram_image(int16_audio, sample_rate=lv.sample_rate)
358
+ return out_path, [spec_img, *preview_images]
359
+
360
+
361
+ @spaces.GPU
362
+ def infer(
363
+ variant_key: str,
364
+ prompt: str,
365
+ duration: int = 60,
366
+ steps: int = 8,
367
+ cfg_scale: float = 1.0,
368
+ sampler_type: str = "pingpong",
369
+ seed: int = 0,
370
+ progress: gr.Progress = gr.Progress(),
371
+ ):
372
+ """Slim handler used by the Simple tab and the Examples cache."""
373
+ return _run_inference(
374
+ variant_key=variant_key,
375
+ prompt=prompt,
376
+ duration=duration,
377
+ steps=steps,
378
+ cfg_scale=cfg_scale,
379
+ sampler_type=sampler_type,
380
+ seed=seed,
381
+ return_spectrogram=False,
382
+ progress=progress,
383
  )
 
384
 
385
+
386
+ @spaces.GPU
387
+ def infer_advanced(
388
+ variant_key: str,
389
+ prompt: str,
390
+ negative_prompt: str,
391
+ duration: int,
392
+ steps: int,
393
+ cfg_scale: float,
394
+ sampler_type: str,
395
+ seed: int,
396
+ sigma_max: float,
397
+ apg_scale: float,
398
+ duration_padding_sec: float,
399
+ cut_to_seconds_total: bool,
400
+ init_audio: Optional[Tuple[int, np.ndarray]],
401
+ init_noise_level: float,
402
+ inpaint_audio: Optional[Tuple[int, np.ndarray]],
403
+ mask_start_sec: float,
404
+ mask_end_sec: float,
405
+ preview_every: int,
406
+ progress: gr.Progress = gr.Progress(),
407
+ ):
408
+ """Full-featured handler used by the Advanced tab."""
409
+ return _run_inference(
410
+ variant_key=variant_key,
411
+ prompt=prompt,
412
+ negative_prompt=negative_prompt,
413
+ duration=duration,
414
+ steps=steps,
415
+ cfg_scale=cfg_scale,
416
+ sampler_type=sampler_type,
417
+ seed=seed,
418
+ sigma_max=sigma_max,
419
+ apg_scale=apg_scale,
420
+ duration_padding_sec=duration_padding_sec,
421
+ cut_to_seconds_total=cut_to_seconds_total,
422
+ init_audio=init_audio,
423
+ init_noise_level=init_noise_level,
424
+ inpaint_audio=inpaint_audio,
425
+ mask_start_sec=mask_start_sec,
426
+ mask_end_sec=mask_end_sec,
427
+ preview_every=preview_every,
428
+ return_spectrogram=True,
429
+ progress=progress,
430
+ )
431
 
432
 
433
  # ---------------------------------------------------------------------------
 
437
  DESCRIPTION = """
438
  # 🎵 Stable Audio 3
439
 
440
+ Text-to-audio generation with <a href="https://huggingface.co/collections/stabilityai/stable-audio-3" target="_blank" rel="noopener noreferrer">Stable Audio 3</a>. Pick a variant, write a prompt, hit Generate. Switch to **Advanced** for the full sampler / init-audio / inpainting controls.
441
  """
442
 
443
  EXAMPLES = [
 
450
  ]
451
 
452
 
453
+ def _variant_change_simple(variant_key: str):
454
  lv = LOADED[variant_key]
455
  return (
456
  gr.update(maximum=lv.max_seconds, value=min(lv.variant.default_duration, lv.max_seconds),
 
459
  )
460
 
461
 
462
+ def _variant_change_advanced(variant_key: str):
463
+ lv = LOADED[variant_key]
464
+ dur = min(lv.variant.default_duration, lv.max_seconds)
465
+ return (
466
+ gr.update(maximum=lv.max_seconds, value=dur,
467
+ label=f"Seconds total · model max {lv.max_seconds}s"),
468
+ gr.update(placeholder=lv.variant.placeholder),
469
+ gr.update(maximum=float(lv.max_seconds), value=0.0),
470
+ gr.update(maximum=float(lv.max_seconds), value=float(dur)),
471
+ )
472
+
473
+
474
  with gr.Blocks(theme=gr.themes.Citrus(), title="Stable Audio 3") as demo:
475
  gr.Markdown(DESCRIPTION)
476
 
477
+ with gr.Tabs():
478
+ # -----------------------------------------------------------------
479
+ # Simple tab
480
+ # -----------------------------------------------------------------
481
+ with gr.Tab("Simple"):
482
+ variant = gr.Radio(
483
+ choices=VARIANT_CHOICES,
484
+ value=VARIANTS[0].key,
485
+ label="Model",
486
+ )
487
 
488
+ with gr.Row():
489
+ with gr.Column(scale=2):
490
+ prompt = gr.Textbox(
491
+ label="Prompt",
492
+ placeholder=VARIANTS[0].placeholder,
493
+ lines=3,
494
+ )
495
+ duration = gr.Slider(
496
+ 1, LOADED[VARIANTS[0].key].max_seconds,
497
+ value=VARIANTS[0].default_duration, step=1,
498
+ label=f"Duration (s) · model max {LOADED[VARIANTS[0].key].max_seconds}s",
499
+ )
500
+ with gr.Accordion("Advanced settings", open=False):
501
+ steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
502
+ cfg_scale = gr.Slider(0.5, 8.0, value=1.0, step=0.1, label="CFG scale")
503
+ sampler_type = gr.Dropdown(SAMPLERS, value="pingpong", label="Sampler")
504
+ seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
505
+ run_btn = gr.Button("🎼 Generate", variant="primary", size="lg")
506
+
507
+ with gr.Column(scale=1):
508
+ audio_out = gr.Audio(label="Output", type="filepath", autoplay=True)
509
+
510
+ gr.Examples(
511
+ examples=EXAMPLES,
512
+ inputs=[variant, prompt, duration],
513
+ outputs=[audio_out],
514
+ fn=infer,
515
+ cache_examples=True,
516
+ cache_mode="lazy",
517
+ label="Examples (lazy-cached on first click)",
518
  )
519
+
520
+ variant.change(
521
+ fn=_variant_change_simple,
522
+ inputs=[variant],
523
+ outputs=[duration, prompt],
524
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
+ run_btn.click(
527
+ fn=infer,
528
+ inputs=[variant, prompt, duration, steps, cfg_scale, sampler_type, seed],
529
+ outputs=[audio_out],
530
+ )
531
 
532
+ # -----------------------------------------------------------------
533
+ # Advanced tab — mirrors stable_audio_3/interface/diffusion_cond.py
534
+ # -----------------------------------------------------------------
535
+ with gr.Tab("Advanced"):
536
+ adv_variant = gr.Radio(
537
+ choices=VARIANT_CHOICES,
538
+ value=VARIANTS[0].key,
539
+ label="Model",
540
+ )
541
+
542
+ with gr.Row():
543
+ with gr.Column(scale=6):
544
+ adv_prompt = gr.Textbox(
545
+ show_label=False,
546
+ placeholder=VARIANTS[0].placeholder,
547
+ )
548
+ adv_negative = gr.Textbox(
549
+ show_label=False, placeholder="Negative prompt"
550
+ )
551
+ adv_generate = gr.Button("Generate", variant="primary", scale=1)
552
+
553
+ with gr.Row(equal_height=False):
554
+ with gr.Column():
555
+ adv_seconds_total = gr.Slider(
556
+ minimum=1,
557
+ maximum=LOADED[VARIANTS[0].key].max_seconds,
558
+ step=1,
559
+ value=VARIANTS[0].default_duration,
560
+ label=f"Seconds total · model max {LOADED[VARIANTS[0].key].max_seconds}s",
561
+ )
562
+
563
+ with gr.Row():
564
+ adv_steps = gr.Slider(
565
+ minimum=1, maximum=500, step=1, value=8, label="Steps"
566
+ )
567
+ adv_cfg = gr.Slider(
568
+ minimum=0.0, maximum=25.0, step=0.1, value=1.0,
569
+ label="CFG scale",
570
+ )
571
+
572
+ with gr.Accordion("Sampler params", open=False):
573
+ with gr.Row():
574
+ adv_seed = gr.Number(
575
+ label="Seed (set to -1 for random seed)",
576
+ value=-1, precision=0,
577
+ )
578
+ adv_sampler = gr.Dropdown(
579
+ SAMPLERS, label="Sampler type", value="pingpong",
580
+ )
581
+ adv_sigma_max = gr.Slider(
582
+ minimum=0.0, maximum=1.0, step=0.01, value=1.0,
583
+ label="Sigma max",
584
+ )
585
+ with gr.Row():
586
+ adv_apg = gr.Slider(
587
+ minimum=0.0, maximum=1.0, step=0.1, value=1.0,
588
+ label="APG scale", info="1.0=full APG, 0.0=vanilla CFG",
589
+ )
590
+ adv_dur_padding = gr.Slider(
591
+ minimum=0.0, maximum=30.0, step=0.5, value=6.0,
592
+ label="Duration padding (sec)",
593
+ )
594
+
595
+ with gr.Accordion("Output params", open=False):
596
+ with gr.Row():
597
+ adv_preview_every = gr.Slider(
598
+ minimum=0, maximum=100, step=1, value=0,
599
+ label="Spec preview every N steps (0 = off)",
600
+ )
601
+ adv_cut_to_total = gr.Checkbox(
602
+ label="Cut to seconds total", value=True,
603
+ )
604
+
605
+ with gr.Accordion("Init audio", open=False):
606
+ adv_init_audio = gr.Audio(
607
+ label="Init audio",
608
+ type="numpy",
609
+ )
610
+ adv_init_noise = gr.Slider(
611
+ minimum=0.01, maximum=1.0, step=0.01, value=0.9,
612
+ label="Init noise level",
613
+ )
614
+
615
+ with gr.Accordion("Inpainting", open=False):
616
+ adv_inpaint_audio = gr.Audio(
617
+ label="Inpaint audio",
618
+ type="numpy",
619
+ )
620
+ adv_mask_start = gr.Slider(
621
+ minimum=0.0,
622
+ maximum=float(LOADED[VARIANTS[0].key].max_seconds),
623
+ step=0.1, value=0.0, label="Mask start (sec)",
624
+ )
625
+ adv_mask_end = gr.Slider(
626
+ minimum=0.0,
627
+ maximum=float(LOADED[VARIANTS[0].key].max_seconds),
628
+ step=0.1, value=0.0, label="Mask end (sec)",
629
+ )
630
+
631
+ with gr.Column():
632
+ adv_audio_out = gr.Audio(
633
+ label="Output audio", type="filepath", autoplay=False,
634
+ sources=[],
635
+ )
636
+ adv_spec_gallery = gr.Gallery(
637
+ label="Output spectrogram", show_label=True, columns=2,
638
+ )
639
+ send_to_init_btn = gr.Button("Send to init audio")
640
+ send_to_inpaint_btn = gr.Button("Send to inpaint audio")
641
+
642
+ send_to_init_btn.click(
643
+ fn=lambda a: a, inputs=[adv_audio_out], outputs=[adv_init_audio]
644
+ )
645
+ send_to_inpaint_btn.click(
646
+ fn=lambda a: a, inputs=[adv_audio_out], outputs=[adv_inpaint_audio]
647
+ )
648
+
649
+ # Keep the inpaint mask bounded by the current duration.
650
+ def _update_mask_max(seconds_total):
651
+ m = max(float(seconds_total), 1.0)
652
+ return (
653
+ gr.update(maximum=m),
654
+ gr.update(maximum=m, value=m),
655
+ )
656
+ adv_seconds_total.change(
657
+ _update_mask_max,
658
+ inputs=[adv_seconds_total],
659
+ outputs=[adv_mask_start, adv_mask_end],
660
+ )
661
+
662
+ adv_variant.change(
663
+ fn=_variant_change_advanced,
664
+ inputs=[adv_variant],
665
+ outputs=[adv_seconds_total, adv_prompt, adv_mask_start, adv_mask_end],
666
+ )
667
+
668
+ adv_generate.click(
669
+ fn=infer_advanced,
670
+ inputs=[
671
+ adv_variant,
672
+ adv_prompt,
673
+ adv_negative,
674
+ adv_seconds_total,
675
+ adv_steps,
676
+ adv_cfg,
677
+ adv_sampler,
678
+ adv_seed,
679
+ adv_sigma_max,
680
+ adv_apg,
681
+ adv_dur_padding,
682
+ adv_cut_to_total,
683
+ adv_init_audio,
684
+ adv_init_noise,
685
+ adv_inpaint_audio,
686
+ adv_mask_start,
687
+ adv_mask_end,
688
+ adv_preview_every,
689
+ ],
690
+ outputs=[adv_audio_out, adv_spec_gallery],
691
+ )
692
 
693
 
694
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -2,6 +2,8 @@
2
  einops
3
  soundfile
4
  numpy<2
 
 
5
  pytorch_lightning
6
  torch
7
  torchaudio
 
2
  einops
3
  soundfile
4
  numpy<2
5
+ matplotlib
6
+ Pillow
7
  pytorch_lightning
8
  torch
9
  torchaudio