| |
| |
| """ |
| HF-driven inference for MAGEL with segment-level autoregressive generation. |
| |
| Uses from HF sample: |
| - text instruction/template tokens (token_ids scaffold) |
| - control tokens: chord_ids/structure_ids |
| |
| Does NOT use: |
| - ground-truth audio token values as input (audio codebook positions are masked) |
| """ |
|
|
| import argparse |
| import contextlib |
| import importlib |
| import json |
| import os |
| import sys |
| from dataclasses import dataclass |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from runtime_utils import ( |
| load_magel_checkpoint, |
| load_music_dataset, |
| maybe_compile_model, |
| maybe_mark_compile_step_begin, |
| resolve_device, |
| seed_everything, |
| ) |
| from vocab import ( |
| CHORD_BOS_ID, |
| CHORD_EOS_ID, |
| STRUCTURE_EOS_ID, |
| chord_id_to_label, |
| structure_id_to_label, |
| ) |
| from modelling_qwen3 import MAGEL |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
| MUCODEC_ROOT = REPO_ROOT / "MuCodec" |
|
|
|
|
| @dataclass |
| class TokenLayout: |
| num_text_token: int |
| num_audio_codebook: int = 16384 |
|
|
| @property |
| def audio_start(self) -> int: |
| return self.num_text_token |
|
|
| @property |
| def audio_end(self) -> int: |
| return self.num_text_token + self.num_audio_codebook |
|
|
| @property |
| def mask_audio(self) -> int: |
| return self.audio_end |
|
|
| @property |
| def bos_audio(self) -> int: |
| return self.audio_end + 1 |
|
|
| @property |
| def eos_audio(self) -> int: |
| return self.audio_end + 2 |
|
|
|
|
| @dataclass |
| class SegmentSpan: |
| seg_idx: int |
| bos_pos: int |
| eos_pos: int |
| audio_positions: list[int] |
|
|
|
|
| @dataclass |
| class HFTemplateSample: |
| song_id: str |
| num_text_token: int |
| template_ids: torch.Tensor |
| input_ids: torch.Tensor |
| chord_ids: torch.Tensor |
| structure_ids: torch.Tensor |
| condition_mask: torch.Tensor |
| is_audio_codebook: torch.Tensor |
| is_eos: torch.Tensor |
| segments: list[SegmentSpan] |
| raw_item: dict[str, Any] |
|
|
| @property |
| def seq_len(self) -> int: |
| return int(self.input_ids.numel()) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Segment-wise AR generation from HF controls/scaffold." |
| ) |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="./output_qwen3_0p6b_train/final", |
| ) |
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default="muse_mucodec_chord.ds", |
| ) |
| parser.add_argument("--split", type=str, default="validation") |
| parser.add_argument("--sample_idx", type=int, default=0) |
| parser.add_argument( |
| "--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B" |
| ) |
| parser.add_argument( |
| "--num_audio_codebook", |
| type=int, |
| default=None, |
| help="Audio codebook size. Defaults to checkpoint metadata when available.", |
| ) |
|
|
| parser.add_argument("--temperature", type=float, default=1.0) |
| parser.add_argument("--top_k", type=int, default=50) |
| parser.add_argument("--top_p", type=float, default=0.90) |
| parser.add_argument("--greedy", action="store_true", default=False) |
| parser.add_argument("--max_audio_tokens", type=int, default=0) |
| parser.add_argument("--fps", type=int, default=25) |
|
|
| parser.add_argument("--seed", type=int, default=1234) |
| parser.add_argument("--device", type=str, default="auto") |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| ) |
| parser.add_argument("--use_cache", action="store_true", default=True) |
| parser.add_argument("--no_cache", action="store_true", default=False) |
| parser.add_argument("--compile", action="store_true", default=False) |
| parser.add_argument( |
| "--compile_mode", |
| type=str, |
| default="reduce-overhead", |
| choices=["default", "reduce-overhead", "max-autotune"], |
| ) |
| parser.add_argument( |
| "--attn_implementation", |
| type=str, |
| default="sdpa", |
| choices=["eager", "sdpa", "flash_attention_2"], |
| ) |
| parser.add_argument("--output_dir", type=str, default="predictions") |
| parser.add_argument("--output_prefix", type=str, default="") |
| parser.add_argument( |
| "--json_output_dir", |
| type=str, |
| default="predictions/json", |
| help="Directory for chord/segment json. Default: <output_dir>/json", |
| ) |
| parser.add_argument( |
| "--mucodec_device", |
| type=str, |
| default="auto", |
| help="Device string for MuCodec, for example cuda:0.", |
| ) |
| parser.add_argument( |
| "--mucodec_layer_num", |
| type=int, |
| default=7, |
| help="MuCodec layer_num passed to the official decoder.", |
| ) |
| parser.add_argument( |
| "--mucodec_duration", |
| type=float, |
| default=40.96, |
| help="Chunk duration argument passed to MuCodec code2sound.", |
| ) |
| parser.add_argument( |
| "--mucodec_guidance_scale", |
| type=float, |
| default=1.5, |
| help="Guidance scale argument passed to MuCodec code2sound.", |
| ) |
| parser.add_argument( |
| "--mucodec_num_steps", |
| type=int, |
| default=20, |
| help="Sampling steps argument passed to MuCodec code2sound.", |
| ) |
| parser.add_argument( |
| "--mucodec_sample_rate", |
| type=int, |
| default=48000, |
| help="Sample rate used when saving decoded wav.", |
| ) |
| parser.add_argument( |
| "--wav_output_dir", |
| type=str, |
| default="predictions/wav", |
| help="Directory for decoded wav. Default: <output_dir>/wav", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def resolve_runtime_device_str(device_arg: str) -> str: |
| if device_arg != "auto": |
| return device_arg |
| if torch.cuda.is_available(): |
| return "cuda:0" |
| if torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
|
|
| @contextlib.contextmanager |
| def pushd(path: str): |
| prev = os.getcwd() |
| os.chdir(path) |
| try: |
| yield |
| finally: |
| os.chdir(prev) |
|
|
|
|
| def ensure_sys_path(path: str) -> None: |
| if path and path not in sys.path: |
| sys.path.insert(0, path) |
|
|
|
|
| def get_mucodec_root() -> str: |
| if not MUCODEC_ROOT.is_dir(): |
| raise FileNotFoundError(f"MuCodec directory not found: {MUCODEC_ROOT}") |
| if not (MUCODEC_ROOT / "generate.py").is_file(): |
| raise FileNotFoundError( |
| f"MuCodec entrypoint not found: {MUCODEC_ROOT / 'generate.py'}" |
| ) |
| return str(MUCODEC_ROOT) |
|
|
|
|
| def import_mucodec_class(): |
| repo_path = get_mucodec_root() |
| ensure_sys_path(repo_path) |
| try: |
| module = importlib.import_module("generate") |
| return getattr(module, "MuCodec"), repo_path |
| except Exception as exc: |
| raise ImportError(f"Could not import MuCodec from {repo_path}/generate.py: {exc}") |
|
|
|
|
| def build_mucodec_decoder(args: argparse.Namespace) -> Any: |
| MuCodec, resolved_repo = import_mucodec_class() |
|
|
| ckpt_path = os.path.join(resolved_repo, "ckpt", "mucodec.pt") |
| if not os.path.exists(ckpt_path): |
| raise FileNotFoundError(f"MuCodec checkpoint not found: {ckpt_path}") |
|
|
| required_local_files = [ |
| os.path.join(resolved_repo, "tools", "audioldm_48k.pth"), |
| os.path.join(resolved_repo, "muq_dev", "muq.pt"), |
| ] |
| for path in required_local_files: |
| if not os.path.exists(path): |
| raise FileNotFoundError( |
| f"Required MuCodec dependency not found for current folder structure: {path}" |
| ) |
|
|
| mucodec_device = resolve_runtime_device_str(args.mucodec_device) |
| if resolved_repo: |
| print(f"[INFO] resolved MuCodec repo: {resolved_repo}") |
| print(f"[INFO] loading MuCodec from {ckpt_path} on {mucodec_device}") |
| with pushd(resolved_repo): |
| decoder = MuCodec( |
| model_path=ckpt_path, |
| layer_num=int(args.mucodec_layer_num), |
| load_main_model=True, |
| device=mucodec_device, |
| ) |
| setattr(decoder, "_magel_mucodec_repo", resolved_repo) |
| return decoder |
|
|
|
|
| def decode_mucodec_codes( |
| mucodec_decoder: Any, |
| shifted_codes: np.ndarray, |
| args: argparse.Namespace, |
| ) -> torch.Tensor: |
| if shifted_codes.ndim != 1: |
| raise ValueError( |
| f"Expected 1D MuCodec token stream, got shape {shifted_codes.shape}" |
| ) |
|
|
| codes = torch.from_numpy(shifted_codes.astype(np.int64, copy=False)) |
| codes = codes.unsqueeze(0).unsqueeze(0) |
| repo_path = getattr(mucodec_decoder, "_magel_mucodec_repo", "") |
| decode_ctx = pushd(repo_path) if repo_path else contextlib.nullcontext() |
| with decode_ctx: |
| wave = mucodec_decoder.code2sound( |
| codes, |
| prompt=None, |
| duration=float(args.mucodec_duration), |
| guidance_scale=float(args.mucodec_guidance_scale), |
| num_steps=int(args.mucodec_num_steps), |
| disable_progress=True, |
| ) |
| if not torch.is_tensor(wave): |
| wave = torch.as_tensor(wave) |
| if wave.ndim == 1: |
| wave = wave.unsqueeze(0) |
| return wave.detach().cpu().to(torch.float32) |
|
|
|
|
| def build_segment_spans( |
| template_ids: torch.Tensor, |
| is_audio_codebook: torch.Tensor, |
| layout: TokenLayout, |
| ) -> list[SegmentSpan]: |
| bos_positions = torch.where(template_ids.eq(layout.bos_audio))[0].tolist() |
| eos_positions = torch.where(template_ids.eq(layout.eos_audio))[0].tolist() |
| if not bos_positions or not eos_positions: |
| return [] |
|
|
| spans: list[SegmentSpan] = [] |
| eos_ptr = 0 |
| for b in bos_positions: |
| while eos_ptr < len(eos_positions) and eos_positions[eos_ptr] <= b: |
| eos_ptr += 1 |
| if eos_ptr >= len(eos_positions): |
| break |
| e = eos_positions[eos_ptr] |
| eos_ptr += 1 |
| idx = torch.arange(template_ids.numel(), device=template_ids.device) |
| mask = is_audio_codebook & (idx > b) & (idx < e) |
| audio_positions = torch.where(mask)[0].tolist() |
| spans.append( |
| SegmentSpan( |
| seg_idx=len(spans), |
| bos_pos=int(b), |
| eos_pos=int(e), |
| audio_positions=[int(p) for p in audio_positions], |
| ) |
| ) |
| return spans |
|
|
|
|
| def load_hf_template_sample( |
| dataset_path: str, |
| split: str, |
| tokenizer_path: str, |
| sample_idx: int, |
| num_audio_codebook: int, |
| ) -> HFTemplateSample: |
| music_ds = load_music_dataset( |
| dataset_path=dataset_path, |
| split=split, |
| tokenizer_path=tokenizer_path, |
| num_audio_token=num_audio_codebook, |
| use_fast=True, |
| ) |
| return load_hf_template_sample_from_music_dataset( |
| music_ds=music_ds, |
| sample_idx=sample_idx, |
| num_audio_codebook=num_audio_codebook, |
| ) |
|
|
|
|
| def load_hf_template_sample_from_music_dataset( |
| music_ds, |
| sample_idx: int, |
| num_audio_codebook: int, |
| ) -> HFTemplateSample: |
| layout = TokenLayout( |
| num_text_token=music_ds.num_text_token, |
| num_audio_codebook=num_audio_codebook, |
| ) |
|
|
| raw_item = music_ds._data[sample_idx] |
| row = music_ds[sample_idx] |
|
|
| template_ids = row["token_ids"].to(torch.long) |
| chord_ids = row["chord_ids"].to(torch.long) |
| structure_ids = row["structure_ids"].to(torch.long) |
| condition_mask = row["condition_mask"].to(torch.bool) |
|
|
| seq_len = int(template_ids.numel()) |
| for name, t in [ |
| ("chord_ids", chord_ids), |
| ("structure_ids", structure_ids), |
| ("condition_mask", condition_mask), |
| ]: |
| if int(t.numel()) != seq_len: |
| raise ValueError(f"{name} length mismatch: {int(t.numel())} != {seq_len}") |
|
|
| is_audio_codebook = (template_ids >= layout.audio_start) & ( |
| template_ids < layout.audio_end |
| ) |
| is_eos = template_ids.eq(layout.eos_audio) |
|
|
| |
| input_ids = template_ids.clone() |
| input_ids[is_audio_codebook] = layout.mask_audio |
|
|
| spans = build_segment_spans(template_ids, is_audio_codebook, layout) |
|
|
| return HFTemplateSample( |
| song_id=str(raw_item.get("song_id", f"sample_{sample_idx}")), |
| num_text_token=music_ds.num_text_token, |
| template_ids=template_ids, |
| input_ids=input_ids, |
| chord_ids=chord_ids, |
| structure_ids=structure_ids, |
| condition_mask=condition_mask, |
| is_audio_codebook=is_audio_codebook, |
| is_eos=is_eos, |
| segments=spans, |
| raw_item=raw_item, |
| ) |
|
|
|
|
| def apply_top_k_top_p(logits: torch.Tensor, top_k: int, top_p: float) -> torch.Tensor: |
| if top_k is not None and top_k > 0: |
| k = min(top_k, logits.shape[-1]) |
| values, _ = torch.topk(logits, k, dim=-1) |
| kth = values[:, -1].unsqueeze(-1) |
| logits = logits.masked_fill(logits < kth, float("-inf")) |
|
|
| if top_p is not None and 0.0 < top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) |
| sorted_probs = torch.softmax(sorted_logits, dim=-1) |
| cum_probs = torch.cumsum(sorted_probs, dim=-1) |
| remove_mask = cum_probs > top_p |
| remove_mask[:, 0] = False |
| sorted_logits = sorted_logits.masked_fill(remove_mask, float("-inf")) |
| filtered = torch.full_like(logits, float("-inf")) |
| filtered.scatter_(dim=-1, index=sorted_idx, src=sorted_logits) |
| logits = filtered |
| return logits |
|
|
|
|
| def sample_from_logits( |
| logits: torch.Tensor, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| greedy: bool, |
| ) -> int: |
| if greedy or temperature <= 0: |
| return int(torch.argmax(logits, dim=-1).item()) |
| logits = logits / max(temperature, 1e-6) |
| logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) |
| if not torch.isfinite(logits).any(): |
| raise RuntimeError("All logits are -inf after filtering.") |
| probs = torch.softmax(logits, dim=-1) |
| return int(torch.multinomial(probs, num_samples=1).item()) |
|
|
|
|
| def sample_audio_token_from_logits( |
| logits: torch.Tensor, |
| layout: TokenLayout, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| greedy: bool, |
| ) -> int: |
| audio_logits = logits[:, layout.audio_start : layout.audio_end] |
| sampled_audio_idx = sample_from_logits( |
| audio_logits, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| greedy=greedy, |
| ) |
| return int(layout.audio_start + sampled_audio_idx) |
|
|
|
|
| def chord_id_to_type(chord_id: int) -> str: |
| decoded = chord_id_to_label(chord_id) |
| return decoded if decoded != "N" or chord_id in {1, CHORD_BOS_ID, CHORD_EOS_ID} else f"unknown_{chord_id}" |
|
|
|
|
| def segment_id_to_type(segment_id: int) -> str: |
| decoded = structure_id_to_label(segment_id) |
| return decoded if 0 <= segment_id <= STRUCTURE_EOS_ID else f"unknown_{segment_id}" |
|
|
|
|
| def to_intervals(type_ids: list[int], fps: int, mapper) -> list[dict[str, Any]]: |
| if not type_ids: |
| return [] |
| out: list[dict[str, Any]] = [] |
| start = 0 |
| cur = type_ids[0] |
| for i in range(1, len(type_ids) + 1): |
| if i == len(type_ids) or type_ids[i] != cur: |
| out.append( |
| { |
| "start": round(start / float(fps), 6), |
| "end": round(i / float(fps), 6), |
| "type": mapper(int(cur)), |
| } |
| ) |
| if i < len(type_ids): |
| start = i |
| cur = type_ids[i] |
| return out |
|
|
|
|
| def merge_same_type_with_small_gap( |
| intervals: list[dict[str, Any]], fps: int, max_gap_frames: int = 1 |
| ) -> list[dict[str, Any]]: |
| if not intervals: |
| return [] |
| max_gap_s = float(max_gap_frames) / float(fps) |
| merged = [dict(intervals[0])] |
| for cur in intervals[1:]: |
| prev = merged[-1] |
| gap_s = float(cur["start"]) - float(prev["end"]) |
| if prev.get("type") == cur.get("type") and gap_s <= (max_gap_s + 1e-9): |
| prev["end"] = cur["end"] |
| else: |
| merged.append(dict(cur)) |
| return merged |
|
|
|
|
| @torch.inference_mode() |
| def generate_segmentwise( |
| model: MAGEL, |
| sample: HFTemplateSample, |
| layout: TokenLayout, |
| device: torch.device, |
| use_cache: bool, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| greedy: bool, |
| max_audio_tokens: int, |
| ) -> tuple[torch.Tensor, int, list[int], list[int]]: |
| import time |
|
|
| seq_template = sample.input_ids.to(device) |
| chord_template = sample.chord_ids.to(device) |
| structure_template = sample.structure_ids.to(device) |
| condition_mask_template = sample.condition_mask.to(device) |
| is_audio_code = sample.is_audio_codebook.to(device) |
| is_eos = sample.is_eos.to(device) |
|
|
| slot_positions = torch.where(is_audio_code | is_eos)[0] |
| if slot_positions.numel() == 0: |
| |
| return seq_template.detach().cpu(), 0, [], [] |
|
|
| start_pos = int(slot_positions[0].item()) |
| if sample.segments: |
| end_pos = int(sample.segments[-1].eos_pos) |
| else: |
| end_pos = int(slot_positions[-1].item()) |
|
|
| sampled_chord_ids: list[int] = [] |
| sampled_segment_ids: list[int] = [] |
|
|
| generated_ids = seq_template.clone() |
| sampled_count = 0 |
| past_key_values: Optional[tuple] = None |
|
|
| |
| |
| cond_template: torch.Tensor = model.condition_encoder( |
| chord_template.unsqueeze(0), |
| structure_template.unsqueeze(0), |
| ) |
|
|
| |
| full_attention_mask = torch.ones( |
| (1, sample.seq_len), dtype=torch.long, device=device |
| ) |
| prefix_ids = generated_ids[:start_pos].unsqueeze(0) |
| prefix_attn = full_attention_mask[:, :start_pos] |
| model_kwargs = dict( |
| input_ids=prefix_ids, |
| attention_mask=prefix_attn, |
| condition_mask=condition_mask_template[:start_pos].unsqueeze(0), |
| cond_precomputed=cond_template[:, :start_pos, :], |
| use_cache=use_cache, |
| ) |
| maybe_mark_compile_step_begin(model) |
| prefill_t0 = time.perf_counter() |
| out = model(**model_kwargs) |
| prefill_time_s = time.perf_counter() - prefill_t0 |
| logits_next = out.logits[:, -1, :] |
| if use_cache: |
| past_key_values = out.past_key_values |
| step_ids = torch.empty((1, 1), dtype=torch.long, device=device) |
|
|
| decode_time_s = 0.0 |
| for i in range(start_pos, end_pos + 1): |
| if bool(is_audio_code[i].item()): |
| if max_audio_tokens > 0 and sampled_count >= max_audio_tokens: |
| break |
| next_id = sample_audio_token_from_logits( |
| logits_next, |
| layout=layout, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| greedy=greedy, |
| ) |
| sampled_count += 1 |
| |
| cond_pos = i |
| sampled_chord_ids.append(int(chord_template[cond_pos].item())) |
| sampled_segment_ids.append(int(structure_template[cond_pos].item())) |
| elif bool(is_eos[i].item()): |
| next_id = layout.eos_audio |
| else: |
| next_id = int(seq_template[i].item()) |
|
|
| generated_ids[i] = int(next_id) |
|
|
| if i >= end_pos: |
| break |
|
|
| if use_cache: |
| step_ids[0, 0] = int(next_id) |
| step_attn = full_attention_mask[:, : i + 2] |
| model_kwargs = dict( |
| input_ids=step_ids, |
| attention_mask=step_attn, |
| condition_mask=condition_mask_template[i : i + 1].unsqueeze(0), |
| cond_precomputed=cond_template[:, i : i + 1, :], |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
| maybe_mark_compile_step_begin(model) |
| step_t0 = time.perf_counter() |
| out = model(**model_kwargs) |
| decode_time_s += time.perf_counter() - step_t0 |
| logits_next = out.logits[:, -1, :] |
| past_key_values = out.past_key_values |
| else: |
| cur_len = i + 1 |
| model_kwargs = dict( |
| input_ids=generated_ids[:cur_len].unsqueeze(0), |
| attention_mask=full_attention_mask[:, :cur_len], |
| condition_mask=condition_mask_template[:cur_len].unsqueeze(0), |
| cond_precomputed=cond_template[:, :cur_len, :], |
| use_cache=False, |
| ) |
| maybe_mark_compile_step_begin(model) |
| step_t0 = time.perf_counter() |
| out = model(**model_kwargs) |
| decode_time_s += time.perf_counter() - step_t0 |
| logits_next = out.logits[:, -1, :] |
|
|
| total_gen_time_s = prefill_time_s + decode_time_s |
| tokens_per_second = ( |
| float(sampled_count) / decode_time_s if decode_time_s > 0 and sampled_count > 0 else 0.0 |
| ) |
| print( |
| "[PROFILE] generation " |
| f"prefill_s={prefill_time_s:.3f} " |
| f"decode_s={decode_time_s:.3f} " |
| f"total_s={total_gen_time_s:.3f} " |
| f"sampled_audio_tokens={sampled_count} " |
| f"decode_tok_per_s={tokens_per_second:.3f}" |
| ) |
|
|
| return ( |
| generated_ids.detach().cpu(), |
| sampled_count, |
| sampled_chord_ids, |
| sampled_segment_ids, |
| ) |
|
|
|
|
| @torch.inference_mode() |
| def batch_generate_segmentwise( |
| model: MAGEL, |
| samples: list[HFTemplateSample], |
| layout: TokenLayout, |
| device: torch.device, |
| use_cache: bool, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| greedy: bool, |
| max_audio_tokens: int, |
| ) -> list[tuple[torch.Tensor, int, list[int], list[int]]]: |
| import time |
|
|
| if not samples: |
| return [] |
| if not use_cache: |
| return [ |
| generate_segmentwise( |
| model=model, |
| sample=sample, |
| layout=layout, |
| device=device, |
| use_cache=use_cache, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| greedy=greedy, |
| max_audio_tokens=max_audio_tokens, |
| ) |
| for sample in samples |
| ] |
|
|
| batch_size = len(samples) |
| seq_lens = [sample.seq_len for sample in samples] |
| max_seq_len = max(seq_lens) |
|
|
| seq_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) |
| generated_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) |
| chord_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) |
| structure_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device) |
| condition_mask_templates = torch.zeros( |
| (batch_size, max_seq_len), dtype=torch.bool, device=device |
| ) |
| is_audio_code_templates = torch.zeros( |
| (batch_size, max_seq_len), dtype=torch.bool, device=device |
| ) |
| is_eos_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device) |
|
|
| start_positions: list[int] = [] |
| end_positions: list[int] = [] |
| sampled_counts = [0 for _ in samples] |
| sampled_chord_ids: list[list[int]] = [[] for _ in samples] |
| sampled_segment_ids: list[list[int]] = [[] for _ in samples] |
| valid_sample_mask = torch.ones(batch_size, dtype=torch.bool, device=device) |
|
|
| for row_idx, sample in enumerate(samples): |
| seq_templates[row_idx, : sample.seq_len] = sample.input_ids.to(device) |
| generated_ids[row_idx, : sample.seq_len] = sample.input_ids.to(device) |
| chord_templates[row_idx, : sample.seq_len] = sample.chord_ids.to(device) |
| structure_templates[row_idx, : sample.seq_len] = sample.structure_ids.to(device) |
| condition_mask_templates[row_idx, : sample.seq_len] = sample.condition_mask.to(device) |
| is_audio_code_templates[row_idx, : sample.seq_len] = sample.is_audio_codebook.to(device) |
| is_eos_templates[row_idx, : sample.seq_len] = sample.is_eos.to(device) |
|
|
| slot_positions = torch.where( |
| is_audio_code_templates[row_idx, : sample.seq_len] |
| | is_eos_templates[row_idx, : sample.seq_len] |
| )[0] |
| if slot_positions.numel() == 0: |
| valid_sample_mask[row_idx] = False |
| start_positions.append(sample.seq_len) |
| end_positions.append(sample.seq_len - 1) |
| continue |
| start_pos = int(slot_positions[0].item()) |
| if sample.segments: |
| end_pos = int(sample.segments[-1].eos_pos) |
| else: |
| end_pos = int(slot_positions[-1].item()) |
| start_positions.append(start_pos) |
| end_positions.append(end_pos) |
|
|
| if not bool(valid_sample_mask.any().item()): |
| return [ |
| (sample.input_ids.detach().cpu(), 0, [], []) |
| for sample in samples |
| ] |
|
|
| start_positions_t = torch.tensor(start_positions, dtype=torch.long, device=device) |
| end_positions_t = torch.tensor(end_positions, dtype=torch.long, device=device) |
| prefix_lens = start_positions_t.clone() |
| max_prefix_len = int(prefix_lens.max().item()) |
| max_decode_steps = int((end_positions_t - start_positions_t + 1).clamp_min(0).max().item()) |
|
|
| cond_template = model.condition_encoder(chord_templates, structure_templates) |
|
|
| prefix_attention_mask = ( |
| torch.arange(max_prefix_len, device=device).unsqueeze(0) < prefix_lens.unsqueeze(1) |
| ).to(torch.long) |
| prefill_t0 = time.perf_counter() |
| maybe_mark_compile_step_begin(model) |
| out = model( |
| input_ids=generated_ids[:, :max_prefix_len], |
| attention_mask=prefix_attention_mask, |
| condition_mask=condition_mask_templates[:, :max_prefix_len], |
| cond_precomputed=cond_template[:, :max_prefix_len, :], |
| use_cache=True, |
| ) |
| prefill_time_s = time.perf_counter() - prefill_t0 |
|
|
| gather_idx = (prefix_lens - 1).clamp_min(0) |
| batch_indices = torch.arange(batch_size, device=device) |
| logits_next = out.logits[batch_indices, gather_idx, :] |
| past_key_values = out.past_key_values |
|
|
| step_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device) |
| decode_valid_mask = torch.zeros( |
| (batch_size, max_decode_steps), dtype=torch.bool, device=device |
| ) |
| decode_time_s = 0.0 |
|
|
| for step_idx in range(max_decode_steps): |
| cur_positions = start_positions_t + step_idx |
| active_mask = valid_sample_mask & cur_positions.le(end_positions_t) |
| if not bool(active_mask.any().item()): |
| break |
|
|
| next_ids = torch.zeros(batch_size, dtype=torch.long, device=device) |
| for row_idx in range(batch_size): |
| if not bool(active_mask[row_idx].item()): |
| continue |
| cur_pos = int(cur_positions[row_idx].item()) |
| if bool(is_audio_code_templates[row_idx, cur_pos].item()): |
| if max_audio_tokens > 0 and sampled_counts[row_idx] >= max_audio_tokens: |
| valid_sample_mask[row_idx] = False |
| continue |
| next_id = sample_audio_token_from_logits( |
| logits_next[row_idx : row_idx + 1], |
| layout=layout, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| greedy=greedy, |
| ) |
| sampled_counts[row_idx] += 1 |
| sampled_chord_ids[row_idx].append( |
| int(chord_templates[row_idx, cur_pos].item()) |
| ) |
| sampled_segment_ids[row_idx].append( |
| int(structure_templates[row_idx, cur_pos].item()) |
| ) |
| elif bool(is_eos_templates[row_idx, cur_pos].item()): |
| next_id = layout.eos_audio |
| else: |
| next_id = int(seq_templates[row_idx, cur_pos].item()) |
|
|
| generated_ids[row_idx, cur_pos] = int(next_id) |
| next_ids[row_idx] = int(next_id) |
| decode_valid_mask[row_idx, step_idx] = True |
|
|
| if step_idx >= max_decode_steps - 1: |
| break |
|
|
| step_ids[:, 0] = next_ids |
| step_attention_mask = torch.cat( |
| [ |
| prefix_attention_mask, |
| decode_valid_mask[:, : step_idx + 1].to(torch.long), |
| ], |
| dim=1, |
| ) |
| step_condition_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=device) |
| step_cond = torch.zeros( |
| (batch_size, 1, cond_template.shape[-1]), |
| dtype=cond_template.dtype, |
| device=device, |
| ) |
| for row_idx in range(batch_size): |
| if not bool(decode_valid_mask[row_idx, step_idx].item()): |
| continue |
| cur_pos = int(cur_positions[row_idx].item()) |
| step_condition_mask[row_idx, 0] = condition_mask_templates[row_idx, cur_pos] |
| step_cond[row_idx, 0, :] = cond_template[row_idx, cur_pos, :] |
|
|
| step_t0 = time.perf_counter() |
| maybe_mark_compile_step_begin(model) |
| out = model( |
| input_ids=step_ids, |
| attention_mask=step_attention_mask, |
| condition_mask=step_condition_mask, |
| cond_precomputed=step_cond, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
| decode_time_s += time.perf_counter() - step_t0 |
| logits_next = out.logits[:, -1, :] |
| past_key_values = out.past_key_values |
|
|
| total_sampled_tokens = sum(sampled_counts) |
| total_gen_time_s = prefill_time_s + decode_time_s |
| tokens_per_second = ( |
| float(total_sampled_tokens) / decode_time_s |
| if decode_time_s > 0 and total_sampled_tokens > 0 |
| else 0.0 |
| ) |
| print( |
| "[PROFILE] batch_generation " |
| f"batch_size={batch_size} " |
| f"prefill_s={prefill_time_s:.3f} " |
| f"decode_s={decode_time_s:.3f} " |
| f"total_s={total_gen_time_s:.3f} " |
| f"sampled_audio_tokens={total_sampled_tokens} " |
| f"decode_tok_per_s={tokens_per_second:.3f}" |
| ) |
|
|
| outputs: list[tuple[torch.Tensor, int, list[int], list[int]]] = [] |
| for row_idx, sample in enumerate(samples): |
| if not bool((torch.where(sample.is_audio_codebook | sample.is_eos)[0]).numel()): |
| outputs.append((sample.input_ids.detach().cpu(), 0, [], [])) |
| continue |
| outputs.append( |
| ( |
| generated_ids[row_idx, : sample.seq_len].detach().cpu(), |
| sampled_counts[row_idx], |
| sampled_chord_ids[row_idx], |
| sampled_segment_ids[row_idx], |
| ) |
| ) |
| return outputs |
|
|
|
|
| def save_outputs( |
| output_dir: str, |
| output_prefix: str, |
| sample: HFTemplateSample, |
| layout: TokenLayout, |
| generated_ids: torch.Tensor, |
| sampled_chord_ids: list[int], |
| sampled_segment_ids: list[int], |
| args: argparse.Namespace, |
| mucodec_decoder: Any = None, |
| ) -> None: |
| import time |
|
|
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| stamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| prefix = output_prefix or f"{sample.song_id}_{args.sample_idx}_{stamp}" |
|
|
| json_dir = args.json_output_dir or os.path.join(output_dir, "json") |
| wav_dir = args.wav_output_dir or os.path.join(output_dir, "wav") |
| Path(json_dir).mkdir(parents=True, exist_ok=True) |
| Path(wav_dir).mkdir(parents=True, exist_ok=True) |
|
|
| json_path = os.path.join(json_dir, f"{prefix}.chord_segment.json") |
| wav_path = os.path.join(wav_dir, f"{prefix}.wav") |
|
|
| gen_full = generated_ids.cpu().numpy().astype(np.int64) |
|
|
| gen_audio_raw = gen_full[ |
| (gen_full >= layout.audio_start) & (gen_full < layout.audio_end) |
| ] |
| gen_audio_shift = gen_audio_raw - layout.audio_start |
|
|
| save_t0 = time.perf_counter() |
| if gen_audio_shift.size == 0: |
| print("[WARN] No generated MuCodec tokens; skipping wav decode.") |
| else: |
| import torchaudio |
|
|
| wave = decode_mucodec_codes(mucodec_decoder, gen_audio_shift, args) |
| torchaudio.save(wav_path, wave, int(args.mucodec_sample_rate)) |
| print(f"[OK] {wav_path}") |
|
|
| chord_intervals = to_intervals( |
| sampled_chord_ids, fps=int(args.fps), mapper=chord_id_to_type |
| ) |
| segment_intervals = to_intervals( |
| sampled_segment_ids, fps=int(args.fps), mapper=segment_id_to_type |
| ) |
|
|
| |
| chord_intervals = [x for x in chord_intervals if x.get("type") != "pad"] |
| segment_intervals = [x for x in segment_intervals if x.get("type") != "pad"] |
| chord_intervals = merge_same_type_with_small_gap( |
| chord_intervals, fps=int(args.fps), max_gap_frames=1 |
| ) |
| segment_intervals = merge_same_type_with_small_gap( |
| segment_intervals, fps=int(args.fps), max_gap_frames=1 |
| ) |
|
|
| chord_segment = { |
| "song_id": sample.song_id, |
| "sample_idx": int(args.sample_idx), |
| "fps": int(args.fps), |
| "generated_audio_count": int(gen_audio_raw.shape[0]), |
| "chord": chord_intervals, |
| "segment": segment_intervals, |
| } |
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump(chord_segment, f, ensure_ascii=False, indent=2) |
|
|
| print(f"[OK] {json_path}") |
| save_time_s = time.perf_counter() - save_t0 |
| print( |
| "[PROFILE] save " |
| f"save_s={save_time_s:.3f} " |
| f"generated_audio_count={int(gen_audio_raw.shape[0])}" |
| ) |
|
|
|
|
| def main() -> None: |
| import time |
|
|
| args = parse_args() |
| seed_everything(args.seed) |
|
|
| use_cache = args.use_cache and not args.no_cache |
|
|
| device = resolve_device(args.device) |
| dtype = { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[args.dtype] |
| if device.type == "cpu" and dtype != torch.float32: |
| print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.") |
| dtype = torch.float32 |
|
|
| print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}") |
| print(f"[INFO] loading model from {args.model_path}") |
| model = load_magel_checkpoint( |
| checkpoint_path=args.model_path, |
| device=device, |
| dtype=dtype, |
| attn_implementation=args.attn_implementation, |
| ) |
| model = maybe_compile_model( |
| model, |
| enabled=bool(args.compile), |
| mode=str(args.compile_mode), |
| ) |
| num_audio_codebook = ( |
| int(args.num_audio_codebook) |
| if args.num_audio_codebook is not None |
| else int(getattr(model.config, "magel_num_audio_token", 16384)) |
| ) |
| print(f"[INFO] num_audio_codebook={num_audio_codebook}") |
|
|
| print(f"[INFO] loading HF sample idx={args.sample_idx} from {args.dataset_path}") |
| sample = load_hf_template_sample( |
| dataset_path=args.dataset_path, |
| split=args.split, |
| tokenizer_path=args.tokenizer_path, |
| sample_idx=args.sample_idx, |
| num_audio_codebook=num_audio_codebook, |
| ) |
| layout = TokenLayout( |
| num_text_token=sample.num_text_token, |
| num_audio_codebook=num_audio_codebook, |
| ) |
| print( |
| f"[INFO] song_id={sample.song_id}, seq_len={sample.seq_len}, segments={len(sample.segments)}" |
| ) |
| mucodec_decoder = build_mucodec_decoder(args) |
| print("[INFO] running segment-level autoregressive generation...") |
| t1 = time.time() |
| ( |
| generated_ids, |
| sampled_count, |
| sampled_chord_ids, |
| sampled_segment_ids, |
| ) = generate_segmentwise( |
| model=model, |
| sample=sample, |
| layout=layout, |
| device=device, |
| use_cache=use_cache, |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| greedy=bool(args.greedy), |
| max_audio_tokens=max(0, int(args.max_audio_tokens)), |
| ) |
|
|
| print(f"[INFO] sampled audio tokens: {sampled_count}") |
| print(f"[INFO] output sequence length: {generated_ids.numel()}") |
| t2 = time.time() |
|
|
| print("total time:", t2 - t1) |
|
|
| save_outputs( |
| output_dir=args.output_dir, |
| output_prefix=args.output_prefix, |
| sample=sample, |
| layout=layout, |
| generated_ids=generated_ids, |
| sampled_chord_ids=sampled_chord_ids, |
| sampled_segment_ids=sampled_segment_ids, |
| args=args, |
| mucodec_decoder=mucodec_decoder, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|