| import gc |
| import platform |
| import os |
| import subprocess as sp |
| import gradio as gr |
| import json |
| import torch |
| import torchaudio |
|
|
| from aeiou.viz import audio_spectrogram_image |
| from einops import rearrange |
| from safetensors.torch import load_file |
| from torch.nn import functional as F |
| from torchaudio import transforms as T |
|
|
| from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond |
| from ..models.factory import create_model_from_config |
| from ..models.pretrained import get_pretrained_model |
| from ..models.utils import load_ckpt_state_dict |
| from ..inference.utils import prepare_audio |
| from ..training.utils import copy_state_dict |
| from ..data.utils import read_video, merge_video_audio |
|
|
|
|
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| import warnings |
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
| device = torch.device("cpu") |
|
|
| os.environ['TMPDIR'] = './tmp' |
|
|
| current_model_name = None |
| current_model = None |
| current_sample_rate = None |
| current_sample_size = None |
|
|
|
|
|
|
| def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): |
| global model_configurations |
| |
| if pretrained_name is not None: |
| print(f"Loading pretrained model {pretrained_name}") |
| model, model_config = get_pretrained_model(pretrained_name) |
| elif model_config is not None and model_ckpt_path is not None: |
| print(f"Creating model from config") |
| model = create_model_from_config(model_config) |
| print(f"Loading model checkpoint from {model_ckpt_path}") |
| copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) |
| sample_rate = model_config["sample_rate"] |
| sample_size = model_config["sample_size"] |
| if pretransform_ckpt_path is not None: |
| print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") |
| model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) |
| print(f"Done loading pretransform") |
| model.to(device).eval().requires_grad_(False) |
| if model_half: |
| model.to(torch.float16) |
| print(f"Done loading model") |
| return model, model_config, sample_rate, sample_size |
|
|
| def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total): |
| if audio_path is None: |
| return torch.zeros((2, int(sample_rate * seconds_total))) |
| audio_tensor, sr = torchaudio.load(audio_path) |
| start_index = int(sample_rate * seconds_start) |
| target_length = int(sample_rate * seconds_total) |
| end_index = start_index + target_length |
| audio_tensor = audio_tensor[:, start_index:end_index] |
| if audio_tensor.shape[1] < target_length: |
| pad_length = target_length - audio_tensor.shape[1] |
| audio_tensor = F.pad(audio_tensor, (pad_length, 0)) |
| return audio_tensor |
|
|
| def generate_cond( |
| prompt, |
| negative_prompt=None, |
| video_file=None, |
| video_path=None, |
| audio_prompt_file=None, |
| audio_prompt_path=None, |
| seconds_start=0, |
| seconds_total=10, |
| cfg_scale=6.0, |
| steps=250, |
| preview_every=None, |
| seed=-1, |
| sampler_type="dpmpp-3m-sde", |
| sigma_min=0.03, |
| sigma_max=1000, |
| cfg_rescale=0.0, |
| use_init=False, |
| init_audio=None, |
| init_noise_level=1.0, |
| mask_cropfrom=None, |
| mask_pastefrom=None, |
| mask_pasteto=None, |
| mask_maskstart=None, |
| mask_maskend=None, |
| mask_softnessL=None, |
| mask_softnessR=None, |
| mask_marination=None, |
| batch_size=1 |
| ): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
| print(f"Prompt: {prompt}") |
| preview_images = [] |
| if preview_every == 0: |
| preview_every = None |
|
|
| try: |
| has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() |
| except Exception: |
| has_mps = False |
| if has_mps: |
| device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
| model_name = 'default' |
| cfg = model_configurations[model_name] |
| model_config_path = cfg.get("model_config") |
| ckpt_path = cfg.get("ckpt_path") |
| pretrained_name = cfg.get("pretrained_name") |
| pretransform_ckpt_path = cfg.get("pretransform_ckpt_path") |
| model_type = cfg.get("model_type", "diffusion_cond") |
| if model_config_path: |
| with open(model_config_path) as f: |
| model_config = json.load(f) |
| else: |
| model_config = None |
| target_fps = model_config.get("video_fps", 5) |
| global current_model_name, current_model, current_sample_rate, current_sample_size |
| if current_model is None or model_name != current_model_name: |
| current_model, model_config, sample_rate, sample_size = load_model( |
| model_name=model_name, |
| model_config=model_config, |
| model_ckpt_path=ckpt_path, |
| pretrained_name=pretrained_name, |
| pretransform_ckpt_path=pretransform_ckpt_path, |
| device=device, |
| model_half=False |
| ) |
| current_model_name = model_name |
| model = current_model |
| current_sample_rate = sample_rate |
| current_sample_size = sample_size |
| else: |
| model = current_model |
| sample_rate = current_sample_rate |
| sample_size = current_sample_size |
| if video_file is not None: |
| video_path = video_file.name |
| elif video_path: |
| video_path = video_path.strip() |
| else: |
| video_path = None |
| |
| if audio_prompt_file is not None: |
| print(f'audio_prompt_file: {audio_prompt_file}') |
| audio_path = audio_prompt_file.name |
| elif audio_prompt_path: |
| audio_path = audio_prompt_path.strip() |
| else: |
| audio_path = None |
| |
| Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) |
| audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) |
| |
| audio_tensor = audio_tensor.to(device) |
| seconds_input = sample_size / sample_rate |
| print(f'video_path: {video_path}') |
| |
| if not prompt: |
| prompt = "" |
| |
| conditioning = [{ |
| "video_prompt": [Video_tensors.unsqueeze(0)], |
| "text_prompt": prompt, |
| "audio_prompt": audio_tensor.unsqueeze(0), |
| "seconds_start": seconds_start, |
| "seconds_total": seconds_input |
| }] * batch_size |
| if negative_prompt: |
| negative_conditioning = [{ |
| "video_prompt": [Video_tensors.unsqueeze(0)], |
| "text_prompt": negative_prompt, |
| "audio_prompt": audio_tensor.unsqueeze(0), |
| "seconds_start": seconds_start, |
| "seconds_total": seconds_total |
| }] * batch_size |
| else: |
| negative_conditioning = None |
| try: |
| device = next(model.parameters()).device |
| except Exception as e: |
| device = next(current_model.parameters()).device |
| seed = int(seed) |
| if not use_init: |
| init_audio = None |
| input_sample_size = sample_size |
| if init_audio is not None: |
| in_sr, init_audio = init_audio |
| init_audio = torch.from_numpy(init_audio).float().div(32767) |
| if init_audio.dim() == 1: |
| init_audio = init_audio.unsqueeze(0) |
| elif init_audio.dim() == 2: |
| init_audio = init_audio.transpose(0, 1) |
| if in_sr != sample_rate: |
| resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) |
| init_audio = resample_tf(init_audio) |
| audio_length = init_audio.shape[-1] |
| if audio_length > sample_size: |
| input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length |
| init_audio = (sample_rate, init_audio) |
| def progress_callback(callback_info): |
| nonlocal preview_images |
| denoised = callback_info["denoised"] |
| current_step = callback_info["i"] |
| sigma = callback_info["sigma"] |
| if (current_step - 1) % preview_every == 0: |
| if model.pretransform is not None: |
| denoised = model.pretransform.decode(denoised) |
| denoised = rearrange(denoised, "b d n -> d (b n)") |
| denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
| audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) |
| preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) |
| if mask_cropfrom is not None: |
| mask_args = { |
| "cropfrom": mask_cropfrom, |
| "pastefrom": mask_pastefrom, |
| "pasteto": mask_pasteto, |
| "maskstart": mask_maskstart, |
| "maskend": mask_maskend, |
| "softnessL": mask_softnessL, |
| "softnessR": mask_softnessR, |
| "marination": mask_marination, |
| } |
| else: |
| mask_args = None |
| if model_type == "diffusion_cond": |
| audio = generate_diffusion_cond( |
| model, |
| conditioning=conditioning, |
| negative_conditioning=negative_conditioning, |
| steps=steps, |
| cfg_scale=cfg_scale, |
| batch_size=batch_size, |
| sample_size=input_sample_size, |
| sample_rate=sample_rate, |
| seed=seed, |
| device=device, |
| sampler_type=sampler_type, |
| sigma_min=sigma_min, |
| sigma_max=sigma_max, |
| init_audio=init_audio, |
| init_noise_level=init_noise_level, |
| mask_args=mask_args, |
| callback=progress_callback if preview_every is not None else None, |
| scale_phi=cfg_rescale |
| ) |
| elif model_type == "diffusion_uncond": |
| audio = generate_diffusion_uncond( |
| model, |
| steps=steps, |
| batch_size=batch_size, |
| sample_size=input_sample_size, |
| seed=seed, |
| device=device, |
| sampler_type=sampler_type, |
| sigma_min=sigma_min, |
| sigma_max=sigma_max, |
| init_audio=init_audio, |
| init_noise_level=init_noise_level, |
| callback=progress_callback if preview_every is not None else None |
| ) |
| else: |
| raise ValueError(f"Unsupported model type: {model_type}") |
| audio = rearrange(audio, "b d n -> d (b n)") |
| audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
| file_name = os.path.basename(video_path) if video_path else "output" |
| output_dir = f"demo_result" |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
| output_video_path = f"{output_dir}/{file_name}" |
| torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate) |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
| if video_path: |
| merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total) |
| audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) |
| del video_path |
| torch.cuda.empty_cache() |
| gc.collect() |
| return (output_video_path, f"{output_dir}/output.wav") |
|
|
| def toggle_custom_model(selected_model): |
| return gr.Row.update(visible=(selected_model == "Custom Model")) |
|
|
| def create_sampling_ui(model_config_map, inpainting=False): |
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """ |
| # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation |
| **[Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/Zeyue7/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** |
| """ |
| ) |
|
|
| with gr.Tab("Generation"): |
|
|
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt") |
| negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False) |
| video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path") |
| video_file = gr.File(label="Upload Video File") |
| audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False) |
| audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False) |
| with gr.Row(): |
| with gr.Column(scale=6): |
| with gr.Accordion("Video Params", open=False): |
| seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start") |
| seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False) |
| with gr.Row(): |
| with gr.Column(scale=4): |
| with gr.Accordion("Sampler Params", open=False): |
| steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") |
| preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") |
| cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale") |
| seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") |
| sampler_type_dropdown = gr.Dropdown( |
| ["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], |
| label="Sampler Type", |
| value="dpmpp-3m-sde" |
| ) |
| sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min") |
| sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max") |
| cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount") |
| with gr.Row(): |
| with gr.Column(scale=4): |
| with gr.Accordion("Init Audio", open=False, visible=False): |
| init_audio_checkbox = gr.Checkbox(label="Use Init Audio") |
| init_audio_input = gr.Audio(label="Init Audio") |
| init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level") |
| gr.Markdown("## Examples") |
| with gr.Accordion("Click to show examples", open=False): |
| with gr.Row(): |
| gr.Markdown("**📝 Task: Text-to-Audio**") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Typing on a keyboard*") |
| ex1 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Ocean waves crashing*") |
| ex2 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Footsteps in snow*") |
| ex3 = gr.Button("Load Example") |
| with gr.Row(): |
| gr.Markdown("**🎶 Task: Text-to-Music**") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") |
| ex4 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") |
| ex5 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") |
| ex6 = gr.Button("Load Example") |
| with gr.Row(): |
| gr.Markdown("**🎬 Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*") |
| with gr.Column(scale=1.2): |
| gr.Video("example/V2A_sample-1.mp4") |
| ex7 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Video("example/V2A_sample-2.mp4") |
| ex8 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Video("example/V2A_sample-3.mp4") |
| ex9 = gr.Button("Load Example") |
| with gr.Row(): |
| gr.Markdown("**🎵 Task: Video-to-Music**\nPrompt: *Generate music for the video*") |
| with gr.Column(scale=1.2): |
| gr.Video("example/V2M_sample-1.mp4") |
| ex10 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Video("example/V2M_sample-2.mp4") |
| ex11 = gr.Button("Load Example") |
| with gr.Column(scale=1.2): |
| gr.Video("example/V2M_sample-3.mp4") |
| ex12 = gr.Button("Load Example") |
| with gr.Row(): |
| generate_button = gr.Button("Generate", variant='primary', scale=1) |
| with gr.Row(): |
| with gr.Column(scale=6): |
| video_output = gr.Video(label="Output Video", interactive=False) |
| audio_output = gr.Audio(label="Output Audio", interactive=False) |
| send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False) |
| send_to_init_button.click( |
| fn=lambda audio: audio, |
| inputs=[audio_output], |
| outputs=[init_audio_input] |
| ) |
| inputs = [ |
| prompt, |
| negative_prompt, |
| video_file, |
| video_path, |
| audio_prompt_file, |
| audio_prompt_path, |
| seconds_start_slider, |
| seconds_total_slider, |
| cfg_scale_slider, |
| steps_slider, |
| preview_every_slider, |
| seed_textbox, |
| sampler_type_dropdown, |
| sigma_min_slider, |
| sigma_max_slider, |
| cfg_rescale_slider, |
| init_audio_checkbox, |
| init_audio_input, |
| init_noise_level_slider |
| ] |
| generate_button.click( |
| fn=generate_cond, |
| inputs=inputs, |
| outputs=[ |
| video_output, |
| audio_output |
| ], |
| api_name="generate" |
| ) |
| ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
| return demo |
|
|
| def create_txt2audio_ui(model_config_map): |
| with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui: |
| with gr.Tab("Generation"): |
| create_sampling_ui(model_config_map) |
| return ui |
|
|
| def toggle_custom_model(selected_model): |
| return gr.Row.update(visible=(selected_model == "Custom Model")) |
|
|
| def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): |
| global model_configurations |
| global device |
|
|
| try: |
| has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() |
| except Exception: |
| has_mps = False |
|
|
| if has_mps: |
| device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
|
|
| print("Using device:", device) |
|
|
| model_configurations = { |
| "default": { |
| "model_config": "./model/config.json", |
| "ckpt_path": "./model/model.ckpt" |
| } |
| } |
| ui = create_txt2audio_ui(model_configurations) |
| return ui |
|
|
| if __name__ == "__main__": |
| ui = create_ui( |
| model_config_path='./model/config.json', |
| share=True |
| ) |
| ui.launch() |
|
|