cond_gen / qwen3_plain_ar.py
Leon299's picture
Add files using upload-large-folder tool
60f02df verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Pure Qwen3 autoregressive training/inference for music token generation.
This variant keeps:
- style
- section structure in the prompt
- section lyrics/description
This variant removes:
- frame-level chord conditioning
- frame-level structure conditioning
- condition encoder / AdaLN injection
"""
from __future__ import annotations
import argparse
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import datasets
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoTokenizer, Trainer, TrainingArguments
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from audio_tokens import (
EOA_TOKEN,
MASK_AUDIO_TOKEN,
SOA_TOKEN,
add_audio_special_tokens,
audio_id_to_token,
)
from dataset import normalize_section_text, SECTION_NAME_MAP, SINGLETON_SECTION_NAMES
from inference_full import build_mucodec_decoder, decode_mucodec_codes
from runtime_utils import resolve_device, seed_everything
from vocab import normalize_structure_label
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Pure Qwen3 autoregressive training/inference without frame-level conditioning."
)
subparsers = parser.add_subparsers(dest="command", required=True)
train_parser = subparsers.add_parser(
"train",
help="Train a plain Qwen3 autoregressive model on section prompts and audio tokens.",
)
add_train_args(train_parser)
infer_parser = subparsers.add_parser(
"infer",
help="Run section-wise autoregressive inference with a plain Qwen3 checkpoint.",
)
add_infer_args(infer_parser)
return parser.parse_args()
def add_train_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--dataset_path", type=str, default="muse_mucodec_chord.ds")
parser.add_argument(
"--model_path",
type=str,
default="checkpoints/Qwen3-0.6B",
help="Local Qwen3 base checkpoint path.",
)
parser.add_argument(
"--tokenizer_path",
type=str,
default="checkpoints/Qwen3-0.6B",
help="Local tokenizer checkpoint path.",
)
parser.add_argument(
"--num_audio_token",
type=int,
default=None,
help="Audio codebook size. Defaults to checkpoint metadata when available, else 16384.",
)
parser.add_argument(
"--model_dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
)
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
choices=["eager", "sdpa", "flash_attention_2"],
)
parser.add_argument("--output_dir", type=str, default="./output_qwen3_plain_ar")
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--num_train_epochs", type=float, default=20)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--max_grad_norm", type=float, default=5.0)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.",
)
parser.add_argument("--dataloader_num_workers", type=int, default=12)
parser.add_argument(
"--gradient_checkpointing",
dest="gradient_checkpointing",
action="store_true",
)
parser.add_argument(
"--deepspeed",
type=str,
default=None,
help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.",
)
parser.add_argument("--report_to", type=str, default="wandb")
parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b")
parser.add_argument("--wandb_run_name", type=str, default=None)
def add_infer_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Tokenizer path. Defaults to --model_path.",
)
parser.add_argument("--dataset_path", type=str, default="muse_mucodec_chord.ds")
parser.add_argument("--split", type=str, default="validation")
parser.add_argument("--sample_idx", type=int, default=0)
parser.add_argument(
"--num_audio_token",
type=int,
default=None,
help="Audio codebook size. Defaults to checkpoint metadata when available, else 16384.",
)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
)
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
choices=["eager", "sdpa", "flash_attention_2"],
)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.90)
parser.add_argument("--greedy", action="store_true", default=False)
parser.add_argument("--use_cache", action="store_true", default=True)
parser.add_argument("--no_cache", action="store_true", default=False)
parser.add_argument(
"--max_new_tokens_per_section",
type=int,
default=2048,
help="Upper bound for each section decode before forcing a failure.",
)
parser.add_argument("--output_dir", type=str, default="plain_ar_predictions")
parser.add_argument("--output_prefix", type=str, default="")
parser.add_argument("--skip_decode", action="store_true", default=False)
parser.add_argument("--mucodec_device", type=str, default="auto")
parser.add_argument("--mucodec_layer_num", type=int, default=7)
parser.add_argument("--mucodec_duration", type=float, default=40.96)
parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5)
parser.add_argument("--mucodec_num_steps", type=int, default=20)
parser.add_argument("--mucodec_sample_rate", type=int, default=48000)
def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str:
if not resume_from_checkpoint:
return model_path
if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint):
print(
"Ignoring --model_path during resume and loading config/model from: "
f"{resume_from_checkpoint}"
)
return resume_from_checkpoint
def get_model_dtype(name: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[name]
def resolve_num_audio_token(checkpoint_path: str, explicit_value: int | None) -> int:
if explicit_value is not None:
return int(explicit_value)
config = AutoConfig.from_pretrained(
checkpoint_path,
local_files_only=True,
)
return int(getattr(config, "magel_num_audio_token", 16384))
@dataclass
class PreparedSection:
text: str
desc: str
start_frame: int
end_frame: int
structure: str
index: int
class PlainARDataCollator:
def __init__(self, pad_token_id: int = 0):
self.pad_token_id = int(pad_token_id)
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
return {
"input_ids": pad_sequence(
[row["input_ids"] for row in batch],
batch_first=True,
padding_value=self.pad_token_id,
),
"attention_mask": pad_sequence(
[row["attention_mask"] for row in batch],
batch_first=True,
padding_value=0,
),
"labels": pad_sequence(
[row["labels"] for row in batch],
batch_first=True,
padding_value=-100,
),
}
class PlainARMusicDataset(torch.utils.data.Dataset):
def __init__(
self,
datasets_obj,
split: str,
tokenizer_path: str,
num_audio_token: int = 16384,
fps: int = 25,
use_fast: bool = True,
):
self._data = datasets_obj[split]
self.tokenizer_path = tokenizer_path
self.num_audio_token = int(num_audio_token)
self.fps = int(fps)
self.use_fast = bool(use_fast)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=True,
use_fast=use_fast,
)
add_audio_special_tokens(self.tokenizer, self.num_audio_token)
self.tokenizer_vocab_size = len(self.tokenizer)
self.audio_prefix_length = int(
self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0))
)
self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN))
self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN))
self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN))
self.pad_token_id = (
int(self.tokenizer.pad_token_id)
if self.tokenizer.pad_token_id is not None
else 0
)
self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}"
self._chat_template_kwargs = {"enable_thinking": False}
def __len__(self) -> int:
return len(self._data)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
sample = self.raw_sample(idx)
sections = self.prepare_sections(sample)
token_ids, attention_mask = self.tokenize_messages(
self.build_messages(sample, sections),
sample["mucodec_codes"],
sections,
)
labels = self.build_labels(token_ids)
return {
"input_ids": token_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def raw_sample(self, idx: int) -> dict[str, Any]:
return self._data[idx]
@staticmethod
def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor:
return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1)
@staticmethod
def _sorted_sections(sample: dict[str, Any]) -> list[dict[str, Any]]:
return sorted(
(
{
"raw_index": raw_index,
"text": str(seg["text"]),
"desc": str(seg["desc"]).strip(),
"start": float(seg["start"]),
"end": float(seg["end"]),
"structure": normalize_structure_label(seg["section"]),
}
for raw_index, seg in enumerate(sample.get("sections", []))
),
key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
)
def prepare_sections(self, sample: dict[str, Any]) -> list[PreparedSection]:
sections: list[PreparedSection] = []
section_counts: dict[str, int] = {}
total_frames = len(sample["mucodec_codes"])
prev_end_idx = 0
sample_language = sample.get("language")
for seg in self._sorted_sections(sample):
structure = seg["structure"]
section_counts[structure] = section_counts.get(structure, 0) + 1
raw_end_idx = max(
prev_end_idx,
min(total_frames, int(np.ceil(seg["end"] * self.fps))),
)
sections.append(
PreparedSection(
text=normalize_section_text(
seg["text"],
structure,
language=sample_language,
),
desc=seg["desc"],
start_frame=prev_end_idx,
end_frame=raw_end_idx,
structure=structure,
index=section_counts[structure],
)
)
prev_end_idx = raw_end_idx
if sections:
last = sections[-1]
sections[-1] = PreparedSection(
text=last.text,
desc=last.desc,
start_frame=last.start_frame,
end_frame=total_frames,
structure=last.structure,
index=last.index,
)
return sections
def format_section_label(self, section: PreparedSection) -> str:
label = SECTION_NAME_MAP[section.structure]
if section.structure in SINGLETON_SECTION_NAMES and section.index == 1:
return label
return f"{label} {section.index}"
def build_section_user_content(
self,
sample: dict[str, Any],
section: PreparedSection,
is_first_turn: bool,
) -> str:
parts: list[str] = []
if is_first_turn:
style = str(sample.get("style", "")).strip()
if style:
parts.append(
f"Please generate a song in the following style:{style}\n"
"Next, I will tell you the requirements and lyrics for the song "
"fragment to be generated, section by section."
)
else:
parts.append(
"Please generate the song section by section. "
"Next, I will tell you the requirements and lyrics for each fragment."
)
section_parts = [f"[{self.format_section_label(section)}]"]
if section.desc:
section_parts.append(f"[desc:{section.desc}]")
if section.text:
section_parts.append(f"[lyrics:{section.text}]")
parts.append("".join(section_parts))
return "\n".join(part for part in parts if part)
def build_messages(
self,
sample: dict[str, Any],
sections: list[PreparedSection],
) -> list[dict[str, str]]:
messages: list[dict[str, str]] = []
for idx, section in enumerate(sections):
messages.append(
{
"role": "user",
"content": self.build_section_user_content(
sample=sample,
section=section,
is_first_turn=(idx == 0),
),
}
)
messages.append(
{
"role": "assistant",
"content": self._assistant_audio_placeholder,
}
)
return messages
def tokenize_messages(
self,
messages: list[dict[str, str]],
full_audio_codes,
sections: list[PreparedSection],
) -> tuple[torch.Tensor, torch.Tensor]:
chat_inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
return_dict=True,
**self._chat_template_kwargs,
)
token_ids = chat_inputs["input_ids"].squeeze(0).to(torch.long)
attention_mask = chat_inputs["attention_mask"].squeeze(0).to(torch.long)
return self.expand_audio_tokens(
token_ids=token_ids,
attention_mask=attention_mask,
full_audio_codes=full_audio_codes,
sections=sections,
)
def expand_audio_tokens(
self,
token_ids: torch.Tensor,
attention_mask: torch.Tensor,
full_audio_codes,
sections: list[PreparedSection],
) -> tuple[torch.Tensor, torch.Tensor]:
if not sections:
return token_ids, attention_mask
bos_positions = self._positions(token_ids, self.BOS_AUDIO)
eos_positions = self._positions(token_ids, self.EOS_AUDIO)
audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long)
extra_audio_tokens = sum(
int(section.end_frame) - int(section.start_frame)
for section in sections
)
final_len = token_ids.numel() + extra_audio_tokens
expanded_token_ids = torch.empty(final_len, dtype=torch.long)
expanded_attention_mask = torch.empty(final_len, dtype=torch.long)
read_pos = 0
write_pos = 0
for bos_pos, eos_pos, section in zip(
bos_positions.tolist(),
eos_positions.tolist(),
sections,
):
start_idx = int(section.start_frame)
end_idx = int(section.end_frame)
audio_len = max(0, end_idx - start_idx)
prefix_len = bos_pos + 1 - read_pos
next_write = write_pos + prefix_len
expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1]
expanded_attention_mask[write_pos:next_write] = attention_mask[
read_pos : bos_pos + 1
]
write_pos = next_write
if audio_len > 0:
next_write = write_pos + audio_len
expanded_token_ids[write_pos:next_write] = audio_code_tensor[
start_idx:end_idx
]
expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length)
expanded_attention_mask[write_pos:next_write] = 1
write_pos = next_write
expanded_token_ids[write_pos] = token_ids[eos_pos]
expanded_attention_mask[write_pos] = attention_mask[eos_pos]
write_pos += 1
read_pos = eos_pos + 1
tail_len = token_ids.numel() - read_pos
if tail_len > 0:
expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:]
expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[
read_pos:
]
return expanded_token_ids, expanded_attention_mask
def build_labels(self, token_ids: torch.Tensor) -> torch.Tensor:
audio_codebook_mask = (token_ids >= self.audio_prefix_length) & (
token_ids < self.MASK_AUDIO
)
eos_mask = token_ids == self.EOS_AUDIO
label_mask = audio_codebook_mask | eos_mask
labels = token_ids.clone()
labels[~label_mask] = -100
return labels
def render_audio_token_string(
self,
audio_token_ids: list[int],
include_bos: bool = True,
include_eos: bool = True,
) -> str:
parts: list[str] = []
if include_bos:
parts.append(SOA_TOKEN)
parts.extend(self.tokenizer.convert_ids_to_tokens(audio_token_ids))
if include_eos:
parts.append(EOA_TOKEN)
return "".join(parts)
def create_plain_qwen3_model(
model_path: str,
model_dtype: torch.dtype,
attn_implementation: str,
target_vocab_size: int,
) -> Qwen3ForCausalLM:
print(f"Loading Qwen3 model from: {model_path}")
config = AutoConfig.from_pretrained(
model_path,
local_files_only=True,
)
model = Qwen3ForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=model_dtype,
attn_implementation=attn_implementation,
ignore_mismatched_sizes=True,
local_files_only=True,
)
model.resize_token_embeddings(target_vocab_size)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
return model
def load_plain_qwen3_for_inference(
model_path: str,
device: torch.device,
dtype: torch.dtype,
attn_implementation: str,
target_vocab_size: int,
) -> Qwen3ForCausalLM:
model = create_plain_qwen3_model(
model_path=model_path,
model_dtype=dtype,
attn_implementation=attn_implementation,
target_vocab_size=target_vocab_size,
)
model.to(device=device)
model.eval()
return model
def create_plain_ar_dataset(
dataset_path: str,
split: str,
tokenizer_path: str,
num_audio_token: int,
) -> PlainARMusicDataset:
hf_ds = datasets.load_from_disk(dataset_path)
if isinstance(hf_ds, datasets.DatasetDict):
container = hf_ds
else:
container = {split: hf_ds}
return PlainARMusicDataset(
datasets_obj=container,
split=split,
tokenizer_path=tokenizer_path,
num_audio_token=num_audio_token,
use_fast=True,
)
def run_train(args: argparse.Namespace) -> None:
model_dtype = get_model_dtype(args.model_dtype)
model_source = resolve_model_source(
model_path=args.model_path,
resume_from_checkpoint=args.resume_from_checkpoint,
)
num_audio_token = resolve_num_audio_token(model_source, args.num_audio_token)
print(f"Using num_audio_token={num_audio_token}")
train_dataset = create_plain_ar_dataset(
dataset_path=args.dataset_path,
split="train",
tokenizer_path=args.tokenizer_path,
num_audio_token=num_audio_token,
)
print(f"Dataset size: {len(train_dataset)}")
model = create_plain_qwen3_model(
model_path=model_source,
model_dtype=model_dtype,
attn_implementation=args.attn_implementation,
target_vocab_size=train_dataset.tokenizer_vocab_size,
)
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
num_train_epochs=args.num_train_epochs,
warmup_steps=args.warmup_steps,
max_grad_norm=args.max_grad_norm,
logging_steps=args.logging_steps,
save_strategy="epoch",
dataloader_num_workers=args.dataloader_num_workers,
bf16=(args.model_dtype == "bfloat16"),
fp16=(args.model_dtype == "float16"),
gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
deepspeed=args.deepspeed,
remove_unused_columns=False,
dataloader_drop_last=True,
report_to=args.report_to,
logging_dir=None,
run_name=args.wandb_run_name,
)
if args.wandb_project and "wandb" in args.report_to:
os.environ["WANDB_PROJECT"] = args.wandb_project
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=PlainARDataCollator(
pad_token_id=train_dataset.pad_token_id,
),
)
if args.resume_from_checkpoint:
print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}")
else:
print("Starting training from current model initialization.")
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
final_dir = os.path.join(args.output_dir, "final")
trainer.save_model(final_dir)
train_dataset.tokenizer.save_pretrained(final_dir)
print(f"Training complete. Final model saved to: {final_dir}")
def sanitize_generated_section(
generated_ids: list[int],
eos_audio_id: int,
audio_start: int,
audio_end: int,
) -> list[int]:
if not generated_ids:
raise RuntimeError("Generation returned no new tokens for the current section.")
if generated_ids[-1] != eos_audio_id:
raise RuntimeError(
"Section generation did not terminate with [EOA]. "
"Increase --max_new_tokens_per_section or inspect the checkpoint."
)
invalid_ids = [
tid for tid in generated_ids[:-1] if not (audio_start <= tid < audio_end)
]
if invalid_ids:
preview = invalid_ids[:8]
raise RuntimeError(
"Section generation produced non-audio tokens before [EOA]: "
f"{preview}"
)
audio_ids = generated_ids[:-1]
return audio_ids
@torch.inference_mode()
def generate_sections_autoregressively(
model: Qwen3ForCausalLM,
music_ds: PlainARMusicDataset,
sample: dict[str, Any],
device: torch.device,
args: argparse.Namespace,
) -> tuple[list[list[int]], list[dict[str, Any]]]:
messages: list[dict[str, str]] = []
sections = music_ds.prepare_sections(sample)
section_records: list[dict[str, Any]] = []
use_cache = args.use_cache and not args.no_cache
eos_token_id = music_ds.EOS_AUDIO
pad_token_id = (
int(music_ds.tokenizer.eos_token_id)
if music_ds.tokenizer.eos_token_id is not None
else music_ds.pad_token_id
)
all_section_audio_ids: list[list[int]] = []
for section_idx, section in enumerate(sections):
user_content = music_ds.build_section_user_content(
sample=sample,
section=section,
is_first_turn=(section_idx == 0),
)
messages.append({"role": "user", "content": user_content})
messages.append({"role": "assistant", "content": SOA_TOKEN})
chat_inputs = music_ds.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
return_dict=True,
**music_ds._chat_template_kwargs,
)
input_ids = chat_inputs["input_ids"].to(device)
attention_mask = chat_inputs["attention_mask"].to(device)
generated = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=not bool(args.greedy),
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
max_new_tokens=int(args.max_new_tokens_per_section),
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
use_cache=use_cache,
)
new_ids = generated[0, input_ids.shape[1] :].tolist()
audio_ids = sanitize_generated_section(
generated_ids=new_ids,
eos_audio_id=music_ds.EOS_AUDIO,
audio_start=music_ds.audio_prefix_length,
audio_end=music_ds.MASK_AUDIO,
)
all_section_audio_ids.append(audio_ids)
messages[-1]["content"] = music_ds.render_audio_token_string(
audio_token_ids=audio_ids,
include_bos=True,
include_eos=True,
)
section_records.append(
{
"section_index": section_idx,
"section_label": music_ds.format_section_label(section),
"desc": section.desc,
"lyrics": section.text,
"generated_audio_tokens": len(audio_ids),
}
)
print(
f"[INFO] section={section_idx} "
f"label={music_ds.format_section_label(section)!r} "
f"generated_audio_tokens={len(audio_ids)}"
)
return all_section_audio_ids, section_records
def save_inference_outputs(
output_dir: str,
output_prefix: str,
sample_idx: int,
sample: dict[str, Any],
section_audio_ids: list[list[int]],
section_records: list[dict[str, Any]],
music_ds: PlainARMusicDataset,
args: argparse.Namespace,
) -> None:
Path(output_dir).mkdir(parents=True, exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
prefix = output_prefix or f"{sample['song_id']}_{sample_idx}_{stamp}"
json_path = Path(output_dir) / f"{prefix}.json"
wav_path = Path(output_dir) / f"{prefix}.wav"
flat_token_ids: list[int] = []
for section_ids in section_audio_ids:
flat_token_ids.extend(section_ids)
payload = {
"song_id": str(sample.get("song_id", f"sample_{sample_idx}")),
"sample_idx": int(sample_idx),
"num_sections": len(section_audio_ids),
"generated_audio_tokens": len(flat_token_ids),
"sections": section_records,
}
with open(json_path, "w", encoding="utf-8") as f:
import json
json.dump(payload, f, ensure_ascii=False, indent=2)
print(f"[OK] {json_path}")
if args.skip_decode:
return
shifted_codes = np.asarray(flat_token_ids, dtype=np.int64) - music_ds.audio_prefix_length
if shifted_codes.size == 0:
print("[WARN] No generated MuCodec tokens; skipping wav decode.")
return
import torchaudio
mucodec_decoder = build_mucodec_decoder(args)
wave = decode_mucodec_codes(
mucodec_decoder=mucodec_decoder,
shifted_codes=shifted_codes,
args=args,
)
torchaudio.save(str(wav_path), wave, int(args.mucodec_sample_rate))
print(f"[OK] {wav_path}")
def run_infer(args: argparse.Namespace) -> None:
seed_everything(args.seed)
tokenizer_path = args.tokenizer_path or args.model_path
num_audio_token = resolve_num_audio_token(args.model_path, args.num_audio_token)
print(f"Using num_audio_token={num_audio_token}")
device = resolve_device(args.device)
dtype = get_model_dtype(args.dtype)
if device.type == "cpu" and dtype != torch.float32:
print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
dtype = torch.float32
music_ds = create_plain_ar_dataset(
dataset_path=args.dataset_path,
split=args.split,
tokenizer_path=tokenizer_path,
num_audio_token=num_audio_token,
)
sample = music_ds.raw_sample(args.sample_idx)
model = load_plain_qwen3_for_inference(
model_path=args.model_path,
device=device,
dtype=dtype,
attn_implementation=args.attn_implementation,
target_vocab_size=music_ds.tokenizer_vocab_size,
)
section_audio_ids, section_records = generate_sections_autoregressively(
model=model,
music_ds=music_ds,
sample=sample,
device=device,
args=args,
)
save_inference_outputs(
output_dir=args.output_dir,
output_prefix=args.output_prefix,
sample_idx=args.sample_idx,
sample=sample,
section_audio_ids=section_audio_ids,
section_records=section_records,
music_ds=music_ds,
args=args,
)
def main() -> None:
args = parse_args()
if args.command == "train":
run_train(args)
return
if args.command == "infer":
run_infer(args)
return
raise ValueError(f"Unknown command: {args.command}")
if __name__ == "__main__":
main()