import json import numpy as np import torch import typing as tp from torch.nn.functional import interpolate from stable_audio_3.inference.audio_utils import prepare_audio, numpy_audio_to_tensor from stable_audio_3.inference.sampling import sample_diffusion from stable_audio_3.loading_utils import load_autoencoder, load_diffusion_cond from stable_audio_3.model_configs import ae_models, all_models from stable_audio_3.models.lora import ( set_lora_strength as _set_lora_strength, load_and_apply_loras, ) class StableAudioModel: def __init__(self, model, model_config, device, model_half): self.model = model self.model_config = model_config self.device = device self.model_half = model_half self.same = self.model.pretransform self.dit = self.model.model torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.benchmark = False @staticmethod def from_pretrained(model_name_or_path, device=None, model_half=True): # Load the model and any necessary components here if device is None and torch.cuda.is_available(): device = "cuda" elif device is None and torch.backends.mps.is_available(): device = "mps" elif device is None: device = "cpu" if not torch.cuda.is_available(): if model_name_or_path in ("medium", "medium-base"): print( f"Warning: You are loading the {model_name_or_path} model without a GPU. This model is not designed to run on cpu" ) model_half = False if model_name_or_path not in all_models: raise ValueError( f"Unknown model '{model_name_or_path}'. Valid models: {list(all_models)}" ) model_cfg = all_models[model_name_or_path] local_config, local_ckpt = model_cfg.resolve() with open(local_config) as f: model_config = json.load(f) model = load_diffusion_cond( model_config, local_ckpt, device=device, model_half=model_half ) model.use_lora = False model.lora_names = [] return StableAudioModel(model, model_config, device, model_half) def load_lora(self, lora_ckpt_paths): """Load LoRA checkpoints onto the model after construction.""" model_type = self.model_config["model_type"] svd_bases_path = self.model_config.get("svd_bases_path") load_and_apply_loras( self.model, lora_ckpt_paths, model_type, svd_bases_path=svd_bases_path ) def set_lora_strength(self, strength: float, lora_index: int | None = None): _set_lora_strength(self.model.model, strength, lora_index=lora_index) _set_lora_strength(self.model.conditioner, strength, lora_index=lora_index) @torch.inference_mode() def generate( self, # Simple path: pass a prompt string and duration prompt: str | list = None, negative_prompt: str | list = None, duration: float | list = 120, # Generation parameters steps: int = 8, cfg_scale: float = 1.0, batch_size: int = 1, sample_size: int = 5292032, truncate_output_to_duration: bool = True, # Low-level path: pass pre-built conditioning dicts conditioning: tp.Optional[tp.List[dict]] = None, conditioning_tensors: tp.Optional[dict] = None, negative_conditioning: tp.Optional[tp.List[dict]] = None, negative_conditioning_tensors: tp.Optional[dict] = None, seed: int = -1, # Audio inputs init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, init_noise_level: float = 1.0, inpaint_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, inpaint_mask=None, inpaint_mask_start_seconds: tp.Optional[tp.Union[float, tp.List[float]]] = None, inpaint_mask_end_seconds: tp.Optional[tp.Union[float, tp.List[float]]] = None, # Schedule options duration_padding_sec: float = 6.0, apg_scale: float = 1.0, dist_shift=None, return_latents: bool = False, chunked_decode: tp.Optional[bool] = None, **sampler_kwargs, ) -> torch.Tensor: """ Generate audio. Simple path: model.generate(prompt="...", duration=30, steps=100) Low-level path (pre-built conditioning): model.generate(conditioning=[{"prompt": "...", "seconds_total": 30}], steps=100, ...) Args: prompt: The text prompt to condition on. Ignored if conditioning dicts are provided directly. negative_prompt: The negative text prompt for classifier-free guidance. Ignored if negative_conditioning dicts are provided directly. duration: The duration of the generated audio in seconds. Only used if conditioning dicts with "seconds_total" are not provided. steps: The number of diffusion steps to use. cfg_scale: Classifier-free guidance scale batch_size: The batch size to use for generation. sample_size: The length of the audio to generate, in samples. truncate_output_to_duration: If True, truncate the output audio to the specified duration. conditioning: A dictionary of conditioning parameters to use for generation. conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. negative_conditioning: A dictionary of negative conditioning parameters for classifier-free guidance. negative_conditioning_tensors: A dictionary of precomputed negative conditioning tensors for classifier-free guidance seed: The random seed to use for generation, or -1 to use a random seed. init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. init_noise_level: The noise level to use when generating from an initial audio sample. inpaint_audio: A tuple of (sample_rate, audio) to use as the source audio for inpainting. The inpaint region will be determined by the inpaint_mask or inpaint_mask_start_seconds/inpaint_mask_end_seconds parameters. inpaint_mask: A prebuilt mask tensor for inpainting. Shape should be [batch_size, sample_size]. Ignored if inpaint_mask_start_seconds/inpaint_mask_end_seconds are provided. inpaint_mask_start_seconds: Start of the inpaint region in seconds. Can be a float for a single region, or a list of floats for multiple non-contiguous regions. inpaint_mask_end_seconds: End of the inpaint region in seconds. Can be a float for a single region, or a list of floats matching inpaint_mask_start_seconds. duration_padding_sec: Extra seconds to add when adapting duration (default 6.0). apg_scale: APG (Adaptive Projected Guidance) scale. 1.0 = full APG, 0.0 = vanilla CFG. dist_shift: Optional distribution shift override for sampling. If None, uses model.sampling_dist_shift. return_latents: Whether to return the latents used for generation instead of the decoded audio. chunked_decode: Whether to decode latents in overlapping chunks to reduce peak VRAM. True forces chunked decoding on, False forces it off, None (default) uses the value set in the model config. **sampler_kwargs: Additional keyword arguments to pass to the sampler. """ device = str(self.device) # Build conditioning from prompt string if not provided directly if conditioning is None and conditioning_tensors is None: assert prompt is not None, "Must provide either prompt or conditioning" conditioning, negative_conditioning = self._build_conditioning_dicts( prompt, negative_prompt, duration, batch_size ) # Adapt sample size based on seconds_total in conditioning audio_sample_size = sample_size if conditioning is not None: audio_sample_size = self._adapt_sample_size( conditioning, sample_size, duration_padding_sec, ) # Convert audio sample size to latent size latent_sample_size = audio_sample_size if self.model.pretransform is not None: latent_sample_size = ( audio_sample_size // self.model.pretransform.downsampling_ratio ) # Build inpaint mask from seconds if provided if ( inpaint_mask_start_seconds is not None and inpaint_mask_end_seconds is not None ): start_is_list = isinstance(inpaint_mask_start_seconds, list) end_is_list = isinstance(inpaint_mask_end_seconds, list) if start_is_list != end_is_list: raise ValueError( "inpaint_mask_start_seconds and inpaint_mask_end_seconds must both be " "scalars or both be lists, got " f"{type(inpaint_mask_start_seconds).__name__} and " f"{type(inpaint_mask_end_seconds).__name__}." ) starts = ( inpaint_mask_start_seconds if start_is_list else [inpaint_mask_start_seconds] ) ends = ( inpaint_mask_end_seconds if end_is_list else [inpaint_mask_end_seconds] ) if len(starts) != len(ends): raise ValueError( f"inpaint_mask_start_seconds and inpaint_mask_end_seconds must have the same " f"length, got {len(starts)} and {len(ends)}." ) inpaint_mask = torch.ones(1, audio_sample_size, device=device) for start_sec, end_sec in zip(starts, ends): mask_start_samples = min( int(start_sec * self.model.sample_rate), audio_sample_size, ) mask_end_samples = min( int(end_sec * self.model.sample_rate), audio_sample_size, ) inpaint_mask[:, mask_start_samples:mask_end_samples] = 0 # If the caller passed a prebuilt mask sized to the un-adapted sample_size (or # anything longer than audio_sample_size), truncate to audio_sample_size so the # downstream nearest-neighbor interpolation preserves the mask's time-domain # positions instead of squashing the mask region. if inpaint_mask is not None and inpaint_mask.shape[-1] > audio_sample_size: inpaint_mask = inpaint_mask[:, :audio_sample_size] # Match training: when mask_padding_attention is used, random_inpaint_mask # zeroes the mask past real_sequence_length. Apply the # same convention here so the mask matches the training distribution, whether # it was built from seconds above or passed in by the caller. if inpaint_mask is not None and conditioning is not None: max_seconds = max( (c.get("seconds_total", 0.0) for c in conditioning), default=0.0 ) if max_seconds > 0: effective_audio_len = int(max_seconds * self.model.sample_rate) mask_len = inpaint_mask.shape[-1] if effective_audio_len < mask_len: inpaint_mask = inpaint_mask.clone() inpaint_mask[:, effective_audio_len:] = 0 if inpaint_mask is not None: inpaint_mask = inpaint_mask.float() # Seed and noise seed = seed if seed != -1 else np.random.randint(0, 99999) torch.manual_seed(seed) noise = torch.randn( [batch_size, self.model.io_channels, latent_sample_size], device=device ) # Encode conditioning if conditioning_tensors is None: conditioning_tensors = self.model.conditioner(conditioning, device) if ( negative_conditioning is not None or negative_conditioning_tensors is not None ): if negative_conditioning_tensors is None: negative_conditioning_tensors = self.model.conditioner( negative_conditioning, device ) else: negative_conditioning_tensors = {} # Process init audio if init_audio is not None: init_audio, inpaint_mask = self._encode_audio_input( init_audio, audio_sample_size, inpaint_mask ) init_audio = init_audio.repeat(batch_size, 1, 1) # Process inpaint audio if inpaint_audio is not None: inpaint_audio, inpaint_mask = self._encode_audio_input( inpaint_audio, audio_sample_size, inpaint_mask ) inpaint_audio = inpaint_audio.repeat(batch_size, 1, 1) else: if inpaint_mask is not None: inpaint_mask = interpolate( inpaint_mask.unsqueeze(1), size=latent_sample_size, mode="nearest" ).squeeze(1) # Build inpaint mask tensor and masked input if inpaint_mask is None: mask = torch.zeros((batch_size, 1, latent_sample_size), device=device) else: mask = inpaint_mask.unsqueeze(1) mask = mask.to(device) inpaint_input = ( inpaint_audio * mask.expand_as(inpaint_audio) if inpaint_audio is not None else torch.zeros( (batch_size, self.model.io_channels, latent_sample_size), device=device ) ) conditioning_tensors["inpaint_mask"] = [mask] conditioning_tensors["inpaint_masked_input"] = [inpaint_input] conditioning_inputs = self.model.get_conditioning_inputs(conditioning_tensors) if negative_conditioning_tensors: negative_conditioning_tensors["inpaint_mask"] = [mask] negative_conditioning_tensors["inpaint_masked_input"] = [inpaint_input] negative_conditioning_tensors = self.model.get_conditioning_inputs( negative_conditioning_tensors, negative=True ) model_dtype = next(self.model.model.parameters()).dtype noise = noise.type(model_dtype) conditioning_inputs = { k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items() } cond_inputs = {**conditioning_inputs, **negative_conditioning_tensors} sampler_type = sampler_kwargs.pop("sampler_type", None) result = sample_diffusion( model=self.model.model, noise=noise, cond_inputs=cond_inputs, diffusion_objective=self.model.diffusion_objective, steps=steps, cfg_scale=cfg_scale, conditioning=conditioning, sample_rate=self.model.sample_rate, pretransform=self.model.pretransform, mask_padding_attention=True, use_effective_length_for_schedule=True, headroom_seconds=duration_padding_sec, dist_shift=dist_shift if dist_shift is not None else self.model.sampling_dist_shift, sampler_type=sampler_type, batch_cfg=True, rescale_cfg=True, apg_scale=apg_scale, init_data=init_audio, init_noise_level=init_noise_level, decode=not return_latents, chunked_decode=chunked_decode, **sampler_kwargs, ) if not return_latents: result = result.to(torch.float32).clamp(-1, 1) if not return_latents and truncate_output_to_duration: if isinstance(duration, (int, float)): max_length_samples = int(duration * self.model.sample_rate) result = result[:, :, :max_length_samples] else: if torch.all(torch.tensor(duration) == duration[0]): max_length_samples = int(duration[0] * self.model.sample_rate) result = result[:, :, :max_length_samples] else: # Warn that we can't truncate to a single duration if the durations are different, and return the full length output print( "Warning: Cannot truncate output to a single duration when passing a list of different durations" ) return result # --- generate() helpers --- @staticmethod def _build_conditioning_dicts(prompt, negative_prompt, duration, batch_size): """Returns (conditioning, negative_conditioning) lists of dicts.""" def _to_list(value, name): """Broadcast a scalar or validate a sequence to length batch_size.""" if isinstance(value, (list, tuple)): assert len(value) == batch_size, ( f"Length of {name} ({len(value)}) must match batch_size ({batch_size})" ) return list(value) return [value] * batch_size prompts = _to_list(prompt, "prompt") durations = _to_list(duration, "duration") conditioning = [ {"prompt": p, "seconds_total": d} for p, d in zip(prompts, durations) ] negative_conditioning = None if negative_prompt is not None: neg_prompts = _to_list(negative_prompt, "negative_prompt") negative_conditioning = [ {"prompt": p, "seconds_total": d} for p, d in zip(neg_prompts, durations) ] return conditioning, negative_conditioning def _adapt_sample_size(self, conditioning, sample_size, duration_padding_sec): """Returns audio_sample_size adapted from conditioning, clamped to sample_size.""" max_seconds = 0.0 for cond_dict in conditioning: if "seconds_total" in cond_dict: max_seconds = max(max_seconds, cond_dict["seconds_total"]) if max_seconds <= 0: return sample_size target_audio_samples = int( (max_seconds + duration_padding_sec) * self.model.sample_rate ) if self.model.pretransform is not None: ds_ratio = self.model.pretransform.downsampling_ratio # Round up to nearest multiple of downsampling ratio target_audio_samples = ( (target_audio_samples + ds_ratio - 1) // ds_ratio ) * ds_ratio encoder_config = self.model_config["model"]["pretransform"]["config"][ "encoder" ]["config"] chunk_size = encoder_config.get("chunk_size", 32) stride = encoder_config["strides"][0] # or min(strides) if multiple # For chunked attention with latent space, align to chunk size after downsampling latent_align = chunk_size // stride align = ds_ratio * latent_align target_audio_samples = ((target_audio_samples + align - 1) // align) * align return min(target_audio_samples, sample_size) def _encode_audio_input(self, audio_input, audio_sample_size, inpaint_mask=None): """ Converts a (sample_rate, audio) tuple to an encoded latent tensor. If model has a pretransform, encodes to latent space and downsamples inpaint_mask to match. Returns (encoded_audio, updated_inpaint_mask). encoded_audio is not yet repeated to batch size. """ device = str(self.device) in_sr, audio_data = audio_input if isinstance(audio_data, np.ndarray): audio_data = numpy_audio_to_tensor(audio_data) io_channels = ( self.model.pretransform.io_channels if self.model.pretransform is not None else self.model.io_channels ) audio = prepare_audio( audio_data, in_sr=in_sr, target_sr=self.model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device, ) if self.model.pretransform is not None: audio = audio.to(next(self.model.pretransform.parameters()).dtype) audio = self.model.pretransform.encode(audio) if inpaint_mask is not None: inpaint_mask = interpolate( inpaint_mask.unsqueeze(1), size=audio.shape[-1], mode="nearest", ).squeeze(1) return audio, inpaint_mask class AutoencoderModel: def __init__(self, autoencoder, sample_rate, device): self.autoencoder = autoencoder self.sample_rate = sample_rate self.device = device @staticmethod def from_pretrained(model_name, device=None): if device is None: if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" if not torch.cuda.is_available(): if model_name == "same-l": print( f"Warning: You are loading the {model_name} model without a GPU. This model is not designed to run on cpu" ) if model_name not in ae_models: raise ValueError( f"Unknown autoencoder '{model_name}'. Valid models: {list(ae_models)}" ) cfg = ae_models[model_name] local_config, local_ckpt = cfg.resolve() with open(local_config) as f: sample_rate = json.load(f)["sample_rate"] autoencoder = load_autoencoder(local_config, local_ckpt, device=device) autoencoder.eval().requires_grad_(False) return AutoencoderModel(autoencoder, sample_rate, device) @torch.inference_mode() def encode(self, audio, sr, chunked=False, chunk_size=128, overlap=32): """Encode audio to latents. Args: audio: A single waveform tensor (C, T), a list of waveform tensors, or a pre-batched tensor (B, C, T). Resampling, channel conversion, and padding are handled automatically; passing sr=ae.sample_rate for already-preprocessed audio skips resampling. sr: Sample rate of the input audio, or a list of sample rates when audio is a list. chunked: If True, encode in overlapping chunks to save memory. chunk_size: Chunk size in latent frames (only used when chunked=True). overlap: Overlap in latent frames between chunks (only used when chunked=True). Returns: Latent tensor of shape (B, latent_dim, latent_time). """ if isinstance(audio, list): preprocessed = self.autoencoder.preprocess_audio_list_for_encoder( audio, in_sr_list=sr ) elif isinstance(audio, torch.Tensor) and audio.dim() == 3: sr_list = sr if isinstance(sr, list) else [sr] * audio.shape[0] preprocessed = self.autoencoder.preprocess_audio_list_for_encoder( list(audio), in_sr_list=sr_list ) else: preprocessed = self.autoencoder.preprocess_audio_for_encoder( audio, in_sr=sr ) return self.autoencoder.encode_audio( preprocessed.to(self.device), chunked=chunked, chunk_size=chunk_size, overlap=overlap, ) @torch.inference_mode() def decode(self, latents, chunked=False, chunk_size=128, overlap=32): """Decode latents to audio. Args: latents: Latent tensor of shape (B, latent_dim, latent_time). chunked: If True, decode in overlapping chunks to save memory. chunk_size: Chunk size in latent frames (only used when chunked=True). overlap: Overlap in latent frames between chunks (only used when chunked=True). Returns: Audio tensor of shape (B, channels, samples). """ return self.autoencoder.decode_audio( latents, chunked=chunked, chunk_size=chunk_size, overlap=overlap, )