| import os |
| from pathlib import Path |
| from typing import List, Tuple |
| import uuid |
| import json |
| import argparse |
| import gradio as gr |
| import torch |
| import torchaudio |
| from safetensors.torch import load_file |
| from tqdm import tqdm |
|
|
| from model import LocalSongModel |
| from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
| class TagEmbedder: |
| def __init__(self, mapping_file: str = "checkpoints/tag_mapping.json"): |
|
|
| with open(mapping_file, 'r', encoding='utf-8') as f: |
| self.tag_mapping = json.load(f) |
|
|
| self.num_classes = 2304 |
|
|
| class AudioVAE: |
| def __init__(self, device: torch.device): |
| self.model = MusicDCAE().to(device) |
| self.model.eval() |
| self.device = device |
| self.latent_mean = torch.tensor( |
| [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], |
| device=device, |
| ).view(1, -1, 1, 1) |
| self.latent_std = torch.tensor( |
| [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], |
| device=device, |
| ).view(1, -1, 1, 1) |
|
|
| def decode(self, latents: torch.Tensor) -> torch.Tensor: |
| with torch.no_grad(): |
| latents = latents * self.latent_std + self.latent_mean |
| sr, audio_list = self.model.decode(latents, sr=48000) |
| audio_batch = torch.stack(audio_list).to(self.device) |
| return audio_batch |
|
|
| class RF: |
| def __init__(self, model: torch.nn.Module): |
| self.model = model |
|
|
| def sample( |
| self, |
| z: torch.Tensor, |
| cond: List[List[int]], |
| null_cond: List[List[int]] | None = None, |
| sample_steps: int = 100, |
| cfg: float = 3.0, |
| ) -> List[torch.Tensor]: |
| batch = z.size(0) |
| dt = 1.0 / sample_steps |
| dt = torch.tensor([dt] * batch, device=z.device).view([batch, *([1] * len(z.shape[1:]))]) |
| images = [z] |
| for i in tqdm(range(sample_steps, 0, -1), desc="Generating", unit="step"): |
| t = torch.tensor([i / sample_steps] * batch, device=z.device) |
|
|
| if null_cond is not None: |
|
|
| z_batched = torch.cat([z, z], dim=0) |
| t_batched = torch.cat([t, t], dim=0) |
| cond_batched = cond + null_cond |
| v_batched = self.model(z_batched, t_batched, cond_batched) |
| vc, vu = v_batched.chunk(2, dim=0) |
| vc = vu + cfg * (vc - vu) |
|
|
| else: |
| vc = self.model(z, t, cond) |
|
|
| z = z - dt * vc |
| images.append(z) |
| return images |
|
|
| model: torch.nn.Module | None = None |
| vae: AudioVAE | None = None |
| tag_embedder: TagEmbedder | None = None |
| rf_sampler: RF | None = None |
| device: torch.device | None = None |
| _available_tags: List[str] | None = None |
|
|
| def load_resources(checkpoint_path) -> List[str]: |
|
|
| torch.set_float32_matmul_precision('high') |
|
|
| global model, vae, tag_embedder, rf_sampler, device, _available_tags |
|
|
| if _available_tags is not None: |
| return _available_tags |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| tag_embedder = TagEmbedder() |
|
|
| model = LocalSongModel( |
| in_channels=8, |
| num_groups=16, |
| hidden_size=1024, |
| decoder_hidden_size=2048, |
| num_blocks=36, |
| patch_size=(16, 1), |
| num_classes=tag_embedder.num_classes, |
| max_tags=8, |
| ).to(device) |
|
|
| print(f"Loading checkpoint: {checkpoint_path}") |
|
|
| state_dict = load_file(checkpoint_path, device=str(device)) |
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
|
|
| vae = AudioVAE(device) |
| rf_sampler = RF(model) |
|
|
| _available_tags = sorted(tag_embedder.tag_mapping.keys()) |
| return _available_tags |
|
|
|
|
| def _tags_to_indices(tags: List[str]) -> List[int]: |
| assert tag_embedder is not None |
| indices = [] |
|
|
| for tag in tags: |
| tag_lower = tag.lower().strip() |
| if tag_lower in tag_embedder.tag_mapping: |
| indices.append(tag_embedder.tag_mapping[tag_lower]) |
|
|
| return indices |
|
|
|
|
| def generate_audio( |
| tags: List[str], |
| cfg: float, |
| sample_steps: int, |
| ) -> Tuple[Tuple[int, object], str]: |
|
|
| assert model is not None and vae is not None and rf_sampler is not None and device is not None |
|
|
| if not tags: |
| tags = [] |
| if len(tags) > 8: |
| raise gr.Error("A maximum of 8 tags is supported.") |
|
|
| tag_indices = _tags_to_indices(tags) |
|
|
| batch = 1 |
| channels = 8 |
| height = 16 |
| width = 512 |
|
|
| z = torch.randn(batch, channels, height, width, device=device) |
| cond = [tag_indices] |
| null_cond = [[]] |
|
|
| with torch.no_grad(): |
| sampled_latents = rf_sampler.sample( |
| z=z, |
| cond=cond, |
| null_cond=null_cond, |
| sample_steps=sample_steps, |
| cfg=cfg, |
| )[-1] |
| audio = vae.decode(sampled_latents) |
|
|
| audio_tensor = audio[0].cpu() |
| sr = 48000 |
| audio_numpy = audio_tensor.transpose(0, 1).numpy() |
|
|
| os.makedirs("generated", exist_ok=True) |
| output_path = f"generated/generated_{uuid.uuid4().hex}.wav" |
| torchaudio.save(str(output_path), audio_tensor, sr) |
|
|
| return (sr, audio_numpy), str(output_path) |
|
|
| def build_interface(checkpoint_path) -> gr.Blocks: |
| available_tags = load_resources(checkpoint_path) |
|
|
| |
| presets = [ |
| ["soundtrack1", "female vocalist","rock","melodic"], |
| ["soundtrack", "chrono trigger", "emotional", "piano", "strings"], |
| ["soundtrack", "touhou 10", "trumpet"], |
| ["soundtrack", "christmas music","winter","melodic"], |
| ["soundtrack2", "male vocalist","pop","melodic","acoustic guitar","ballad"], |
| ] |
|
|
| with gr.Blocks(title="LocalSong") as demo: |
| gr.Markdown("# LocalSong") |
|
|
| with gr.Row(): |
| tag_input = gr.Dropdown( |
| label="Tags (select up to 8)", |
| choices=available_tags, |
| multiselect=True, |
| max_choices=8, |
| value=presets[0], |
| ) |
|
|
| gr.Markdown("**Presets:**") |
| with gr.Row(): |
| for preset in presets: |
| btn = gr.Button(f"{' + '.join(preset)}", size="sm") |
| def make_preset_fn(p): |
| return lambda: p |
| btn.click(fn=make_preset_fn(preset), inputs=None, outputs=tag_input) |
|
|
| with gr.Row(): |
| cfg_slider = gr.Slider( |
| label="CFG Scale", |
| minimum=1.0, |
| maximum=7.0, |
| step=0.5, |
| value=3.5, |
| ) |
| sample_steps_slider = gr.Slider( |
| label="Sample Steps", |
| minimum=50, |
| maximum=200, |
| step=10, |
| value=200, |
| ) |
|
|
| with gr.Row(): |
| seed_input = gr.Number( |
| label="Seed", |
| value=45, |
| precision=0, |
| ) |
|
|
| generate_button = gr.Button("Generate Audio", variant="primary") |
| audio_output = gr.Audio(label="Generated Audio", type="numpy") |
| download_output = gr.File(label="Download WAV") |
|
|
| def generate_wrapper(tags, cfg, steps, seed): |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| return generate_audio(tags, cfg, steps) |
|
|
| generate_button.click( |
| fn=generate_wrapper, |
| inputs=[ |
| tag_input, |
| cfg_slider, |
| sample_steps_slider, |
| seed_input, |
| ], |
| outputs=[ |
| audio_output, |
| download_output, |
| ], |
| ) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="LocalSong Gradio Interface") |
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| default="checkpoints/checkpoint_461260.safetensors", |
| help="Path to the model checkpoint" |
| ) |
| args = parser.parse_args() |
|
|
| demo = build_interface(args.checkpoint) |
|
|
| demo.launch() |
|
|