Spaces:
Running on Zero
Running on Zero
| """ | |
| stable-audio — command-line interface for Stable Audio 3. | |
| Basic usage:: | |
| stable-audio --model small-music -p "lo-fi hip hop beat, 90 BPM" --duration 30 -o beat.wav | |
| """ | |
| import argparse | |
| import os | |
| import torch | |
| import torchaudio | |
| from stable_audio_3 import StableAudioModel | |
| def _save_output(audio: torch.Tensor, sample_rate: int, output: str, batch_size: int): | |
| """Save generated audio tensor(s) to disk.""" | |
| base, ext = os.path.splitext(output) | |
| if not ext: | |
| ext = ".wav" | |
| for i in range(batch_size): | |
| path = f"{base}_{i}{ext}" if batch_size > 1 else f"{base}{ext}" | |
| torchaudio.save(path, audio[i].cpu(), sample_rate) | |
| print(f"Saved: {path}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| prog="stable-audio", | |
| description="Stable Audio 3 — CLI for text-to-audio, audio-to-audio, and inpainting", | |
| ) | |
| # Model | |
| parser.add_argument( | |
| "--model", | |
| default="medium", | |
| choices=[ | |
| "medium", | |
| "small-music", | |
| "small-sfx", | |
| "medium-base", | |
| "small-music-base", | |
| "small-sfx-base", | |
| ], | |
| help="Model to load (default: medium)", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| default=None, | |
| help="Device: cuda / mps / cpu (auto-detected if omitted)", | |
| ) | |
| parser.add_argument( | |
| "--no-half", action="store_true", help="Disable half-precision (fp16) on CUDA" | |
| ) | |
| # Generation | |
| parser.add_argument( | |
| "-p", | |
| "--prompt", | |
| required=True, | |
| nargs="+", | |
| help="Text prompt(s). Pass multiple for per-batch prompts", | |
| ) | |
| parser.add_argument( | |
| "--negative-prompt", nargs="+", default=None, help="Negative prompt(s)" | |
| ) | |
| parser.add_argument( | |
| "--duration", | |
| type=float, | |
| nargs="+", | |
| default=[120.0], | |
| help="Duration in seconds (default: 120). Pass multiple for per-batch durations", | |
| ) | |
| parser.add_argument( | |
| "--steps", type=int, default=8, help="Diffusion steps (default: 8)" | |
| ) | |
| parser.add_argument( | |
| "--cfg-scale", | |
| type=float, | |
| default=1.0, | |
| help="CFG scale (default: 1.0; try 7.0 for base models)", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=-1, help="Random seed (-1 = random, default: -1)" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=None, | |
| help="Batch size (default: inferred from number of prompts, or 1)", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| default="output.wav", | |
| help="Output file path (default: output.wav)", | |
| ) | |
| # Audio-to-Audio | |
| parser.add_argument( | |
| "--init-audio", | |
| default=None, | |
| metavar="PATH", | |
| help="Source audio file for audio-to-audio generation", | |
| ) | |
| parser.add_argument( | |
| "--init-noise-level", | |
| type=float, | |
| default=0.9, | |
| help="Noise level for audio-to-audio (0.0–1.0, default: 0.9)", | |
| ) | |
| # Inpainting / Continuation | |
| parser.add_argument( | |
| "--inpaint-audio", | |
| default=None, | |
| metavar="PATH", | |
| help="Source audio file for inpainting or continuation", | |
| ) | |
| parser.add_argument( | |
| "--inpaint-start", | |
| type=float, | |
| action="append", | |
| dest="inpaint_starts", | |
| metavar="SECONDS", | |
| help="Start of inpaint region in seconds. Repeat for multiple regions.", | |
| ) | |
| parser.add_argument( | |
| "--inpaint-end", | |
| type=float, | |
| action="append", | |
| dest="inpaint_ends", | |
| metavar="SECONDS", | |
| help="End of inpaint region in seconds. Repeat for multiple regions.", | |
| ) | |
| # Chunked decode | |
| decode_group = parser.add_mutually_exclusive_group() | |
| decode_group.add_argument( | |
| "--chunked-decode", | |
| action="store_true", | |
| default=None, | |
| help="Force chunked decoding on", | |
| ) | |
| decode_group.add_argument( | |
| "--no-chunked-decode", | |
| action="store_true", | |
| default=None, | |
| help="Force chunked decoding off", | |
| ) | |
| # LoRA | |
| parser.add_argument( | |
| "--lora-ckpt-path", | |
| action="append", | |
| dest="loras", | |
| metavar="PATH", | |
| help="LoRA checkpoint path. Repeat to stack multiple LoRAs.", | |
| ) | |
| parser.add_argument( | |
| "--lora-strength", | |
| type=float, | |
| default=None, | |
| help="LoRA strength (applied to all LoRAs)", | |
| ) | |
| parser.add_argument( | |
| "--lora-index", | |
| type=int, | |
| default=None, | |
| help="Target a specific LoRA index when setting strength", | |
| ) | |
| args = parser.parse_args() | |
| # --- Validate inpaint args --- | |
| if (args.inpaint_starts is None) != (args.inpaint_ends is None): | |
| parser.error("--inpaint-start and --inpaint-end must both be provided together") | |
| if args.inpaint_starts and len(args.inpaint_starts) != len(args.inpaint_ends): | |
| parser.error( | |
| "--inpaint-start and --inpaint-end must be specified the same number of times" | |
| ) | |
| if args.inpaint_starts and not args.inpaint_audio: | |
| parser.error("--inpaint-start/--inpaint-end require --inpaint-audio") | |
| if args.inpaint_audio and not args.inpaint_starts: | |
| parser.error("--inpaint-audio requires --inpaint-start and --inpaint-end") | |
| # --- Resolve batch size --- | |
| n_prompts = len(args.prompt) | |
| if args.batch_size is None: | |
| batch_size = n_prompts | |
| elif n_prompts > 1 and args.batch_size != n_prompts: | |
| parser.error( | |
| f"--batch-size {args.batch_size} does not match the number of prompts " | |
| f"({n_prompts}); omit --batch-size to have it inferred automatically" | |
| ) | |
| else: | |
| batch_size = args.batch_size | |
| # --- Validate list-flag lengths against batch size --- | |
| if ( | |
| args.negative_prompt | |
| and len(args.negative_prompt) > 1 | |
| and len(args.negative_prompt) != batch_size | |
| ): | |
| parser.error( | |
| f"Got {len(args.negative_prompt)} --negative-prompt values but batch size is {batch_size}" | |
| ) | |
| if len(args.duration) > 1 and len(args.duration) != batch_size: | |
| parser.error( | |
| f"Got {len(args.duration)} --duration values but batch size is {batch_size}" | |
| ) | |
| # --- Build scalar / list args --- | |
| prompt = args.prompt[0] if len(args.prompt) == 1 else args.prompt | |
| negative_prompt = None | |
| if args.negative_prompt: | |
| negative_prompt = ( | |
| args.negative_prompt[0] | |
| if len(args.negative_prompt) == 1 | |
| else args.negative_prompt | |
| ) | |
| duration = args.duration[0] if len(args.duration) == 1 else args.duration | |
| # --- chunked_decode flag --- | |
| chunked_decode = None | |
| if args.chunked_decode: | |
| chunked_decode = True | |
| elif args.no_chunked_decode: | |
| chunked_decode = False | |
| # --- Load model --- | |
| print(f"Loading model '{args.model}'…") | |
| model = StableAudioModel.from_pretrained( | |
| args.model, device=args.device, model_half=not args.no_half | |
| ) | |
| # --- LoRA --- | |
| if args.loras: | |
| print(f"Loading LoRA(s): {args.loras}") | |
| model.load_lora(args.loras) | |
| if args.lora_strength is not None: | |
| model.set_lora_strength(args.lora_strength, lora_index=args.lora_index) | |
| # --- Load audio inputs --- | |
| # torchaudio.load returns (waveform, sample_rate); model.generate expects (sample_rate, waveform) | |
| init_audio = None | |
| if args.init_audio: | |
| waveform, sr = torchaudio.load(args.init_audio) | |
| init_audio = (sr, waveform) | |
| inpaint_audio = None | |
| if args.inpaint_audio: | |
| waveform, sr = torchaudio.load(args.inpaint_audio) | |
| inpaint_audio = (sr, waveform) | |
| inpaint_start = None | |
| inpaint_end = None | |
| if args.inpaint_starts: | |
| inpaint_start = ( | |
| args.inpaint_starts[0] | |
| if len(args.inpaint_starts) == 1 | |
| else args.inpaint_starts | |
| ) | |
| inpaint_end = ( | |
| args.inpaint_ends[0] if len(args.inpaint_ends) == 1 else args.inpaint_ends | |
| ) | |
| # --- Generate --- | |
| print("Generating…") | |
| audio = model.generate( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| duration=duration, | |
| steps=args.steps, | |
| cfg_scale=args.cfg_scale, | |
| seed=args.seed, | |
| batch_size=batch_size, | |
| init_audio=init_audio, | |
| init_noise_level=args.init_noise_level, | |
| inpaint_audio=inpaint_audio, | |
| inpaint_mask_start_seconds=inpaint_start, | |
| inpaint_mask_end_seconds=inpaint_end, | |
| chunked_decode=chunked_decode, | |
| ) | |
| _save_output(audio, model.model.sample_rate, args.output, batch_size) | |
| if __name__ == "__main__": | |
| main() | |