cond_gen / inference_full.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
HF-driven inference for MAGEL with segment-level autoregressive generation.
Uses from HF sample:
- text instruction/template tokens (token_ids scaffold)
- control tokens: chord_ids/structure_ids
Does NOT use:
- ground-truth audio token values as input (audio codebook positions are masked)
"""
import argparse
import contextlib
import importlib
import json
import os
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from runtime_utils import (
load_magel_checkpoint,
load_music_dataset,
maybe_compile_model,
maybe_mark_compile_step_begin,
resolve_device,
seed_everything,
)
from vocab import (
CHORD_BOS_ID,
CHORD_EOS_ID,
STRUCTURE_EOS_ID,
chord_id_to_label,
structure_id_to_label,
)
from modelling_qwen3 import MAGEL
REPO_ROOT = Path(__file__).resolve().parent
MUCODEC_ROOT = REPO_ROOT / "MuCodec"
@dataclass
class TokenLayout:
num_text_token: int
num_audio_codebook: int = 16384
@property
def audio_start(self) -> int:
return self.num_text_token
@property
def audio_end(self) -> int:
return self.num_text_token + self.num_audio_codebook
@property
def mask_audio(self) -> int:
return self.audio_end
@property
def bos_audio(self) -> int:
return self.audio_end + 1
@property
def eos_audio(self) -> int:
return self.audio_end + 2
@dataclass
class SegmentSpan:
seg_idx: int
bos_pos: int
eos_pos: int
audio_positions: list[int]
@dataclass
class HFTemplateSample:
song_id: str
num_text_token: int
template_ids: torch.Tensor # [T], original token_ids
input_ids: torch.Tensor # [T], audio codebook replaced with MASK_AUDIO
chord_ids: torch.Tensor # [T]
structure_ids: torch.Tensor # [T]
condition_mask: torch.Tensor # [T]
is_audio_codebook: torch.Tensor # [T]
is_eos: torch.Tensor # [T]
segments: list[SegmentSpan]
raw_item: dict[str, Any]
@property
def seq_len(self) -> int:
return int(self.input_ids.numel())
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Segment-wise AR generation from HF controls/scaffold."
)
parser.add_argument(
"--model_path",
type=str,
default="./output_qwen3_0p6b_train/final",
)
parser.add_argument(
"--dataset_path",
type=str,
default="muse_mucodec_chord.ds",
)
parser.add_argument("--split", type=str, default="validation")
parser.add_argument("--sample_idx", type=int, default=0)
parser.add_argument(
"--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B"
)
parser.add_argument(
"--num_audio_codebook",
type=int,
default=None,
help="Audio codebook size. Defaults to checkpoint metadata when available.",
)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.90)
parser.add_argument("--greedy", action="store_true", default=False)
parser.add_argument("--max_audio_tokens", type=int, default=0)
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
)
parser.add_argument("--use_cache", action="store_true", default=True)
parser.add_argument("--no_cache", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument(
"--compile_mode",
type=str,
default="reduce-overhead",
choices=["default", "reduce-overhead", "max-autotune"],
)
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
choices=["eager", "sdpa", "flash_attention_2"],
)
parser.add_argument("--output_dir", type=str, default="predictions")
parser.add_argument("--output_prefix", type=str, default="")
parser.add_argument(
"--json_output_dir",
type=str,
default="predictions/json",
help="Directory for chord/segment json. Default: <output_dir>/json",
)
parser.add_argument(
"--mucodec_device",
type=str,
default="auto",
help="Device string for MuCodec, for example cuda:0.",
)
parser.add_argument(
"--mucodec_layer_num",
type=int,
default=7,
help="MuCodec layer_num passed to the official decoder.",
)
parser.add_argument(
"--mucodec_duration",
type=float,
default=40.96,
help="Chunk duration argument passed to MuCodec code2sound.",
)
parser.add_argument(
"--mucodec_guidance_scale",
type=float,
default=1.5,
help="Guidance scale argument passed to MuCodec code2sound.",
)
parser.add_argument(
"--mucodec_num_steps",
type=int,
default=20,
help="Sampling steps argument passed to MuCodec code2sound.",
)
parser.add_argument(
"--mucodec_sample_rate",
type=int,
default=48000,
help="Sample rate used when saving decoded wav.",
)
parser.add_argument(
"--wav_output_dir",
type=str,
default="predictions/wav",
help="Directory for decoded wav. Default: <output_dir>/wav",
)
return parser.parse_args()
def resolve_runtime_device_str(device_arg: str) -> str:
if device_arg != "auto":
return device_arg
if torch.cuda.is_available():
return "cuda:0"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
@contextlib.contextmanager
def pushd(path: str):
prev = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(prev)
def ensure_sys_path(path: str) -> None:
if path and path not in sys.path:
sys.path.insert(0, path)
def get_mucodec_root() -> str:
if not MUCODEC_ROOT.is_dir():
raise FileNotFoundError(f"MuCodec directory not found: {MUCODEC_ROOT}")
if not (MUCODEC_ROOT / "generate.py").is_file():
raise FileNotFoundError(
f"MuCodec entrypoint not found: {MUCODEC_ROOT / 'generate.py'}"
)
return str(MUCODEC_ROOT)
def import_mucodec_class():
repo_path = get_mucodec_root()
ensure_sys_path(repo_path)
try:
module = importlib.import_module("generate")
return getattr(module, "MuCodec"), repo_path
except Exception as exc: # pragma: no cover - env dependent
raise ImportError(f"Could not import MuCodec from {repo_path}/generate.py: {exc}")
def build_mucodec_decoder(args: argparse.Namespace) -> Any:
MuCodec, resolved_repo = import_mucodec_class()
ckpt_path = os.path.join(resolved_repo, "ckpt", "mucodec.pt")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"MuCodec checkpoint not found: {ckpt_path}")
required_local_files = [
os.path.join(resolved_repo, "tools", "audioldm_48k.pth"),
os.path.join(resolved_repo, "muq_dev", "muq.pt"),
]
for path in required_local_files:
if not os.path.exists(path):
raise FileNotFoundError(
f"Required MuCodec dependency not found for current folder structure: {path}"
)
mucodec_device = resolve_runtime_device_str(args.mucodec_device)
if resolved_repo:
print(f"[INFO] resolved MuCodec repo: {resolved_repo}")
print(f"[INFO] loading MuCodec from {ckpt_path} on {mucodec_device}")
with pushd(resolved_repo):
decoder = MuCodec(
model_path=ckpt_path,
layer_num=int(args.mucodec_layer_num),
load_main_model=True,
device=mucodec_device,
)
setattr(decoder, "_magel_mucodec_repo", resolved_repo)
return decoder
def decode_mucodec_codes(
mucodec_decoder: Any,
shifted_codes: np.ndarray,
args: argparse.Namespace,
) -> torch.Tensor:
if shifted_codes.ndim != 1:
raise ValueError(
f"Expected 1D MuCodec token stream, got shape {shifted_codes.shape}"
)
codes = torch.from_numpy(shifted_codes.astype(np.int64, copy=False))
codes = codes.unsqueeze(0).unsqueeze(0)
repo_path = getattr(mucodec_decoder, "_magel_mucodec_repo", "")
decode_ctx = pushd(repo_path) if repo_path else contextlib.nullcontext()
with decode_ctx:
wave = mucodec_decoder.code2sound(
codes,
prompt=None,
duration=float(args.mucodec_duration),
guidance_scale=float(args.mucodec_guidance_scale),
num_steps=int(args.mucodec_num_steps),
disable_progress=True,
)
if not torch.is_tensor(wave):
wave = torch.as_tensor(wave)
if wave.ndim == 1:
wave = wave.unsqueeze(0)
return wave.detach().cpu().to(torch.float32)
def build_segment_spans(
template_ids: torch.Tensor,
is_audio_codebook: torch.Tensor,
layout: TokenLayout,
) -> list[SegmentSpan]:
bos_positions = torch.where(template_ids.eq(layout.bos_audio))[0].tolist()
eos_positions = torch.where(template_ids.eq(layout.eos_audio))[0].tolist()
if not bos_positions or not eos_positions:
return []
spans: list[SegmentSpan] = []
eos_ptr = 0
for b in bos_positions:
while eos_ptr < len(eos_positions) and eos_positions[eos_ptr] <= b:
eos_ptr += 1
if eos_ptr >= len(eos_positions):
break
e = eos_positions[eos_ptr]
eos_ptr += 1
idx = torch.arange(template_ids.numel(), device=template_ids.device)
mask = is_audio_codebook & (idx > b) & (idx < e)
audio_positions = torch.where(mask)[0].tolist()
spans.append(
SegmentSpan(
seg_idx=len(spans),
bos_pos=int(b),
eos_pos=int(e),
audio_positions=[int(p) for p in audio_positions],
)
)
return spans
def load_hf_template_sample(
dataset_path: str,
split: str,
tokenizer_path: str,
sample_idx: int,
num_audio_codebook: int,
) -> HFTemplateSample:
music_ds = load_music_dataset(
dataset_path=dataset_path,
split=split,
tokenizer_path=tokenizer_path,
num_audio_token=num_audio_codebook,
use_fast=True,
)
return load_hf_template_sample_from_music_dataset(
music_ds=music_ds,
sample_idx=sample_idx,
num_audio_codebook=num_audio_codebook,
)
def load_hf_template_sample_from_music_dataset(
music_ds,
sample_idx: int,
num_audio_codebook: int,
) -> HFTemplateSample:
layout = TokenLayout(
num_text_token=music_ds.num_text_token,
num_audio_codebook=num_audio_codebook,
)
raw_item = music_ds._data[sample_idx]
row = music_ds[sample_idx]
template_ids = row["token_ids"].to(torch.long)
chord_ids = row["chord_ids"].to(torch.long)
structure_ids = row["structure_ids"].to(torch.long)
condition_mask = row["condition_mask"].to(torch.bool)
seq_len = int(template_ids.numel())
for name, t in [
("chord_ids", chord_ids),
("structure_ids", structure_ids),
("condition_mask", condition_mask),
]:
if int(t.numel()) != seq_len:
raise ValueError(f"{name} length mismatch: {int(t.numel())} != {seq_len}")
is_audio_codebook = (template_ids >= layout.audio_start) & (
template_ids < layout.audio_end
)
is_eos = template_ids.eq(layout.eos_audio)
# Remove GT audio token values from input scaffold.
input_ids = template_ids.clone()
input_ids[is_audio_codebook] = layout.mask_audio
spans = build_segment_spans(template_ids, is_audio_codebook, layout)
return HFTemplateSample(
song_id=str(raw_item.get("song_id", f"sample_{sample_idx}")),
num_text_token=music_ds.num_text_token,
template_ids=template_ids,
input_ids=input_ids,
chord_ids=chord_ids,
structure_ids=structure_ids,
condition_mask=condition_mask,
is_audio_codebook=is_audio_codebook,
is_eos=is_eos,
segments=spans,
raw_item=raw_item,
)
def apply_top_k_top_p(logits: torch.Tensor, top_k: int, top_p: float) -> torch.Tensor:
if top_k is not None and top_k > 0:
k = min(top_k, logits.shape[-1])
values, _ = torch.topk(logits, k, dim=-1)
kth = values[:, -1].unsqueeze(-1)
logits = logits.masked_fill(logits < kth, float("-inf"))
if top_p is not None and 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cum_probs > top_p
remove_mask[:, 0] = False
sorted_logits = sorted_logits.masked_fill(remove_mask, float("-inf"))
filtered = torch.full_like(logits, float("-inf"))
filtered.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
logits = filtered
return logits
def sample_from_logits(
logits: torch.Tensor,
temperature: float,
top_k: int,
top_p: float,
greedy: bool,
) -> int:
if greedy or temperature <= 0:
return int(torch.argmax(logits, dim=-1).item())
logits = logits / max(temperature, 1e-6)
logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p)
if not torch.isfinite(logits).any():
raise RuntimeError("All logits are -inf after filtering.")
probs = torch.softmax(logits, dim=-1)
return int(torch.multinomial(probs, num_samples=1).item())
def sample_audio_token_from_logits(
logits: torch.Tensor,
layout: TokenLayout,
temperature: float,
top_k: int,
top_p: float,
greedy: bool,
) -> int:
audio_logits = logits[:, layout.audio_start : layout.audio_end]
sampled_audio_idx = sample_from_logits(
audio_logits,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
)
return int(layout.audio_start + sampled_audio_idx)
def chord_id_to_type(chord_id: int) -> str:
decoded = chord_id_to_label(chord_id)
return decoded if decoded != "N" or chord_id in {1, CHORD_BOS_ID, CHORD_EOS_ID} else f"unknown_{chord_id}"
def segment_id_to_type(segment_id: int) -> str:
decoded = structure_id_to_label(segment_id)
return decoded if 0 <= segment_id <= STRUCTURE_EOS_ID else f"unknown_{segment_id}"
def to_intervals(type_ids: list[int], fps: int, mapper) -> list[dict[str, Any]]:
if not type_ids:
return []
out: list[dict[str, Any]] = []
start = 0
cur = type_ids[0]
for i in range(1, len(type_ids) + 1):
if i == len(type_ids) or type_ids[i] != cur:
out.append(
{
"start": round(start / float(fps), 6),
"end": round(i / float(fps), 6),
"type": mapper(int(cur)),
}
)
if i < len(type_ids):
start = i
cur = type_ids[i]
return out
def merge_same_type_with_small_gap(
intervals: list[dict[str, Any]], fps: int, max_gap_frames: int = 1
) -> list[dict[str, Any]]:
if not intervals:
return []
max_gap_s = float(max_gap_frames) / float(fps)
merged = [dict(intervals[0])]
for cur in intervals[1:]:
prev = merged[-1]
gap_s = float(cur["start"]) - float(prev["end"])
if prev.get("type") == cur.get("type") and gap_s <= (max_gap_s + 1e-9):
prev["end"] = cur["end"]
else:
merged.append(dict(cur))
return merged
@torch.inference_mode()
def generate_segmentwise(
model: MAGEL,
sample: HFTemplateSample,
layout: TokenLayout,
device: torch.device,
use_cache: bool,
temperature: float,
top_k: int,
top_p: float,
greedy: bool,
max_audio_tokens: int,
) -> tuple[torch.Tensor, int, list[int], list[int]]:
import time
seq_template = sample.input_ids.to(device)
chord_template = sample.chord_ids.to(device)
structure_template = sample.structure_ids.to(device)
condition_mask_template = sample.condition_mask.to(device)
is_audio_code = sample.is_audio_codebook.to(device)
is_eos = sample.is_eos.to(device)
slot_positions = torch.where(is_audio_code | is_eos)[0]
if slot_positions.numel() == 0:
# No generation slot: return scaffold as-is.
return seq_template.detach().cpu(), 0, [], []
start_pos = int(slot_positions[0].item())
if sample.segments:
end_pos = int(sample.segments[-1].eos_pos)
else:
end_pos = int(slot_positions[-1].item())
sampled_chord_ids: list[int] = []
sampled_segment_ids: list[int] = []
generated_ids = seq_template.clone()
sampled_count = 0
past_key_values: Optional[tuple] = None
# Precompute full-sequence condition once so cached decoding keeps
# the same global condition-encoder context as training.
cond_template: torch.Tensor = model.condition_encoder(
chord_template.unsqueeze(0),
structure_template.unsqueeze(0),
)
# Prefill with fixed prefix.
full_attention_mask = torch.ones(
(1, sample.seq_len), dtype=torch.long, device=device
)
prefix_ids = generated_ids[:start_pos].unsqueeze(0)
prefix_attn = full_attention_mask[:, :start_pos]
model_kwargs = dict(
input_ids=prefix_ids,
attention_mask=prefix_attn,
condition_mask=condition_mask_template[:start_pos].unsqueeze(0),
cond_precomputed=cond_template[:, :start_pos, :],
use_cache=use_cache,
)
maybe_mark_compile_step_begin(model)
prefill_t0 = time.perf_counter()
out = model(**model_kwargs)
prefill_time_s = time.perf_counter() - prefill_t0
logits_next = out.logits[:, -1, :]
if use_cache:
past_key_values = out.past_key_values
step_ids = torch.empty((1, 1), dtype=torch.long, device=device)
decode_time_s = 0.0
for i in range(start_pos, end_pos + 1):
if bool(is_audio_code[i].item()):
if max_audio_tokens > 0 and sampled_count >= max_audio_tokens:
break
next_id = sample_audio_token_from_logits(
logits_next,
layout=layout,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
)
sampled_count += 1
# Controls are input-aligned to the token sequence.
cond_pos = i
sampled_chord_ids.append(int(chord_template[cond_pos].item()))
sampled_segment_ids.append(int(structure_template[cond_pos].item()))
elif bool(is_eos[i].item()):
next_id = layout.eos_audio
else:
next_id = int(seq_template[i].item())
generated_ids[i] = int(next_id)
if i >= end_pos:
break
if use_cache:
step_ids[0, 0] = int(next_id)
step_attn = full_attention_mask[:, : i + 2]
model_kwargs = dict(
input_ids=step_ids,
attention_mask=step_attn,
condition_mask=condition_mask_template[i : i + 1].unsqueeze(0),
cond_precomputed=cond_template[:, i : i + 1, :],
past_key_values=past_key_values,
use_cache=True,
)
maybe_mark_compile_step_begin(model)
step_t0 = time.perf_counter()
out = model(**model_kwargs)
decode_time_s += time.perf_counter() - step_t0
logits_next = out.logits[:, -1, :]
past_key_values = out.past_key_values
else:
cur_len = i + 1
model_kwargs = dict(
input_ids=generated_ids[:cur_len].unsqueeze(0),
attention_mask=full_attention_mask[:, :cur_len],
condition_mask=condition_mask_template[:cur_len].unsqueeze(0),
cond_precomputed=cond_template[:, :cur_len, :],
use_cache=False,
)
maybe_mark_compile_step_begin(model)
step_t0 = time.perf_counter()
out = model(**model_kwargs)
decode_time_s += time.perf_counter() - step_t0
logits_next = out.logits[:, -1, :]
total_gen_time_s = prefill_time_s + decode_time_s
tokens_per_second = (
float(sampled_count) / decode_time_s if decode_time_s > 0 and sampled_count > 0 else 0.0
)
print(
"[PROFILE] generation "
f"prefill_s={prefill_time_s:.3f} "
f"decode_s={decode_time_s:.3f} "
f"total_s={total_gen_time_s:.3f} "
f"sampled_audio_tokens={sampled_count} "
f"decode_tok_per_s={tokens_per_second:.3f}"
)
return (
generated_ids.detach().cpu(),
sampled_count,
sampled_chord_ids,
sampled_segment_ids,
)
@torch.inference_mode()
def batch_generate_segmentwise(
model: MAGEL,
samples: list[HFTemplateSample],
layout: TokenLayout,
device: torch.device,
use_cache: bool,
temperature: float,
top_k: int,
top_p: float,
greedy: bool,
max_audio_tokens: int,
) -> list[tuple[torch.Tensor, int, list[int], list[int]]]:
import time
if not samples:
return []
if not use_cache:
return [
generate_segmentwise(
model=model,
sample=sample,
layout=layout,
device=device,
use_cache=use_cache,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
max_audio_tokens=max_audio_tokens,
)
for sample in samples
]
batch_size = len(samples)
seq_lens = [sample.seq_len for sample in samples]
max_seq_len = max(seq_lens)
seq_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
generated_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
chord_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
structure_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
condition_mask_templates = torch.zeros(
(batch_size, max_seq_len), dtype=torch.bool, device=device
)
is_audio_code_templates = torch.zeros(
(batch_size, max_seq_len), dtype=torch.bool, device=device
)
is_eos_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device)
start_positions: list[int] = []
end_positions: list[int] = []
sampled_counts = [0 for _ in samples]
sampled_chord_ids: list[list[int]] = [[] for _ in samples]
sampled_segment_ids: list[list[int]] = [[] for _ in samples]
valid_sample_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
for row_idx, sample in enumerate(samples):
seq_templates[row_idx, : sample.seq_len] = sample.input_ids.to(device)
generated_ids[row_idx, : sample.seq_len] = sample.input_ids.to(device)
chord_templates[row_idx, : sample.seq_len] = sample.chord_ids.to(device)
structure_templates[row_idx, : sample.seq_len] = sample.structure_ids.to(device)
condition_mask_templates[row_idx, : sample.seq_len] = sample.condition_mask.to(device)
is_audio_code_templates[row_idx, : sample.seq_len] = sample.is_audio_codebook.to(device)
is_eos_templates[row_idx, : sample.seq_len] = sample.is_eos.to(device)
slot_positions = torch.where(
is_audio_code_templates[row_idx, : sample.seq_len]
| is_eos_templates[row_idx, : sample.seq_len]
)[0]
if slot_positions.numel() == 0:
valid_sample_mask[row_idx] = False
start_positions.append(sample.seq_len)
end_positions.append(sample.seq_len - 1)
continue
start_pos = int(slot_positions[0].item())
if sample.segments:
end_pos = int(sample.segments[-1].eos_pos)
else:
end_pos = int(slot_positions[-1].item())
start_positions.append(start_pos)
end_positions.append(end_pos)
if not bool(valid_sample_mask.any().item()):
return [
(sample.input_ids.detach().cpu(), 0, [], [])
for sample in samples
]
start_positions_t = torch.tensor(start_positions, dtype=torch.long, device=device)
end_positions_t = torch.tensor(end_positions, dtype=torch.long, device=device)
prefix_lens = start_positions_t.clone()
max_prefix_len = int(prefix_lens.max().item())
max_decode_steps = int((end_positions_t - start_positions_t + 1).clamp_min(0).max().item())
cond_template = model.condition_encoder(chord_templates, structure_templates)
prefix_attention_mask = (
torch.arange(max_prefix_len, device=device).unsqueeze(0) < prefix_lens.unsqueeze(1)
).to(torch.long)
prefill_t0 = time.perf_counter()
maybe_mark_compile_step_begin(model)
out = model(
input_ids=generated_ids[:, :max_prefix_len],
attention_mask=prefix_attention_mask,
condition_mask=condition_mask_templates[:, :max_prefix_len],
cond_precomputed=cond_template[:, :max_prefix_len, :],
use_cache=True,
)
prefill_time_s = time.perf_counter() - prefill_t0
gather_idx = (prefix_lens - 1).clamp_min(0)
batch_indices = torch.arange(batch_size, device=device)
logits_next = out.logits[batch_indices, gather_idx, :]
past_key_values = out.past_key_values
step_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
decode_valid_mask = torch.zeros(
(batch_size, max_decode_steps), dtype=torch.bool, device=device
)
decode_time_s = 0.0
for step_idx in range(max_decode_steps):
cur_positions = start_positions_t + step_idx
active_mask = valid_sample_mask & cur_positions.le(end_positions_t)
if not bool(active_mask.any().item()):
break
next_ids = torch.zeros(batch_size, dtype=torch.long, device=device)
for row_idx in range(batch_size):
if not bool(active_mask[row_idx].item()):
continue
cur_pos = int(cur_positions[row_idx].item())
if bool(is_audio_code_templates[row_idx, cur_pos].item()):
if max_audio_tokens > 0 and sampled_counts[row_idx] >= max_audio_tokens:
valid_sample_mask[row_idx] = False
continue
next_id = sample_audio_token_from_logits(
logits_next[row_idx : row_idx + 1],
layout=layout,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
)
sampled_counts[row_idx] += 1
sampled_chord_ids[row_idx].append(
int(chord_templates[row_idx, cur_pos].item())
)
sampled_segment_ids[row_idx].append(
int(structure_templates[row_idx, cur_pos].item())
)
elif bool(is_eos_templates[row_idx, cur_pos].item()):
next_id = layout.eos_audio
else:
next_id = int(seq_templates[row_idx, cur_pos].item())
generated_ids[row_idx, cur_pos] = int(next_id)
next_ids[row_idx] = int(next_id)
decode_valid_mask[row_idx, step_idx] = True
if step_idx >= max_decode_steps - 1:
break
step_ids[:, 0] = next_ids
step_attention_mask = torch.cat(
[
prefix_attention_mask,
decode_valid_mask[:, : step_idx + 1].to(torch.long),
],
dim=1,
)
step_condition_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=device)
step_cond = torch.zeros(
(batch_size, 1, cond_template.shape[-1]),
dtype=cond_template.dtype,
device=device,
)
for row_idx in range(batch_size):
if not bool(decode_valid_mask[row_idx, step_idx].item()):
continue
cur_pos = int(cur_positions[row_idx].item())
step_condition_mask[row_idx, 0] = condition_mask_templates[row_idx, cur_pos]
step_cond[row_idx, 0, :] = cond_template[row_idx, cur_pos, :]
step_t0 = time.perf_counter()
maybe_mark_compile_step_begin(model)
out = model(
input_ids=step_ids,
attention_mask=step_attention_mask,
condition_mask=step_condition_mask,
cond_precomputed=step_cond,
past_key_values=past_key_values,
use_cache=True,
)
decode_time_s += time.perf_counter() - step_t0
logits_next = out.logits[:, -1, :]
past_key_values = out.past_key_values
total_sampled_tokens = sum(sampled_counts)
total_gen_time_s = prefill_time_s + decode_time_s
tokens_per_second = (
float(total_sampled_tokens) / decode_time_s
if decode_time_s > 0 and total_sampled_tokens > 0
else 0.0
)
print(
"[PROFILE] batch_generation "
f"batch_size={batch_size} "
f"prefill_s={prefill_time_s:.3f} "
f"decode_s={decode_time_s:.3f} "
f"total_s={total_gen_time_s:.3f} "
f"sampled_audio_tokens={total_sampled_tokens} "
f"decode_tok_per_s={tokens_per_second:.3f}"
)
outputs: list[tuple[torch.Tensor, int, list[int], list[int]]] = []
for row_idx, sample in enumerate(samples):
if not bool((torch.where(sample.is_audio_codebook | sample.is_eos)[0]).numel()):
outputs.append((sample.input_ids.detach().cpu(), 0, [], []))
continue
outputs.append(
(
generated_ids[row_idx, : sample.seq_len].detach().cpu(),
sampled_counts[row_idx],
sampled_chord_ids[row_idx],
sampled_segment_ids[row_idx],
)
)
return outputs
def save_outputs(
output_dir: str,
output_prefix: str,
sample: HFTemplateSample,
layout: TokenLayout,
generated_ids: torch.Tensor,
sampled_chord_ids: list[int],
sampled_segment_ids: list[int],
args: argparse.Namespace,
mucodec_decoder: Any = None,
) -> None:
import time
Path(output_dir).mkdir(parents=True, exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
prefix = output_prefix or f"{sample.song_id}_{args.sample_idx}_{stamp}"
json_dir = args.json_output_dir or os.path.join(output_dir, "json")
wav_dir = args.wav_output_dir or os.path.join(output_dir, "wav")
Path(json_dir).mkdir(parents=True, exist_ok=True)
Path(wav_dir).mkdir(parents=True, exist_ok=True)
json_path = os.path.join(json_dir, f"{prefix}.chord_segment.json")
wav_path = os.path.join(wav_dir, f"{prefix}.wav")
gen_full = generated_ids.cpu().numpy().astype(np.int64)
gen_audio_raw = gen_full[
(gen_full >= layout.audio_start) & (gen_full < layout.audio_end)
]
gen_audio_shift = gen_audio_raw - layout.audio_start
save_t0 = time.perf_counter()
if gen_audio_shift.size == 0:
print("[WARN] No generated MuCodec tokens; skipping wav decode.")
else:
import torchaudio
wave = decode_mucodec_codes(mucodec_decoder, gen_audio_shift, args)
torchaudio.save(wav_path, wave, int(args.mucodec_sample_rate))
print(f"[OK] {wav_path}")
chord_intervals = to_intervals(
sampled_chord_ids, fps=int(args.fps), mapper=chord_id_to_type
)
segment_intervals = to_intervals(
sampled_segment_ids, fps=int(args.fps), mapper=segment_id_to_type
)
# PAD is used for EOS-related conditioning; drop it in exported json.
chord_intervals = [x for x in chord_intervals if x.get("type") != "pad"]
segment_intervals = [x for x in segment_intervals if x.get("type") != "pad"]
chord_intervals = merge_same_type_with_small_gap(
chord_intervals, fps=int(args.fps), max_gap_frames=1
)
segment_intervals = merge_same_type_with_small_gap(
segment_intervals, fps=int(args.fps), max_gap_frames=1
)
chord_segment = {
"song_id": sample.song_id,
"sample_idx": int(args.sample_idx),
"fps": int(args.fps),
"generated_audio_count": int(gen_audio_raw.shape[0]),
"chord": chord_intervals,
"segment": segment_intervals,
}
with open(json_path, "w", encoding="utf-8") as f:
json.dump(chord_segment, f, ensure_ascii=False, indent=2)
print(f"[OK] {json_path}")
save_time_s = time.perf_counter() - save_t0
print(
"[PROFILE] save "
f"save_s={save_time_s:.3f} "
f"generated_audio_count={int(gen_audio_raw.shape[0])}"
)
def main() -> None:
import time
args = parse_args()
seed_everything(args.seed)
use_cache = args.use_cache and not args.no_cache
device = resolve_device(args.device)
dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[args.dtype]
if device.type == "cpu" and dtype != torch.float32:
print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
dtype = torch.float32
print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}")
print(f"[INFO] loading model from {args.model_path}")
model = load_magel_checkpoint(
checkpoint_path=args.model_path,
device=device,
dtype=dtype,
attn_implementation=args.attn_implementation,
)
model = maybe_compile_model(
model,
enabled=bool(args.compile),
mode=str(args.compile_mode),
)
num_audio_codebook = (
int(args.num_audio_codebook)
if args.num_audio_codebook is not None
else int(getattr(model.config, "magel_num_audio_token", 16384))
)
print(f"[INFO] num_audio_codebook={num_audio_codebook}")
print(f"[INFO] loading HF sample idx={args.sample_idx} from {args.dataset_path}")
sample = load_hf_template_sample(
dataset_path=args.dataset_path,
split=args.split,
tokenizer_path=args.tokenizer_path,
sample_idx=args.sample_idx,
num_audio_codebook=num_audio_codebook,
)
layout = TokenLayout(
num_text_token=sample.num_text_token,
num_audio_codebook=num_audio_codebook,
)
print(
f"[INFO] song_id={sample.song_id}, seq_len={sample.seq_len}, segments={len(sample.segments)}"
)
mucodec_decoder = build_mucodec_decoder(args)
print("[INFO] running segment-level autoregressive generation...")
t1 = time.time()
(
generated_ids,
sampled_count,
sampled_chord_ids,
sampled_segment_ids,
) = generate_segmentwise(
model=model,
sample=sample,
layout=layout,
device=device,
use_cache=use_cache,
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
greedy=bool(args.greedy),
max_audio_tokens=max(0, int(args.max_audio_tokens)),
)
print(f"[INFO] sampled audio tokens: {sampled_count}")
print(f"[INFO] output sequence length: {generated_ids.numel()}")
t2 = time.time()
print("total time:", t2 - t1)
save_outputs(
output_dir=args.output_dir,
output_prefix=args.output_prefix,
sample=sample,
layout=layout,
generated_ids=generated_ids,
sampled_chord_ids=sampled_chord_ids,
sampled_segment_ids=sampled_segment_ids,
args=args,
mucodec_decoder=mucodec_decoder,
)
if __name__ == "__main__":
main()