import gc import numpy as np import gradio as gr import re import subprocess import torch import torchaudio import threading import os, time, math from einops import rearrange from stable_audio_3.interface.aeiou import audio_spectrogram_image from stable_audio_3.inference.distribution_shift import LogSNRShift, FluxDistributionShift, DistributionShift, IdentityDistributionShift from stable_audio_3.models.lora import has_lora from stable_audio_3.interface.reprompt import reprompt as _reprompt_fn, get_model as _reprompt_get_model, is_model_cached as _reprompt_is_model_cached stable_audio_3_model = None sample_size = 5324800 sample_rate = 44100 n_loras = 0 _LENGTH_EXTRACT_RE = re.compile(r' Length: (\d+) seconds\.?\s*$') # when using a prompt in a filename def condense_prompt(prompt): pattern = r'[\\/:*?"<>|]' # Replace special characters with hyphens prompt = re.sub(pattern, '-', prompt) # set a character limit prompt = prompt[:150] # zero length prompts may lead to filenames (ie ".wav") which seem cause problems with gradio if len(prompt)==0: prompt = "_" return prompt def generate_cond( prompt, negative_prompt=None, seconds_total=30, cfg_scale=6.0, steps=250, preview_every=None, seed=-1, sampler_type="dpmpp-3m-sde", sigma_max=1000, cfg_interval_min=0.0, cfg_interval_max=1.0, cfg_rescale=0.0, cfg_norm_threshold=0.0, apg_scale=1.0, file_format="wav", file_naming="verbose", cut_to_seconds_total=False, init_audio=None, init_noise_level=1.0, mask_maskstart=None, mask_maskend=None, inpaint_audio=None, init_audio_type="Init audio", inversion_steps=100, inversion_gamma=0.3, inversion_unconditional=False, duration_padding_sec=6.0, batch_size=1, dist_shift=None, *lora_args ): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print(f"Prompt: {prompt}") global preview_images preview_images = [] if preview_every == 0: preview_every = None # Parse per-LoRA controls from trailing args # Each LoRA has 5 controls: strength, interval_min, interval_max, layer_filter lora_configs = None if n_loras > 0 and len(lora_args) >= n_loras * 4: lora_configs = [] for i in range(n_loras): off = i * 4 strength = lora_args[off] interval_min = lora_args[off + 1] interval_max = lora_args[off + 2] layer_filter = lora_args[off + 3] stable_audio_3_model.set_lora_strength(strength, lora_index=i) lora_configs.append({ "lora_index": i, "interval": (interval_min, interval_max), "layer_filter": layer_filter, }) input_sample_size = sample_size def progress_callback(callback_info): global preview_images denoised = callback_info["denoised"] current_step = callback_info["i"] t = callback_info["t"] sigma = callback_info["sigma"] # Extract scalar from tensor if needed (samplers pass tensors to avoid GPU sync) if isinstance(t, torch.Tensor): t = t[0].item() if t.dim() > 0 else t.item() if isinstance(sigma, torch.Tensor): sigma = sigma[0].item() if sigma.dim() > 0 else sigma.item() log_snr = math.log(((1 - sigma) / sigma) + 1e-6) if (current_step - 1) % preview_every == 0: if stable_audio_3_model.model.pretransform is not None: denoised = stable_audio_3_model.model.pretransform.decode(denoised) denoised = rearrange(denoised, "b d n -> d (b n)") denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f} logSNR={log_snr:.3f}")) if init_audio_type == "RF-Inversion": inversion_params = { "inversion_steps": inversion_steps, "inversion_gamma": inversion_gamma, "inversion_unconditional": inversion_unconditional, "inversion_cfg_scale": 1.0, "inversion_sigma_max": 1.0 } else: inversion_params = None generate_args = { "prompt": prompt, "negative_prompt": negative_prompt, "duration": seconds_total, "steps": steps, "cfg_scale": cfg_scale, "cfg_interval": (cfg_interval_min, cfg_interval_max), "lora_configs": lora_configs, "batch_size": batch_size, "sample_size": input_sample_size, "seed": seed, "sampler_type": sampler_type, "sigma_max": sigma_max, "init_audio": init_audio, "init_noise_level": init_noise_level, "callback": progress_callback if preview_every is not None else None, "scale_phi": cfg_rescale, "cfg_norm_threshold": cfg_norm_threshold, "apg_scale": apg_scale, "duration_padding_sec": duration_padding_sec, "dist_shift": dist_shift, } # If inpainting, send mask args # This will definitely change in the future if inpaint_audio is not None: generate_args.update({ "inpaint_audio": inpaint_audio, "inpaint_mask_start_seconds": mask_maskstart, "inpaint_mask_end_seconds": mask_maskend, }) audio = stable_audio_3_model.generate(**generate_args) # Filenaming convention prompt_condensed = condense_prompt(prompt) if file_naming=="verbose": basename = prompt_condensed if negative_prompt: basename += ".neg-%s" % condense_prompt(negative_prompt) basename += ".cfg%s" % (cfg_scale) if sigma_max not in [1.0, 100.0]: # this is a common parameter to tweak, if it's not a default value, put it in the verbose filename basename += ".smx%s" % sigma_max basename += ".%s" % seed elif file_naming=="prompt": basename = prompt_condensed else: # simple e.g. "output.wav" basename = "output" if file_format: filename_extension = file_format.split(" ")[0].lower() else: filename_extension = "wav" output_filename = "%s.%s" % (basename, filename_extension) output_wav = "%s.wav" % basename # Cut the extra silence off the end, if the user requested a smaller seconds_total if cut_to_seconds_total: audio = audio[:,:,:seconds_total*sample_rate] # Encode the audio to WAV format audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() # save as wav file torchaudio.save(output_wav, audio, sample_rate) # If file_format is other than wav, convert to other file format cmd = "" if file_format == "m4a aac_he_v2 32k": # note: need to compile ffmpeg with --enable-libfdk_aac cmd = f"ffmpeg -i \"{output_wav}\" -c:a libfdk_aac -profile:a aac_he_v2 -b:a 32k -y \"{output_filename}\"" elif file_format == "m4a aac_he_v2 64k": cmd = f"ffmpeg -i \"{output_wav}\" -c:a libfdk_aac -profile:a aac_he_v2 -b:a 64k -y \"{output_filename}\"" elif file_format == "flac": cmd = f"ffmpeg -i \"{output_wav}\" -y \"{output_filename}\"" elif file_format == "mp3 320k": cmd = f"ffmpeg -i \"{output_wav}\" -b:a 320k -y \"{output_filename}\"" elif file_format == "mp3 128k": cmd = f"ffmpeg -i \"{output_wav}\" -b:a 128k -y \"{output_filename}\"" elif file_format == "mp3 v0": cmd = f"ffmpeg -i \"{output_wav}\" -q:a 0 -y \"{output_filename}\"" else: # wav pass if cmd: cmd += " -loglevel error" # make output less verbose in the cmd window subprocess.run(cmd, shell=True, check=True) # Let's look at a nice spectrogram too audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) # Asynchronously delete the files after returning the output file, so as to prevent clutter in the directory delete_files_async([output_wav, output_filename], 30) return (output_filename, [audio_spectrogram, *preview_images]) # Asynchronously delete the given list of filenames after delay seconds. Sets up thread that sleeps for delay then deletes. def delete_files_async(filenames, delay): def delete_files_after_delay(filenames, delay): time.sleep(delay) # Wait for the specified delay for filename in filenames: if os.path.exists(filename): os.remove(filename) # Delete the file threading.Thread(target=delete_files_after_delay, args=(filenames, delay)).start() def create_sampling_ui(stable_audio_3_model, default_prompt=None): global n_loras diffusion_objective = stable_audio_3_model.model.diffusion_objective is_rf = diffusion_objective == "rectified_flow" is_rf_denoiser = diffusion_objective == "rf_denoiser" # includes ARC models # Extract default dist_shift params from model's sampling_dist_shift default_sampling_dist_shift = getattr(stable_audio_3_model.model, 'sampling_dist_shift', None) default_dist_shift_type = "LogSNR" default_logsnr_params = {"anchor_length": 2000, "anchor_logsnr": -6.2, "rate": 0.0, "logsnr_end": 2.0} default_flux_params = {"min_length": 256, "max_length": 4096, "alpha_min": 6.93, "alpha_max": 6.93} default_full_params = {"base_shift": 0.5, "max_shift": 1.15, "min_length": 256, "max_length": 4096} if isinstance(default_sampling_dist_shift, LogSNRShift): default_dist_shift_type = "LogSNR" default_logsnr_params = { "anchor_length": getattr(default_sampling_dist_shift, 'anchor_length', 2000), "anchor_logsnr": getattr(default_sampling_dist_shift, 'anchor_logsnr', -6.2), "rate": getattr(default_sampling_dist_shift, 'rate', 0.0), "logsnr_end": getattr(default_sampling_dist_shift, 'logsnr_end', 2.0), } elif isinstance(default_sampling_dist_shift, FluxDistributionShift): default_dist_shift_type = "Flux" default_flux_params = { "min_length": default_sampling_dist_shift.min_length, "max_length": default_sampling_dist_shift.max_length, "alpha_min": default_sampling_dist_shift.alpha_min, "alpha_max": default_sampling_dist_shift.alpha_max, } elif isinstance(default_sampling_dist_shift, DistributionShift): default_dist_shift_type = "Full" default_full_params = { "base_shift": default_sampling_dist_shift.base_shift, "max_shift": default_sampling_dist_shift.max_shift, "min_length": default_sampling_dist_shift.min_length, "max_length": default_sampling_dist_shift.max_length, } elif default_sampling_dist_shift is None: default_dist_shift_type = "None" has_seconds_total = True use_lora = has_lora(stable_audio_3_model.model) lora_names = getattr(stable_audio_3_model.model, 'lora_names', []) n_loras = len(lora_names) if default_prompt is None: default_prompt = "" _reprompt_model_id = "Qwen/Qwen3.5-2B" _reprompt_cached = _reprompt_is_model_cached(_reprompt_model_id) with gr.Row(): with gr.Column(scale=6): prompt = gr.Textbox(show_label=False, placeholder="Prompt", value=default_prompt) negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") prompt_assistant_button = gr.Button( "Prompt Assistant" if _reprompt_cached else "Download Prompt Assistant (~4.2 GB)", scale=1 ) generate_button = gr.Button("Generate", variant='primary', scale=1) with gr.Row(equal_height=False): with gr.Column(): with gr.Row(visible = True): # Timing controls seconds_total_slider = gr.Slider(minimum=0, maximum=sample_size//sample_rate, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) with gr.Row(): # Steps slider if is_rf: default_steps = 50 elif is_rf_denoiser: default_steps = 8 steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=default_steps, label="Steps") # CFG scale default_cfg_scale = 1.0 if is_rf_denoiser else 7.0 cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=default_cfg_scale, label="CFG scale") # Per-LoRA controls (dynamic based on number of loaded LoRAs) lora_ui_inputs = [] if use_lora and lora_names: for i, lora_name in enumerate(lora_names): with gr.Accordion("LoRA {}: {}".format(i + 1, lora_name), open=(i == 0)): with gr.Row(): strength = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.0, label="strength") with gr.Row(): int_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Interval min") int_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Interval max") lyr_filt = gr.Textbox(label="Layer filter", placeholder="") lora_ui_inputs.extend([strength, int_min, int_max, lyr_filt]) with gr.Accordion("Sampler params", open=False): with gr.Row(): # Seed seed_textbox = gr.Number(label="Seed (set to -1 for random seed)", value=-1, precision=0) cfg_interval_min_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG interval min") cfg_interval_max_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=1.0, label="CFG interval max") with gr.Row(): cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount") cfg_norm_threshold = gr.Slider(minimum=0.0, maximum=100, step=0.1, value=0.0, label="CFG norm threshold") apg_scale_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="APG scale", info="1.0=full APG, 0.0=vanilla CFG") with gr.Row(): # Sampler params if is_rf: sampler_types = ["euler", "rk4", "dpmpp"] default_sampler_type = "euler" sigma_max_max = 1.0 sigma_max_default = 1.0 elif is_rf_denoiser: sampler_types = ["pingpong"] default_sampler_type = "pingpong" sigma_max_max = 1.0 sigma_max_default = 1.0 else: sampler_types = ["dpmpp-2m-sde", "dpmpp-3m-sde", "dpmpp-2m", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-adaptive", "k-dpm-fast", "v-ddim", "v-ddim-cfgpp"] default_sampler_type = "dpmpp-3m-sde" sigma_max_max = 1000.0 sigma_max_default = 100.0 sampler_type_dropdown = gr.Dropdown(sampler_types, label="Sampler type", value=default_sampler_type) sigma_max_slider = gr.Slider(minimum=0.0, maximum=sigma_max_max, step=0.1, value=sigma_max_default, label="Sigma max", visible=True) with gr.Row(): duration_padding_slider = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=6.0, label="Duration padding (sec)") def build_dist_shift(shift_type, p1, p2, p3, p4): """Build dist_shift from type + 4 params (meaning depends on type).""" if shift_type == "LogSNR": return LogSNRShift(anchor_length=int(p1), anchor_logsnr=p2, rate=p3, logsnr_end=p4) elif shift_type == "Flux": return FluxDistributionShift(min_length=int(p1), max_length=int(p2), alpha_min=p3, alpha_max=p4) elif shift_type == "Full": return DistributionShift(base_shift=p1, max_shift=p2, min_length=int(p3), max_length=int(p4)) return IdentityDistributionShift() # "None" = no shift dist_shift_state = gr.State(value=default_sampling_dist_shift) with gr.Row(visible=is_rf or is_rf_denoiser): dist_shift_type_dropdown = gr.Dropdown( ["LogSNR", "Flux", "Full", "None"], label="Sampling schedule shift", value=default_dist_shift_type, info="Distribution shift applied to sampling timesteps" ) with gr.Row(visible=(is_rf or is_rf_denoiser) and default_dist_shift_type == "LogSNR") as logsnr_params_row: logsnr_anchor_length_slider = gr.Slider(minimum=100, maximum=10000, step=100, value=default_logsnr_params["anchor_length"], label="Anchor length") logsnr_anchor_logsnr_slider = gr.Slider(minimum=-12.0, maximum=0.0, step=0.1, value=default_logsnr_params["anchor_logsnr"], label="Anchor log-SNR") logsnr_rate_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=default_logsnr_params["rate"], label="Rate") logsnr_end_slider = gr.Slider(minimum=-2.0, maximum=6.0, step=0.1, value=default_logsnr_params["logsnr_end"], label="log-SNR end") with gr.Row(visible=(is_rf or is_rf_denoiser) and default_dist_shift_type == "Flux") as flux_params_row: flux_min_length_slider = gr.Slider(minimum=1, maximum=10000, step=1, value=default_flux_params["min_length"], label="Min seq len") flux_max_length_slider = gr.Slider(minimum=1, maximum=10000, step=1, value=default_flux_params["max_length"], label="Max seq len") flux_alpha_min_slider = gr.Slider(minimum=0.1, maximum=20.0, step=0.1, value=default_flux_params["alpha_min"], label="Alpha min") flux_alpha_max_slider = gr.Slider(minimum=0.1, maximum=20.0, step=0.1, value=default_flux_params["alpha_max"], label="Alpha max") with gr.Row(visible=(is_rf or is_rf_denoiser) and default_dist_shift_type == "Full") as full_params_row: full_base_shift_slider = gr.Slider(minimum=0.0, maximum=5.0, step=0.05, value=default_full_params["base_shift"], label="Base shift") full_max_shift_slider = gr.Slider(minimum=0.0, maximum=5.0, step=0.05, value=default_full_params["max_shift"], label="Max shift") full_min_length_slider = gr.Slider(minimum=1, maximum=10000, step=1, value=default_full_params["min_length"], label="Min length") full_max_length_slider = gr.Slider(minimum=1, maximum=10000, step=1, value=default_full_params["max_length"], label="Max length") # Per-type slider groups for wiring to state logsnr_sliders = [logsnr_anchor_length_slider, logsnr_anchor_logsnr_slider, logsnr_rate_slider, logsnr_end_slider] flux_sliders = [flux_min_length_slider, flux_max_length_slider, flux_alpha_min_slider, flux_alpha_max_slider] full_sliders = [full_base_shift_slider, full_max_shift_slider, full_min_length_slider, full_max_length_slider] all_dist_shift_inputs = [dist_shift_type_dropdown] + logsnr_sliders + flux_sliders + full_sliders def update_dist_shift_state(shift_type, *params): """Route the 4 relevant params to build_dist_shift based on type.""" type_to_slice = {"LogSNR": params[0:4], "Flux": params[4:8], "Full": params[8:12]} p = type_to_slice.get(shift_type, (0, 0, 0, 0)) return ( build_dist_shift(shift_type, *p), gr.update(visible=((is_rf or is_rf_denoiser) and (shift_type == "LogSNR"))), gr.update(visible=((is_rf or is_rf_denoiser) and (shift_type == "Flux"))), gr.update(visible=((is_rf or is_rf_denoiser) and (shift_type == "Full"))), ) for component in all_dist_shift_inputs: component.change( update_dist_shift_state, inputs=all_dist_shift_inputs, outputs=[dist_shift_state, logsnr_params_row, flux_params_row, full_params_row], ) # Hidden state for batch_size (no UI control, but needed for function signature) batch_size_state = gr.State(value=1) with gr.Accordion("Output params", open=False): # Output params with gr.Row(): file_format_dropdown = gr.Dropdown(["wav", "flac", "mp3 320k", "mp3 v0", "mp3 128k", "m4a aac_he_v2 64k", "m4a aac_he_v2 32k"], label="File format", value="wav") file_naming_dropdown = gr.Dropdown(["verbose", "prompt", "output.wav"], label="File naming", value="verbose") # ,"prompt","verbose" preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Spec Preview Every") cut_to_seconds_total_checkbox = gr.Checkbox(label="Cut to seconds total", value=True) autoplay_checkbox = gr.Checkbox(label="Autoplay", value=False, elem_id="autoplay") infinite_radio_checkbox = gr.Checkbox(label="Infinite Radio", value=False, elem_id="infinite-radio") automatic_download_checkbox = gr.Checkbox(label="Auto Download", value=False, elem_id="automatic-download") # Default generation tab with gr.Accordion("Init audio", open=False): init_audio_input = gr.Audio(label="Init audio", waveform_options=gr.WaveformOptions(show_recording_waveform=False)) min_noise_level = 0.01 max_noise_level = 1.0 default_noise_level = 0.9 # roughly halfway style transfer values if is_rf: choices = ["Init audio","RF-Inversion"] else: choices = ["Init audio"] init_audio_type_radio = gr.Radio(label="Techniques", choices=choices, value=choices[0], visible=len(choices)>1) with gr.Column(visible=True) as interface_a: init_noise_level_slider = gr.Slider(minimum=min_noise_level, maximum=max_noise_level, step=0.01, value=default_noise_level, label="Init noise level") with gr.Column(visible=False) as interface_b: inversion_steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Inversion Steps") inversion_gamma_slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0, label="Gamma", visible=True) inversion_unconditional_checkbox = gr.Checkbox(label="Unconditional", value=False) gr.HTML("
For reproduction, try empty prompt, cfg 1, gamma .3
\ For prompt re-stylization, try cfg 1-7, gamma 0-.15, unconditional
") def init_audio_type_switch(choice): return ( gr.update(visible=(choice == "Init audio")), gr.update(visible=(choice == "RF-Inversion")) ) init_audio_type_radio.change(init_audio_type_switch, inputs=init_audio_type_radio, outputs=[interface_a, interface_b]) with gr.Accordion("Inpainting", open=False): inpaint_audio_input = gr.Audio(label="Inpaint audio", waveform_options=gr.WaveformOptions(show_recording_waveform=False)) mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=sample_size//sample_rate, step=0.1, value=0, label="Mask Start (sec)") mask_maskend_slider = gr.Slider(minimum=0.0, maximum=sample_size//sample_rate, step=0.1, value=sample_size//sample_rate, label="Mask End (sec)") # Update inpainting slider ranges when seconds_total changes. # Only seconds_total is an input — reading the mask sliders here would cause # validation errors since their values may exceed the about-to-be-reduced maximum. def update_inpaint_sliders(seconds_total): max_val = max(seconds_total, 1) return ( gr.update(maximum=max_val), gr.update(maximum=max_val, value=max_val), ) seconds_total_slider.change(update_inpaint_sliders, inputs=[seconds_total_slider], outputs=[mask_maskstart_slider, mask_maskend_slider]) inputs = [ prompt, negative_prompt, seconds_total_slider, cfg_scale_slider, steps_slider, preview_every_slider, seed_textbox, sampler_type_dropdown, sigma_max_slider, cfg_interval_min_slider, cfg_interval_max_slider, cfg_rescale_slider, cfg_norm_threshold, apg_scale_slider, file_format_dropdown, file_naming_dropdown, cut_to_seconds_total_checkbox, init_audio_input, init_noise_level_slider, mask_maskstart_slider, mask_maskend_slider, inpaint_audio_input, init_audio_type_radio, inversion_steps_slider, inversion_gamma_slider, inversion_unconditional_checkbox, duration_padding_slider, batch_size_state, dist_shift_state, ] + lora_ui_inputs with gr.Column(): audio_output = gr.Audio(label="Output audio", interactive=False, waveform_options=gr.WaveformOptions(show_recording_waveform=False)) audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) send_to_init_button = gr.Button("Send to init audio", scale=1) send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) send_to_inpaint_button = gr.Button("Send to inpaint audio", scale=1) send_to_inpaint_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[inpaint_audio_input]) generate_button.click(fn=generate_cond, inputs=inputs, outputs=[ audio_output, audio_spectrogram_output ], api_name="generate") def _prompt_assistant_or_download(text, progress=gr.Progress(track_tqdm=True)): if not _reprompt_is_model_cached(_reprompt_model_id): _reprompt_get_model(_reprompt_model_id) return text, gr.update(), gr.update(value="Prompt Assistant") _, result, _ = _reprompt_fn(text, "Auto", "", _reprompt_model_id, 128, 1.11) m = _LENGTH_EXTRACT_RE.search(result) if m: max_seconds = sample_size // sample_rate seconds = min(int(m.group(1)), max_seconds) result = result[:m.start()] else: seconds = gr.update() return result, seconds, gr.update() prompt_assistant_button.click( fn=_prompt_assistant_or_download, inputs=[prompt], outputs=[prompt, seconds_total_slider, prompt_assistant_button], concurrency_limit=1, ) def create_diffusion_cond_ui(model, gradio_title="", default_prompt=None): global sample_size, sample_rate, stable_audio_3_model sample_size = model.model_config["sample_size"] sample_rate = model.model_config["sample_rate"] stable_audio_3_model = model js ="""function run_javascript_on_page_load(){ const generateBtn = Array.from(document.querySelectorAll('button')) .find(btn => btn.innerText.trim() === 'Generate'); function getAudioOutputPlayer () { return [...document.querySelectorAll('label')].find(label => label.textContent.trim() === 'Output audio')?.parentElement.querySelector('audio'); } const infiniteRadio = document.querySelector('#infinite-radio input[type="checkbox"]'); const autoplay = document.querySelector('#autoplay input[type="checkbox"]'); const automaticDownload = document.querySelector('#automatic-download input[type="checkbox"]'); let radioAutoStart = false; let listenersSetup = false; const setupListeners = () => { const audioEl = getAudioOutputPlayer(); if (!audioEl) return; audioEl.addEventListener('loadedmetadata', () => { if(automaticDownload.checked){ downloadAudio(audioEl); } if(autoplay.checked || radioAutoStart){ audioEl.play(); radioAutoStart = false; } if(infiniteRadio.checked){ audioEl.addEventListener('timeupdate', function checkAudioEnd() { // Can set window.headstart (seconds) in the dev console if you want to start generating before the song is over let headstart = 1; if(window.headstart) headstart = window.headstart; if (audioEl.duration - audioEl.currentTime <= headstart) { generateBtn.click(); radioAutoStart = true; audioEl.removeEventListener('timeupdate', checkAudioEnd); } }); } }); listenersSetup = true; }; generateBtn.addEventListener('click', () => { if(listenersSetup) return; const interval = setInterval(() => { console.log("...") const audioEl = document.querySelector('audio'); if (audioEl?.src && audioEl.src !== window.location.href) { setupListeners(); clearInterval(interval); } }, 100); }); // Respond to >> button on MacBookPro and on steering wheel during CarPlay if ('mediaSession' in navigator) { navigator.mediaSession.setActionHandler('nexttrack', () => generateBtn.click()); navigator.mediaSession.setActionHandler('play', () => getAudioOutputPlayer()?.play()); navigator.mediaSession.setActionHandler('pause', () => getAudioOutputPlayer()?.pause()); } // Automatic Download function downloadAudio(audioEl) { const audioSrc = audioEl.src; const link = document.createElement('a'); link.href = audioSrc; link.download = audioSrc.substring(audioSrc.lastIndexOf('/') + 1); document.body.appendChild(link); link.click(); document.body.removeChild(link); } } """ with gr.Blocks() as ui: ui._sao_js = js ui._sao_theme = gr.themes.Base() if gradio_title: gr.Markdown("### %s" % gradio_title) with gr.Tab("Generation"): create_sampling_ui(model, default_prompt=default_prompt) # JavaScript to autoplay audio immediately after generation (if autoplay enabled) return ui